Commit 945a05bc authored by E144069X's avatar E144069X

Added simple attention model

parent 77feaaed
......@@ -133,6 +133,7 @@ resnet_multi_model = False
resnet_multi_model_sparse_const = False
resnet_layer_size_reduce = True
resnet_prelay_size_reduce = True
resnet_simple_att = False
spat_transf = False
spat_transf_img_size = 112
......
......@@ -70,6 +70,7 @@ class Model(nn.Module):
visResDict = self.firstModel(x)
x = visResDict["x"]
resDict = self.secondModel(x,self.firstModel.batchSize,timeElapsed)
for key in visResDict.keys():
......@@ -168,6 +169,35 @@ class CNN2D(FirstModel):
else:
return {'x':res}
class CNN2D_simpleAttention(FirstModel):
def __init__(self,videoMode,featModelName,pretrainedFeatMod=True,featMap=False,bigMaps=False,chan=64,attBlockNb=2,attChan=16,**kwargs):
super(CNN2D_simpleAttention,self).__init__(videoMode,featModelName,pretrainedFeatMod,featMap,bigMaps,**kwargs)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
inFeat = getResnetFeat(featModelName,chan)
attention = []
for i in range(attBlockNb):
attention.append(resnet.BasicBlock(inFeat, inFeat))
attention.append(resnet.conv1x1(inFeat,1))
self.attention = nn.Sequential(*attention)
def forward(self,x):
# N x T x C x H x L
self.batchSize = x.size(0)
x = x.view(x.size(0)*x.size(1),x.size(2),x.size(3),x.size(4)).contiguous() if self.videoMode else x
# NT x C x H x L
features = self.featMod(x)
spatialWeights = torch.sigmoid(self.attention(features))
features = spatialWeights*features
features = self.avgpool(features)
features = features.view(features.size(0), -1)
return {'x':features,'attMaps':spatialWeights}
class CNN3D(FirstModel):
def __init__(self,videoMode,featModelName,pretrainedFeatMod=True,featMap=False,bigMaps=False):
......@@ -418,11 +448,10 @@ class TopkPointExtractor(nn.Module):
#abs,ord = abs.float(),ord.float()
points = torch.cat((abs.unsqueeze(-1),ord.unsqueeze(-1)),dim=-1).float()
selectedPointInds = pointnet2.utils.pointnet2_utils.furthest_point_sample(points,self.furthestPointSampling_nb_pts).to(abs.device)
abs = abs[torch.arange(abs.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
ord = ord[torch.arange(ord.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
retDict={}
if self.softCoord:
......@@ -685,10 +714,19 @@ def netBuilder(args):
############### Visual Model #######################
if args.feat.find("resnet") != -1:
nbFeat = getResnetFeat(args.feat,args.resnet_chan)
firstModel = CNN2D(args.video_mode,args.feat,args.pretrained_visual,chan=args.resnet_chan,stride=args.resnet_stride,dilation=args.resnet_dilation,\
if not args.resnet_simple_att:
CNNconst = CNN2D
kwargs={}
else:
CNNconst = CNN2D_simpleAttention
kwargs={"featMap":True}
firstModel = CNNconst(args.video_mode,args.feat,args.pretrained_visual,chan=args.resnet_chan,stride=args.resnet_stride,dilation=args.resnet_dilation,\
attChan=args.resnet_att_chan,attBlockNb=args.resnet_att_blocks_nb,attActFunc=args.resnet_att_act_func,\
multiModel=args.resnet_multi_model,\
multiModSparseConst=args.resnet_multi_model_sparse_const,num_classes=args.class_nb)
multiModSparseConst=args.resnet_multi_model_sparse_const,num_classes=args.class_nb,**kwargs)
elif args.feat.find("vgg") != -1:
nbFeat = 4096
firstModel = CNN2D(args.video_mode,args.feat,args.pretrained_visual)
......@@ -757,14 +795,17 @@ def netBuilder(args):
net = Model(firstModel,secondModel,spatTransf=spatTransf)
if args.multi_gpu:
net = DataParallelModel(net)
if args.temp_mod == "pointnet2":
net = net.float()
if args.cuda:
net = net.cuda()
net.cuda()
if args.multi_gpu:
net = DataParallelModel(net)
#net.to("cuda" if args.cuda else "cpu")
return net
......@@ -869,6 +910,8 @@ def addArgs(argreader):
help='To apply a stride of 2 in the layer 2,3 and 4 when the resnet model is used.')
argreader.parser.add_argument('--resnet_prelay_size_reduce', type=args.str2bool, metavar='INT',
help='To apply a stride of 2 in the convolution and the maxpooling before the layer 1.')
argreader.parser.add_argument('--resnet_simple_att', type=args.str2bool, metavar='INT',
help='To apply a simple attention on top of the resnet model.')
argreader.parser.add_argument('--resnet_att_chan', type=int, metavar='INT',
help='For the \'resnetX_att\' feat models. The number of channels in the attention module.')
......
......@@ -133,6 +133,7 @@ resnet_multi_model = False
resnet_multi_model_sparse_const = False
resnet_layer_size_reduce = True
resnet_prelay_size_reduce = True
resnet_simple_att = False
spat_transf = False
spat_transf_img_size = 112
......
......@@ -138,3 +138,4 @@ resnet_multi_model = False
resnet_multi_model_sparse_const = False
resnet_layer_size_reduce = True
resnet_prelay_size_reduce = True
resnet_simple_att = False
......@@ -133,6 +133,7 @@ resnet_multi_model = False
resnet_multi_model_sparse_const = False
resnet_layer_size_reduce = True
resnet_prelay_size_reduce = True
resnet_simple_att = False
spat_transf = False
spat_transf_img_size = 112
......
......@@ -138,3 +138,4 @@ resnet_multi_model = False
resnet_multi_model_sparse_const = False
resnet_layer_size_reduce = True
resnet_prelay_size_reduce = True
resnet_simple_att = False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment