Commit 432398dd authored by Tristan GOMEZ's avatar Tristan GOMEZ

Added lstm temp model

parent 8ef153a0
......@@ -46,3 +46,7 @@ train_step_to_ignore = 0
class_nb = 15
compute_val_metrics = True
temp_mod = linear
lstm_lay = 2
lstm_hid_size = 1024
......@@ -28,78 +28,6 @@ def buildFeatModel(featModelName):
return featModel
'''
class CNN(nn.Module):
def __init__(self,featModelName="resnet50",dropout=0.5,classNb=16):
super(CNN,self).__init__()
self.featMod = buildFeatModel(featModelName)
self.classNb = classNb
if featModelName=="resnet50" or featModelName=="resnet101" or featModelName=="resnet151":
self.nbFeat = 256*2**(4-1)
elif featModelName.find("vgg") != -1:
self.nbFeat = 4096
else:
self.nbFeat = 64*2**(4-1)
self.dropout = nn.Dropout(p=dropout)
self.linLay = nn.Linear(self.nbFeat,classNb)
self.transMat = torch.zeros((classNb,classNb))
def forward(self,x):
# N x T x C x H x L
x = self.computeFeat(x)
# NT x D
x = self.computeScore(x)
# N x T x classNb
return x
def computeFeat(self,x):
# N x T x C x H x L
self.batchSize = x.size(0)
x = x.view(x.size(0)*x.size(1),x.size(2),x.size(3),x.size(4))
# NT x C x H x L
x = self.featMod(x)
# NT x D
return x
def computeScore(self,x):
# NT x D
x = self.dropout(x)
x = self.linLay(x)
# NT x classNb
x = x.view(self.batchSize,-1,self.classNb)
# N x T x classNb
return x
class CNN3D(CNN):
def __init__(self,featModelName,dropout=0.5,classNb=16):
super(CNN3D,self).__init__(featModelName,dropout,classNb)
#The r2plus1d_18 architecture has 512 features
self.nbFeat = 512
self.linLay = nn.Linear(self.nbFeat,classNb)
def computeFeat(self,x):
# N x T x C x H x L
self.batchSize = x.size(0)
x = x.permute(0,2,1,3,4)
# N x C x T x H x L
x = self.featMod(x)
# N x D x T
x = x.permute(0,2,1)
# N x T x D
x = x.contiguous().view(x.size(0)*x.size(1),-1)
# NT x D
return x
'''
class Model(nn.Module):
def __init__(self,visualModel,tempModel):
......@@ -182,36 +110,72 @@ class LinearTempModel(TempModel):
# N x T x classNb
return x
class LSTMTempModel(TempModel):
def __init__(self,nbFeat,nbClass,dropout,nbLayers,nbHidden):
super(LSTMTempModel,self).__init__(nbFeat,nbClass)
self.lstmTempMod = nn.LSTM(input_size=self.nbFeat,hidden_size=nbHidden,num_layers=nbLayers,batch_first=True,dropout=dropout,bidirectional=True)
self.linTempMod = LinearTempModel(nbFeat=nbHidden*2,nbClass=self.nbClass,dropout=dropout)
def forward(self,x,batchSize):
# NT x D
x = x.view(batchSize,-1,x.size(-1))
# N x T x D
x,_ = self.lstmTempMod(x)
# N x T x H
x = x.view(-1,x.size(-1))
# NT x H
x = self.linTempMod(x,batchSize)
# N x T x classNb
return x
def netBuilder(args):
############### Visual Model #######################
if args.feat.find("resnet") != -1:
if featModelName=="resnet50" or featModelName=="resnet101" or featModelName=="resnet151":
nbFeat = 256*2**(4-1)
else:
nbFeat = 64*2**(4-1)
visualModelConstruct = CNN2D
visualModel = CNN2D(args.feat)
elif args.feat.find("vgg") != -1:
nbFeat = 4096
visualModelConstruct = CNN2D
visualModel = CNN2D(args.feat)
elif args.feat == "r2plus1d_18":
nbFeat = 512
visualModelConstruct = CNN3D
visualModel = CNN3D(args.feat)
else:
raise ValueError("Unknown model type : ",args.feat)
raise ValueError("Unknown visual model type : ",args.feat)
visualModel = visualModelConstruct(args.feat)
tempModel = LinearTempModel(nbFeat,args.class_nb,args.dropout)
############### Temporal Model #######################
if args.temp_mod == "lstm":
tempModel = LSTMTempModel(nbFeat,args.class_nb,args.dropout,args.lstm_lay,args.lstm_hid_size)
elif args.temp_mod == "linear":
tempModel = LinearTempModel(nbFeat,args.class_nb,args.dropout)
else:
raise ValueError("Unknown temporal model type : ",args.temp_mod)
############### Whole Model ##########################
net = Model(visualModel,tempModel)
return net
def addArgs(argreader):
argreader.parser.add_argument('--feat', type=str, metavar='N',
argreader.parser.add_argument('--feat', type=str, metavar='MOD',
help='the net to use to produce feature for each frame')
argreader.parser.add_argument('--dropout', type=float,metavar='D',
help='The dropout amount on each layer of the RNN except the last one')
argreader.parser.add_argument('--temp_mod', type=str,metavar='MOD',
help='The temporal model. Can be "linear" or "lstm".')
argreader.parser.add_argument('--lstm_lay', type=int,metavar='N',
help='Number of layers for the lstm temporal model')
argreader.parser.add_argument('--lstm_hid_size', type=int,metavar='N',
help='Size of hidden layers for the lstm temporal model')
return argreader
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment