Commit 2164528a authored by E144069X's avatar E144069X

Added simple topk attention for resnet

parent 8cba2160
......@@ -110,6 +110,8 @@ pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_simple_att_topk = False
resnet_simple_att_topk_pxls_nb = 256
resnet_chan = 64
resnet_stride = 2
......
......@@ -45,6 +45,15 @@ def buildFeatModel(featModelName,pretrainedFeatMod,featMap=False,bigMaps=False,l
return featModel
def mapToList(map,abs,ord):
#This extract the desired pixels in a map
indices = tuple([torch.arange(map.size(0), dtype=torch.long).unsqueeze(1).unsqueeze(1),
torch.arange(map.size(1), dtype=torch.long).unsqueeze(1).unsqueeze(0),
ord.long().unsqueeze(1),abs.long().unsqueeze(1)])
list = map[indices].permute(0,2,1)
return list
#This class is just the class nn.DataParallel that allow running computation on multiple gpus
#but it adds the possibility to access the attribute of the model
class DataParallelModel(nn.DataParallel):
......@@ -175,7 +184,9 @@ class CNN2D(FirstModel):
class CNN2D_simpleAttention(FirstModel):
def __init__(self,videoMode,featModelName,pretrainedFeatMod=True,featMap=False,bigMaps=False,chan=64,attBlockNb=2,attChan=16,**kwargs):
def __init__(self,videoMode,featModelName,pretrainedFeatMod=True,featMap=False,bigMaps=False,chan=64,attBlockNb=2,attChan=16,\
topk=False,topk_pxls_nb=256,**kwargs):
super(CNN2D_simpleAttention,self).__init__(videoMode,featModelName,pretrainedFeatMod,featMap,bigMaps,**kwargs)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
......@@ -186,7 +197,13 @@ class CNN2D_simpleAttention(FirstModel):
attention.append(resnet.BasicBlock(inFeat, inFeat))
attention.append(resnet.conv1x1(inFeat,1))
self.attention = nn.Sequential(*attention)
self.topk = topk
if not topk:
self.attention = nn.Sequential(*attention)
self.topk_pxls_nb = None
else:
self.attention = None
self.topk_pxls_nb = topk_pxls_nb
def forward(self,x):
# N x T x C x H x L
......@@ -194,11 +211,25 @@ class CNN2D_simpleAttention(FirstModel):
x = x.view(x.size(0)*x.size(1),x.size(2),x.size(3),x.size(4)).contiguous() if self.videoMode else x
# NT x C x H x L
features = self.featMod(x)
spatialWeights = torch.sigmoid(self.attention(features))
features = spatialWeights*features
features = self.avgpool(features)
features = features.view(features.size(0), -1)
if not self.topk:
spatialWeights = torch.sigmoid(self.attention(features))
features = spatialWeights*features
features = self.avgpool(features)
features = features.view(features.size(0), -1)
else:
spatialWeights = torch.zeros((features.size(0),1,features.size(2),features.size(3)))
#Compute the mean between the k most active pixels
featNorm = torch.pow(features,2).sum(dim=1,keepdim=True)
flatFeatNorm = featNorm.view(featNorm.size(0),-1)
flatVals,flatInds = torch.topk(flatFeatNorm, self.topk_pxls_nb, dim=-1, largest=True)
abs,ord = (flatInds%featNorm.shape[-1],flatInds//featNorm.shape[-1])
featureList = mapToList(features,abs,ord)
features = featureList.mean(dim=1)
indices = tuple([torch.arange(spatialWeights.size(0), dtype=torch.long).unsqueeze(1).unsqueeze(1),
torch.arange(spatialWeights.size(1), dtype=torch.long).unsqueeze(1).unsqueeze(0),
ord.long().unsqueeze(1),abs.long().unsqueeze(1)])
spatialWeights[indices] = 1
return {'x':features,'attMaps':spatialWeights}
......@@ -389,6 +420,7 @@ class TopkPointExtractor(nn.Module):
def __init__(self,cuda,nbFeat,featMod,softCoord,softCoord_kerSize,softCoord_secOrder,point_nb,reconst,encoderChan,\
predictDepth,softcoord_shiftpred,furthestPointSampling,furthestPointSampling_nb_pts,dropout):
super(TopkPointExtractor,self).__init__()
self.feat = featMod
......@@ -433,15 +465,6 @@ class TopkPointExtractor(nn.Module):
self.ordKerDict,self.absKerDict,self.spatialWeightKerDict = {},{},{}
def mapToList(self,map,abs,ord):
#This extract the desired pixels in a map
indices = tuple([torch.arange(map.size(0), dtype=torch.long).unsqueeze(1).unsqueeze(1),
torch.arange(map.size(1), dtype=torch.long).unsqueeze(1).unsqueeze(0),
ord.long().unsqueeze(1),abs.long().unsqueeze(1)])
list = map[indices].permute(0,2,1)
return list
def forward(self,imgBatch):
featureMaps = self.feat(imgBatch)
......@@ -485,17 +508,17 @@ class TopkPointExtractor(nn.Module):
#Keeping only one channel (the three channels are just copies)
retDict['reconst'] = reconst[:,0:1]
pointFeat = self.mapToList(pointFeaturesMap,abs,ord)
pointFeat = mapToList(pointFeaturesMap,abs,ord)
if self.predictDepth:
depthMap = torch.tanh(self.conv1x1_depth(featureMaps))*max(featureMaps.size(-2),featureMaps.size(-1))//10
depth = self.mapToList(depthMap,abs,ord)
depth = mapToList(depthMap,abs,ord)
else:
depth = torch.zeros(abs.size(0),abs.size(1),1).to(x.device)
if self.softcoord_shiftpred:
xShiftMap,yShiftMap = torch.tanh(self.shiftpred_x(featureMaps)),torch.tanh(self.shiftpred_y(featureMaps))
xShift,yShift = self.mapToList(xShiftMap,abs,ord).squeeze(-1),self.mapToList(yShiftMap,abs,ord).squeeze(-1)
xShift,yShift = mapToList(xShiftMap,abs,ord).squeeze(-1),mapToList(yShiftMap,abs,ord).squeeze(-1)
abs = xShift + abs.float()
ord = yShift + ord.float()
......@@ -678,7 +701,7 @@ def netBuilder(args):
kwargs={}
else:
CNNconst = CNN2D_simpleAttention
kwargs={"featMap":True}
kwargs={"featMap":True,"topk":args.resnet_simple_att_topk,"topk_pxls_nb":args.resnet_simple_att_topk_pxls_nb}
firstModel = CNNconst(args.video_mode,args.feat,args.pretrained_visual,chan=args.resnet_chan,stride=args.resnet_stride,dilation=args.resnet_dilation,\
attChan=args.resnet_att_chan,attBlockNb=args.resnet_att_blocks_nb,attActFunc=args.resnet_att_act_func,\
......@@ -847,6 +870,10 @@ def addArgs(argreader):
help='To apply a stride of 2 in the convolution and the maxpooling before the layer 1.')
argreader.parser.add_argument('--resnet_simple_att', type=args.str2bool, metavar='INT',
help='To apply a simple attention on top of the resnet model.')
argreader.parser.add_argument('--resnet_simple_att_topk', type=args.str2bool, metavar='BOOL',
help='To use top-k feature as attention model with resnet. Ignored when --resnet_simple_att is False.')
argreader.parser.add_argument('--resnet_simple_att_topk_pxls_nb', type=int, metavar='INT',
help='The value of k when using top-k selection for resnet simple attention. Ignored when --resnet_simple_att_topk is False.')
argreader.parser.add_argument('--resnet_att_chan', type=int, metavar='INT',
help='For the \'resnetX_att\' feat models. The number of channels in the attention module.')
......
......@@ -110,6 +110,8 @@ pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_simple_att_topk = False
resnet_simple_att_topk_pxls_nb = 256
resnet_chan = 64
resnet_stride = 2
......
......@@ -115,6 +115,8 @@ pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_simple_att_topk = False
resnet_simple_att_topk_pxls_nb = 256
resnet_chan = 64
resnet_stride = 2
......
......@@ -110,6 +110,8 @@ pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_simple_att_topk = False
resnet_simple_att_topk_pxls_nb = 256
resnet_chan = 64
resnet_stride = 2
......
......@@ -115,6 +115,8 @@ pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_simple_att_topk = False
resnet_simple_att_topk_pxls_nb = 256
resnet_chan = 64
resnet_stride = 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment