Commit dd262a69 authored by E144069X's avatar E144069X

Added edge net

parent 2164528a
......@@ -562,7 +562,7 @@ class TopkPointExtractor(nn.Module):
class PointNet2(FirstModel):
def __init__(self,cuda,videoMode,classNb,nbFeat,featModelName='resnet18',pretrainedFeatMod=True,topk=False,\
def __init__(self,cuda,videoMode,classNb,nbFeat,pn_model,featModelName='resnet18',pretrainedFeatMod=True,topk=False,\
topk_softcoord=False,topk_softCoord_kerSize=2,topk_softCoord_secOrder=False,point_nb=256,reconst=False,\
encoderChan=1,encoderHidChan=64,predictDepth=False,topk_softcoord_shiftpred=False,topk_fps=False,topk_fps_nb_pts=64,\
topk_dropout=0,**kwargs):
......@@ -576,7 +576,7 @@ class PointNet2(FirstModel):
else:
self.pointExtr = DirectPointExtractor(point_nb)
self.pn2 = pointnet2.Net(num_classes=classNb,input_channels=encoderChan if topk else 0)
self.pn2 = pn_model
def forward(self,x):
......@@ -741,8 +741,13 @@ def netBuilder(args):
chan=args.resnet_chan,multiModel=args.resnet_multi_model,dilation=args.resnet_dilation,\
multiModSparseConst=args.resnet_multi_model_sparse_const,layerSizeReduce=args.resnet_layer_size_reduce)
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
elif args.temp_mod == "pointnet2":
firstModel = PointNet2(args.cuda,args.video_mode,args.class_nb,nbFeat=nbFeat,featModelName=args.feat,pretrainedFeatMod=args.pretrained_visual,encoderHidChan=args.pn_topk_hid_chan,\
elif args.temp_mod == "pointnet2" or args.temp_mod == "edgenet":
if args.temp_mod == "pointnet2":
pn_model = pointnet2.Net(num_classes=args.class_nb,input_channels=args.pn_topk_enc_chan if args.pn_topk else 0)
else:
pn_model = pointnet2.EdgeNet(num_classes=args.class_nb,input_channels=args.pn_topk_enc_chan if args.pn_topk else 3)
firstModel = PointNet2(args.cuda,args.video_mode,args.class_nb,nbFeat=nbFeat,pn_model=pn_model,featModelName=args.feat,pretrainedFeatMod=args.pretrained_visual,encoderHidChan=args.pn_topk_hid_chan,\
topk=args.pn_topk,topk_softcoord=args.pn_topk_softcoord,topk_softCoord_kerSize=args.pn_topk_softcoord_kersize,topk_softCoord_secOrder=args.pn_topk_softcoord_secorder,\
point_nb=args.pn_point_nb,reconst=args.pn_topk_reconst,topk_softcoord_shiftpred=args.pn_topk_softcoord_shiftpred,\
encoderChan=args.pn_topk_enc_chan,multiModel=args.resnet_multi_model,multiModSparseConst=args.resnet_multi_model_sparse_const,predictDepth=args.pn_topk_pred_depth,\
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment