Commit 033782d3 authored by E144069X's avatar E144069X

Added arg to stop pointnet from using coordinate as features

parent 1d3276d5
......@@ -121,6 +121,7 @@ pn_topk_enc_chan = 1
pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
resnet_chan = 64
resnet_stride = 2
......
......@@ -417,14 +417,11 @@ class TopkPointExtractor(nn.Module):
if self.furthestPointSampling:
#abs,ord = abs.float(),ord.float()
points = torch.cat((abs.unsqueeze(-1),ord.unsqueeze(-1)),dim=-1).float()
selectedPointInds = pointnet2.utils.pointnet2_utils.furthest_point_sample(points,self.furthestPointSampling_nb_pts)
try:
abs = abs[torch.arange(abs.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
ord = ord[torch.arange(ord.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
except RuntimeError:
print(abs.size(),torch.arange(abs.size(0)).unsqueeze(1).long().max(),selectedPointInds.max(),selectedPointInds.min())
print(ord.size(),torch.arange(ord.size(0)).unsqueeze(1).long().max(),selectedPointInds.max(),selectedPointInds.min())
sys.exit(0)
selectedPointInds = pointnet2.utils.pointnet2_utils.furthest_point_sample(points,self.furthestPointSampling_nb_pts).to(abs.device)
abs = abs[torch.arange(abs.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
ord = ord[torch.arange(ord.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
retDict={}
......@@ -490,7 +487,7 @@ class PointNet2(FirstModel):
def __init__(self,cuda,videoMode,pn_builder,classNb,nbFeat,featModelName='resnet18',pretrainedFeatMod=True,topk=False,\
topk_softcoord=False,topk_softCoord_kerSize=2,topk_softCoord_secOrder=False,point_nb=256,reconst=False,\
encoderChan=1,encoderHidChan=64,predictDepth=False,topk_softcoord_shiftpred=False,topk_fps=False,topk_fps_nb_pts=64,\
**kwargs):
useXYZ=True,**kwargs):
super(PointNet2,self).__init__(videoMode,featModelName,pretrainedFeatMod,True,chan=encoderHidChan,**kwargs)
......@@ -501,7 +498,7 @@ class PointNet2(FirstModel):
else:
self.pointExtr = DirectPointExtractor(point_nb)
self.pn2 = pointnet2.models.pointnet2_ssg_cls.Pointnet2SSG(num_classes=classNb,input_channels=encoderChan if topk else 0,use_xyz=True)
self.pn2 = pn_builder(num_classes=classNb,input_channels=encoderChan if topk else 0,use_xyz=useXYZ)
def forward(self,x):
......@@ -735,7 +732,7 @@ def netBuilder(args):
point_nb=args.pn_point_nb,reconst=args.pn_topk_reconst,topk_softcoord_shiftpred=args.pn_topk_softcoord_shiftpred,\
encoderChan=args.pn_topk_enc_chan,multiModel=args.resnet_multi_model,multiModSparseConst=args.resnet_multi_model_sparse_const,predictDepth=args.pn_topk_pred_depth,\
layerSizeReduce=args.resnet_layer_size_reduce,preLayerSizeReduce=args.resnet_prelay_size_reduce,dilation=args.resnet_dilation,\
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points)
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points,useXYZ=args.pn_use_xyz)
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
elif args.temp_mod == "pointnet2_pp":
pn_builder = pointnet2.models.pointnet2_msg_cls.Pointnet2MSG
......@@ -744,7 +741,7 @@ def netBuilder(args):
point_nb=args.pn_point_nb,reconst=args.pn_topk_reconst,topk_softcoord_shiftpred=args.pn_topk_softcoord_shiftpred,\
encoderChan=args.pn_topk_enc_chan,multiModel=args.resnet_multi_model,multiModSparseConst=args.resnet_multi_model_sparse_const,predictDepth=args.pn_topk_pred_depth,\
layerSizeReduce=args.resnet_layer_size_reduce,preLayerSizeReduce=args.resnet_prelay_size_reduce,dilation=args.resnet_dilation,\
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points)
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points,useXYZ=args.pn_use_xyz)
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
elif args.temp_mod == "identity":
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
......@@ -853,6 +850,8 @@ def addArgs(argreader):
help='For the topk point net model. This is the number of hidden channel of the encoder')
argreader.parser.add_argument('--pn_topk_pred_depth', type=args.str2bool, metavar='INT',
help='For the pointnet2 model. To predict the depth of chosen points.')
argreader.parser.add_argument('--pn_use_xyz', type=args.str2bool, metavar='INT',
help='For the pointnet2 model. To use the point coordinates as feature.')
argreader.parser.add_argument('--resnet_chan', type=int, metavar='INT',
help='The channel number for the visual model when resnet is used')
......
......@@ -121,6 +121,7 @@ pn_topk_enc_chan = 1
pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
resnet_chan = 64
resnet_stride = 2
......
......@@ -126,6 +126,7 @@ pn_topk_enc_chan = 1
pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
resnet_chan = 64
resnet_stride = 2
......
......@@ -121,6 +121,7 @@ pn_topk_enc_chan = 1
pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
resnet_chan = 64
resnet_stride = 2
......
......@@ -126,6 +126,7 @@ pn_topk_enc_chan = 1
pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
resnet_chan = 64
resnet_stride = 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment