Commit 6365155f authored by E144069X's avatar E144069X

Added auto start_mode

parent 1fa7fd99
......@@ -104,7 +104,7 @@ def updateMetrDict(metrDict,metrDictSample):
return metrDict
def binaryToMetrics(output,target,paramDict,transition_matrix=None,videoNames=None,onlyPairsCorrelation=True,videoMode=True):
def binaryToMetrics(output,target,paramDict=None,transition_matrix=None,videoNames=None,onlyPairsCorrelation=True,videoMode=True):
''' Computes metrics over a batch of targets and predictions
Args:
......
......@@ -50,7 +50,7 @@ exp_id = default
cuda = True
multi_gpu = False
optim = SGD
start_mode = scratch
start_mode = auto
init_path = None
note = None
val_batch_size = 1
......
......@@ -97,7 +97,7 @@ def epochSeqTr(model,optim,log_interval,loader, epoch, args,writer,**kwargs):
optim.zero_grad()
#Metrics
metDictSample = metrics.binaryToMetrics(output,target,model.transMat,videoMode=args.video_mode)
metDictSample = metrics.binaryToMetrics(output,target,transition_matrix=model.transMat,videoMode=args.video_mode)
metDictSample["Loss"] = loss.detach().data.item()
metrDict = metrics.updateMetrDict(metrDict,metDictSample)
......@@ -342,7 +342,9 @@ def computeTransMat(dataset,transMat,priors,propStart,propEnd,propSetIntFormat):
#Taking the last target of the sequence into account only for prior
priors[target[-1]] += 1
#Just in case where propStart==propEnd, which is true for example, when the training set is empty
print(torch.isnan(transMat).sum() == 0)
#Just in case where propStart==propEnd, which is true when the training set is empty for example
if len(videoPaths) > 0:
return transMat/transMat.sum(dim=1,keepdim=True),priors/priors.sum()
else:
......@@ -421,12 +423,22 @@ def initialize_Net_And_EpochNumber(net,exp_id,model_id,cuda,start_mode,init_path
Returns: the start epoch number
'''
if start_mode == "auto":
if len(glob.glob("../models/{}/model{}_epoch*".format(exp_id, model_id))) > 0:
start_mode = "fine_tune"
else:
start_mode = "scratch"
print("Autodetected mode", start_mode)
if start_mode == "scratch":
#Saving initial parameters
torch.save(net.state_dict(), "../models/{}/{}_epoch0".format(exp_id,model_id))
startEpoch = 1
elif start_mode == "fine_tune":
if init_path == "None":
init_path = sorted(glob.glob("../models/{}/model{}_epoch*".format(exp_id, model_id)), key=utils.findLastNumbers)[-1]
params = torch.load(init_path,map_location="cpu" if not cuda else None)
#Checking if the key of the model start with "module."
......@@ -646,6 +658,20 @@ def run(args):
with torch.no_grad():
testFunc(**kwargsTest)
def updateSeedAndNote(args):
if args.start_mode == "auto" and len(
glob.glob("../models/{}/model{}_epoch*".format(args.exp_id, args.model_id))) > 0:
args.seed += 1
init_path = args.init_path
if init_path == "None" and args.strict_init:
init_path = sorted(glob.glob("../models/{}/model{}_epoch*".format(args.exp_id, args.model_id)),
key=utils.findLastNumbers)[-1]
startEpoch = utils.findLastNumbers(init_path)
args.note += ";s{} at {}".format(args.seed, startEpoch)
return args
def main(argv=None):
#Getting arguments from config file and command line
......@@ -686,6 +712,10 @@ def main(argv=None):
if not (os.path.exists("../models/{}".format(args.exp_id))):
os.makedirs("../models/{}".format(args.exp_id))
args = updateSeedAndNote(args)
# Update the config args
argreader.args = args
#Write the arguments in a config file so the experiment can be re-run
argreader.writeConfigFile("../models/{}/{}.ini".format(args.exp_id,args.model_id))
......
......@@ -66,7 +66,7 @@ def updateMetrics(args,model,allFeat,timeElapsed,allTarget,precVidName,nbVideos,
if args.compute_metrics_during_eval:
loss = F.cross_entropy(allOutput.squeeze(0),allTarget.squeeze(0)).data.item()
metDictSample = metrics.binaryToMetrics(allOutput,allTarget,model.transMat,videoMode=True)
metDictSample = metrics.binaryToMetrics(allOutput,allTarget,transition_matrix=model.transMat,videoMode=True)
metDictSample["Loss"] = loss
metrDict = metrics.updateMetrDict(metrDict,metDictSample)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment