Commit 3e8fb033 authored by E144069X's avatar E144069X

Added point dropout

parent 372bc218
......@@ -66,14 +66,6 @@ lstm_hid_size = 1024
augment_data = False
regression = False
uncertainty = False
uncer_loss_type = MSE
uncer_exact_inf_div = True
uncert_inf_div_weight = 0
uncer_ll_ratio_weight = 1
uncer_max_adv_entr_weight = 0
maximise_val_metric = True
metric_early_stop = Accuracy
max_worse_epoch_nb = 30
......@@ -117,6 +109,7 @@ pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_chan = 64
resnet_stride = 2
......
......@@ -388,7 +388,7 @@ class DirectPointExtractor(nn.Module):
class TopkPointExtractor(nn.Module):
def __init__(self,cuda,nbFeat,featMod,softCoord,softCoord_kerSize,softCoord_secOrder,point_nb,reconst,encoderChan,\
predictDepth,softcoord_shiftpred,furthestPointSampling,furthestPointSampling_nb_pts):
predictDepth,softcoord_shiftpred,furthestPointSampling,furthestPointSampling_nb_pts,dropout):
super(TopkPointExtractor,self).__init__()
self.feat = featMod
......@@ -400,6 +400,7 @@ class TopkPointExtractor(nn.Module):
self.softcoord_shiftpred = softcoord_shiftpred
self.furthestPointSampling = furthestPointSampling
self.furthestPointSampling_nb_pts = furthestPointSampling_nb_pts
self.dropout = dropout
self.reconst = reconst
if reconst:
......@@ -469,6 +470,11 @@ class TopkPointExtractor(nn.Module):
abs = abs[torch.arange(abs.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
ord = ord[torch.arange(ord.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
if self.dropout and self.training:
idx = torch.randperm(abs.size(1))[:int((1-self.dropout)*abs.size(1))].unsqueeze(0)
abs = abs[torch.arange(abs.size(0)).unsqueeze(1).long(),idx]
ord = ord[torch.arange(ord.size(0)).unsqueeze(1).long(),idx]
retDict={}
if self.softCoord:
......@@ -536,14 +542,14 @@ class PointNet2(FirstModel):
def __init__(self,cuda,videoMode,classNb,nbFeat,featModelName='resnet18',pretrainedFeatMod=True,topk=False,\
topk_softcoord=False,topk_softCoord_kerSize=2,topk_softCoord_secOrder=False,point_nb=256,reconst=False,\
encoderChan=1,encoderHidChan=64,predictDepth=False,topk_softcoord_shiftpred=False,topk_fps=False,topk_fps_nb_pts=64,\
useXYZ=True,**kwargs):
topk_dropout=0,**kwargs):
super(PointNet2,self).__init__(videoMode,featModelName,pretrainedFeatMod,True,chan=encoderHidChan,**kwargs)
if topk:
self.pointExtr = TopkPointExtractor(cuda,nbFeat,self.featMod,topk_softcoord,topk_softCoord_kerSize,\
topk_softCoord_secOrder,point_nb,reconst,encoderChan,predictDepth,\
topk_softcoord_shiftpred,topk_fps,topk_fps_nb_pts)
topk_softcoord_shiftpred,topk_fps,topk_fps_nb_pts,topk_dropout)
else:
self.pointExtr = DirectPointExtractor(point_nb)
......@@ -718,7 +724,7 @@ def netBuilder(args):
point_nb=args.pn_point_nb,reconst=args.pn_topk_reconst,topk_softcoord_shiftpred=args.pn_topk_softcoord_shiftpred,\
encoderChan=args.pn_topk_enc_chan,multiModel=args.resnet_multi_model,multiModSparseConst=args.resnet_multi_model_sparse_const,predictDepth=args.pn_topk_pred_depth,\
layerSizeReduce=args.resnet_layer_size_reduce,preLayerSizeReduce=args.resnet_prelay_size_reduce,dilation=args.resnet_dilation,\
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points,useXYZ=args.pn_use_xyz)
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points,topk_dropout=args.pn_topk_dropout)
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
elif args.temp_mod == "identity":
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
......@@ -820,6 +826,8 @@ def addArgs(argreader):
help='For the pointnet2 model. To predict the depth of chosen points.')
argreader.parser.add_argument('--pn_use_xyz', type=args.str2bool, metavar='INT',
help='For the pointnet2 model. To use the point coordinates as feature.')
argreader.parser.add_argument('--pn_topk_dropout', type=float, metavar='FLOAT',
help='The proportion of point to randomly drop to decrease overfitting.')
argreader.parser.add_argument('--resnet_chan', type=int, metavar='INT',
help='The channel number for the visual model when resnet is used')
......
......@@ -66,14 +66,6 @@ lstm_hid_size = 1024
augment_data = False
regression = False
uncertainty = False
uncer_loss_type = MSE
uncer_exact_inf_div = True
uncert_inf_div_weight = 0
uncer_ll_ratio_weight = 1
uncer_max_adv_entr_weight = 0
maximise_val_metric = True
metric_early_stop = Accuracy
max_worse_epoch_nb = 30
......@@ -117,6 +109,7 @@ pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_chan = 64
resnet_stride = 2
......
......@@ -66,14 +66,6 @@ lstm_hid_size = 1024
augment_data = False
regression = False
uncertainty = False
uncer_loss_type = MSE
uncer_exact_inf_div = True
uncert_inf_div_weight = 0
uncer_ll_ratio_weight = 1
uncer_max_adv_entr_weight = 0
maximise_val_metric = True
metric_early_stop = Accuracy
max_worse_epoch_nb = 30
......@@ -122,6 +114,7 @@ pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_chan = 64
resnet_stride = 2
......
......@@ -66,14 +66,6 @@ lstm_hid_size = 1024
augment_data = False
regression = False
uncertainty = False
uncer_loss_type = MSE
uncer_exact_inf_div = True
uncert_inf_div_weight = 0
uncer_ll_ratio_weight = 1
uncer_max_adv_entr_weight = 0
maximise_val_metric = True
metric_early_stop = Accuracy
max_worse_epoch_nb = 30
......@@ -117,6 +109,7 @@ pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_chan = 64
resnet_stride = 2
......
......@@ -66,14 +66,6 @@ lstm_hid_size = 1024
augment_data = False
regression = False
uncertainty = False
uncer_loss_type = MSE
uncer_exact_inf_div = True
uncert_inf_div_weight = 0
uncer_ll_ratio_weight = 1
uncer_max_adv_entr_weight = 0
maximise_val_metric = True
metric_early_stop = Accuracy
max_worse_epoch_nb = 30
......@@ -122,6 +114,7 @@ pn_topk_pred_depth= False
pn_topk_farthest_pts_sampling = False
pn_topk_fps_nb_points = 64
pn_use_xyz = True
pn_topk_dropout = 0
resnet_chan = 64
resnet_stride = 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment