Commit 8546c776 authored by E144069X's avatar E144069X

Added condition to prevent errors when using marseile dataset

parent dd1891bd
......@@ -42,7 +42,7 @@ import cv2
import torch.distributed as dist
from torch.multiprocessing import Process
import torchvision
import time
def epochSeqTr(model,optim,log_interval,loader, epoch, args,writer,**kwargs):
......@@ -145,7 +145,7 @@ def average_gradients(model):
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mode="val",computeMetrics=True):
def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mode="val"):
'''
Validate a model. This function computes several metrics and return the best value found until this point.
......@@ -185,6 +185,8 @@ def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mo
for batch_idx, (data,target,vidName,frameInds,timeElapsedTensor) in enumerate(loader):
#torchvision.utils.save_image(data.reshape(data.size(0)*data.size(1),data.size(2),data.size(3),data.size(4)),"../vis/batch_{}_{}.png".format(args.dataset_val,batch_idx))
newVideo = (vidName != precVidName) or videoBegining
if (batch_idx % log_interval == 0):
......@@ -201,8 +203,7 @@ def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mo
update.updateFrameDict(frameIndDict,frameInds,vidName)
if newVideo and not videoBegining:
if computeMetrics:
allOutput,nbVideos = update.updateMetrics(args,model,allFeat,allTimeElapsedTensor,allTarget,precVidName,nbVideos,metrDict,outDict,targDict)
allOutput,nbVideos = update.updateDictsAndMetrics(args,model,allFeat,allTimeElapsedTensor,allTarget,precVidName,nbVideos,metrDict,outDict,targDict)
intermVarDict = update.saveIntermediateVariables(intermVarDict,args.exp_id,args.model_id,epoch,precVidName)
intermVarDict = update.catIntermediateVariables(visualDict,intermVarDict,nbVideos)
......@@ -228,8 +229,7 @@ def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mo
break
if not args.debug:
if computeMetrics:
allOutput,nbVideos = update.updateMetrics(args,model,allFeat,allTimeElapsedTensor,allTarget,precVidName,nbVideos,metrDict,outDict,targDict)
allOutput,nbVideos = update.updateDictsAndMetrics(args,model,allFeat,allTimeElapsedTensor,allTarget,precVidName,nbVideos,metrDict,outDict,targDict)
intermVarDict = update.saveIntermediateVariables(intermVarDict,args.exp_id,args.model_id,epoch,precVidName)
for key in outDict.keys():
......@@ -345,8 +345,6 @@ def computeTransMat(dataset,transMat,priors,propStart,propEnd,propSetIntFormat):
#Taking the last target of the sequence into account only for prior
priors[target[-1]] += 1
print(torch.isnan(transMat).sum() == 0)
#Just in case where propStart==propEnd, which is true when the training set is empty for example
if len(videoPaths) > 0:
return transMat/transMat.sum(dim=1,keepdim=True),priors/priors.sum()
......@@ -372,7 +370,8 @@ def writeSummaries(metrDict,sampleNb,writer,epoch,mode,model_id,exp_id):
'''
for metric in metrDict.keys():
metrDict[metric] /= sampleNb
if sampleNb > 0:
metrDict[metric] /= sampleNb
for metric in metrDict:
writer.add_scalars(metric,{model_id+"_"+mode:metrDict[metric]},epoch)
......@@ -609,7 +608,6 @@ def run(args):
kwargsVal['loader'] = valLoader
kwargsVal["metricEarlyStop"] = args.metric_early_stop
kwargsVal["computeMetrics"] = args.compute_metrics_during_eval
#Getting the contructor and the kwargs for the choosen optimizer
optimConst,kwargsOpti = get_OptimConstructor_And_Kwargs(args.optim,args.momentum)
......
......@@ -39,7 +39,10 @@ def computeScore(model,allFeats,timeElapsed,allTarget,valLTemp,vidName):
chunkList = torch.split(allFeats,split_size_or_sections=splitSizes,dim=1)
timeElapsedChunkList = torch.split(timeElapsed,split_size_or_sections=splitSizes,dim=1)
if not timeElapsed is None:
timeElapsedChunkList = torch.split(timeElapsed,split_size_or_sections=splitSizes,dim=1)
else:
timeElapsedChunkList = [None for _ in range(len(chunkList))]
sumSize = 0
......@@ -57,7 +60,7 @@ def computeScore(model,allFeats,timeElapsed,allTarget,valLTemp,vidName):
return allOutput
def updateMetrics(args,model,allFeat,timeElapsed,allTarget,precVidName,nbVideos,metrDict,outDict,targDict):
def updateDictsAndMetrics(args,model,allFeat,timeElapsed,allTarget,precVidName,nbVideos,metrDict,outDict,targDict):
allOutputDict = computeScore(model,allFeat,timeElapsed,allTarget,args.val_l_temp,precVidName)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment