Commit df0e1ab6 authored by E144069X's avatar E144069X

Replaced pims by torchvision's video reader and paralelised validation batch computation

parent 8546c776
This diff is collapsed.
......@@ -68,7 +68,7 @@ class DataParallelModel(nn.DataParallel):
class Model(nn.Module):
def __init__(self,firstModel,secondModel,spatTransf=None):
def __init__(self,firstModel,secondModel,nbDevices,spatTransf=None):
super(Model,self).__init__()
self.firstModel = firstModel
self.secondModel = secondModel
......@@ -77,14 +77,16 @@ class Model(nn.Module):
self.transMat = torch.zeros((self.secondModel.nbClass,self.secondModel.nbClass))
self.priors = torch.zeros((self.secondModel.nbClass))
self.nbDevices = nbDevices
def forward(self,x,timeElapsed=None):
if self.spatTransf:
x = self.spatTransf(x)["x"]
self.batchSize=x.size(0)
visResDict = self.firstModel(x)
x = visResDict["x"]
resDict = self.secondModel(x,self.firstModel.batchSize,timeElapsed)
resDict = self.secondModel(x,self.batchSize,timeElapsed)
for key in visResDict.keys():
resDict[key] = visResDict[key]
......@@ -92,15 +94,9 @@ class Model(nn.Module):
return resDict
def computeVisual(self,x):
if self.spatTransf:
resDict = self.spatTransf(x)
x = resDict["x"]
theta = resDict["theta"]
resDict = self.firstModel(x)
if self.spatTransf:
resDict["theta"] = theta
return resDict
def setTransMat(self,transMat):
......@@ -173,6 +169,7 @@ class CNN2D(FirstModel):
self.batchSize = x.size(0)
x = x.view(x.size(0)*x.size(1),x.size(2),x.size(3),x.size(4)).contiguous() if self.videoMode else x
# NT x C x H x L
res = self.featMod(x)
# NT x D
......@@ -766,8 +763,7 @@ def netBuilder(args):
else:
spatTransf = None
net = Model(firstModel,secondModel,spatTransf=spatTransf)
net = Model(firstModel,secondModel,torch.cuda.device_count() if args.multi_gpu else 1,spatTransf=spatTransf)
if args.temp_mod == "pointnet2":
net = net.float()
......@@ -776,8 +772,14 @@ def netBuilder(args):
net.cuda()
if args.multi_gpu:
net = DataParallelModel(net)
#netTrain = DataParallelModel(net)
#netVal = net
#netVal.firstModel = DataParallelModel(netVal.firstModel)
net.firstModel = DataParallelModel(net.firstModel)
#net.secondModel = DataParallelModel(net.secondModel)
#net = DataParallelModel(net)
#net.to("cuda" if args.cuda else "cpu")
return net
......
......@@ -110,7 +110,7 @@ def epochSeqTr(model,optim,log_interval,loader, epoch, args,writer,**kwargs):
torch.save(model.state_dict(), "../models/{}/model{}_epoch{}".format(args.exp_id,args.model_id, epoch))
writeSummaries(metrDict,validBatch,writer,epoch,"train",args.model_id,args.exp_id)
if args.debug:
if args.debug and validBatch > 0:
totalTime = time.time() - start_time
update.updateTimeCSV(epoch,"train",args.exp_id,args.model_id,totalTime,batch_idx)
......@@ -182,6 +182,7 @@ def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mo
videoBegining = True
validBatch = 0
nbVideos = 0
deviceNb = torch.cuda.device_count() if args.multi_gpu else 1
for batch_idx, (data,target,vidName,frameInds,timeElapsedTensor) in enumerate(loader):
......@@ -197,7 +198,10 @@ def epochSeqVal(model,log_interval,loader, epoch, args,writer,metricEarlyStop,mo
if not timeElapsedTensor is None:
timeElapsedTensor = timeElapsedTensor.cuda()
visualDict = model.computeVisual(data)
if data.size(1)%deviceNb == 0:
data = data.reshape(deviceNb,data.size(1)//deviceNb,data.size(2),data.size(3),data.size(4))
visualDict = model.firstModel(data)
feat = visualDict["x"].data
update.updateFrameDict(frameIndDict,frameInds,vidName)
......
......@@ -47,9 +47,7 @@ def computeScore(model,allFeats,timeElapsed,allTarget,valLTemp,vidName):
sumSize = 0
for i in range(len(chunkList)):
output = model.secondModel(chunkList[i].squeeze(0),batchSize=1,timeTensor=timeElapsedChunkList[i])
output = model.secondModel(chunkList[i].squeeze(0),1,timeElapsedChunkList[i])
for tensorName in output.keys():
if not tensorName in allOutput.keys():
allOutput[tensorName] = output[tensorName]
......@@ -65,7 +63,6 @@ def updateDictsAndMetrics(args,model,allFeat,timeElapsed,allTarget,precVidName,n
allOutputDict = computeScore(model,allFeat,timeElapsed,allTarget,args.val_l_temp,precVidName)
allOutput = allOutputDict["pred"]
if args.compute_metrics_during_eval:
loss = F.cross_entropy(allOutput.squeeze(0),allTarget.squeeze(0)).data.item()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment