Commit 22ce5acf authored by E144069X's avatar E144069X

Use pytorch_geometric

parent fbfbfb16
......@@ -9,11 +9,15 @@ import args
import sys
import cv2
import deeplab
import glob
from skimage.transform import resize
import matplotlib.pyplot as plt
#pointnet2 can only be installed on a machine that has a gpu
#pointnet2 and torch_cluster can only be installed on a machine that has a gpu
#But to avoid program stopping when doing test on a no-gpu machine
#the error is catched
try:
import torch_geometric
import pointnet2
except ModuleNotFoundError:
pass
......@@ -376,7 +380,10 @@ class DirectPointExtractor(nn.Module):
points = x.view(x.size(0),x.size(1)//2,2)
points = torch.cat((points,torch.zeros(points.size(0),points.size(1),1).to(x.device)),dim=-1)
return {"points":points}
return {"points":points,
"batch":torch.arange(points.size(0)).unsqueeze(1).expand(points.size(0),points.size(1)).reshape(-1).to(points.device),
"pos":points.reshape(points.size(0)*points.size(1),points.size(2)),
"pointfeatures":None}
class TopkPointExtractor(nn.Module):
......@@ -437,6 +444,10 @@ class TopkPointExtractor(nn.Module):
def forward(self,imgBatch):
featureMaps = self.feat(imgBatch)
#Because of zero padding, the border are very active, so we remove it.
featureMaps = featureMaps[:,:,3:-3,3:-3]
pointFeaturesMap = self.conv1x1(featureMaps)
x = torch.pow(pointFeaturesMap,2).sum(dim=1,keepdim=True)
......@@ -445,9 +456,15 @@ class TopkPointExtractor(nn.Module):
abs,ord = (flatInds%x.shape[-1],flatInds//x.shape[-1])
if self.furthestPointSampling:
#abs,ord = abs.float(),ord.float()
points = torch.cat((abs.unsqueeze(-1),ord.unsqueeze(-1)),dim=-1).float()
selectedPointInds = pointnet2.utils.pointnet2_utils.furthest_point_sample(points,self.furthestPointSampling_nb_pts).to(abs.device)
# exampleInds is a list that indicates the example index for each point in the batch
# If batch_size==1, exampleInds is just a list of zeros.
# If batch_size==2, the first half of exampleInds if filled with 0's and the rest with 1's.
exampleInds = torch.arange(points.size(0)).unsqueeze(1).expand(points.size(0),points.size(1)).reshape(-1).to(points.device)
selectedPointInds = torch_geometric.nn.fps(points.view(-1,2),exampleInds,ratio=self.furthestPointSampling_nb_pts/abs.size(1))
selectedPointInds = selectedPointInds.reshape(points.size(0),-1)
selectedPointInds = selectedPointInds%abs.size(1)
abs = abs[torch.arange(abs.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
ord = ord[torch.arange(ord.size(0)).unsqueeze(1).long(),selectedPointInds.long()]
......@@ -477,8 +494,11 @@ class TopkPointExtractor(nn.Module):
ord = yShift + ord.float()
abs,ord = abs.unsqueeze(-1).float(),ord.unsqueeze(-1).float()
points = torch.cat((abs,ord,depth,pointFeat),dim=-1).float()
retDict["points"] = points
points = torch.cat((abs,ord,depth),dim=-1).float()
retDict['points'] = torch.cat((abs,ord,depth,pointFeat),dim=-1).float()
retDict["batch"] = torch.arange(points.size(0)).unsqueeze(1).expand(points.size(0),points.size(1)).reshape(-1).to(points.device)
retDict["pos"] = points.reshape(points.size(0)*points.size(1),points.size(2))
retDict["pointfeatures"] = pointFeat.reshape(pointFeat.size(0)*pointFeat.size(1),pointFeat.size(2))
return retDict
def fastSoftCoordRefiner(self,x,abs,ord,kerSize=5,secondOrderSpatWeight=False):
......@@ -513,7 +533,7 @@ class TopkPointExtractor(nn.Module):
class PointNet2(FirstModel):
def __init__(self,cuda,videoMode,pn_builder,classNb,nbFeat,featModelName='resnet18',pretrainedFeatMod=True,topk=False,\
def __init__(self,cuda,videoMode,classNb,nbFeat,featModelName='resnet18',pretrainedFeatMod=True,topk=False,\
topk_softcoord=False,topk_softCoord_kerSize=2,topk_softCoord_secOrder=False,point_nb=256,reconst=False,\
encoderChan=1,encoderHidChan=64,predictDepth=False,topk_softcoord_shiftpred=False,topk_fps=False,topk_fps_nb_pts=64,\
useXYZ=True,**kwargs):
......@@ -527,14 +547,14 @@ class PointNet2(FirstModel):
else:
self.pointExtr = DirectPointExtractor(point_nb)
self.pn2 = pn_builder(num_classes=classNb,input_channels=encoderChan if topk else 0,use_xyz=useXYZ)
self.pn2 = pointnet2.Net(num_classes=classNb,input_channels=encoderChan if topk else 0)
def forward(self,x):
self.batchSize = x.size(0)
x = x.view(x.size(0)*x.size(1),x.size(2),x.size(3),x.size(4)) if self.videoMode else x
retDict = self.pointExtr(x)
x = self.pn2(retDict['points'].float())
x = self.pn2(retDict['pointfeatures'],retDict['pos'],retDict['batch'])
retDict['x'] = x
......@@ -763,18 +783,7 @@ def netBuilder(args):
multiModSparseConst=args.resnet_multi_model_sparse_const,layerSizeReduce=args.resnet_layer_size_reduce)
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
elif args.temp_mod == "pointnet2":
pn_builder = pointnet2.models.pointnet2_ssg_cls.Pointnet2SSG
firstModel = PointNet2(args.cuda,args.video_mode,pn_builder,args.class_nb,nbFeat=nbFeat,featModelName=args.feat,pretrainedFeatMod=args.pretrained_visual,encoderHidChan=args.pn_topk_hid_chan,\
topk=args.pn_topk,topk_softcoord=args.pn_topk_softcoord,topk_softCoord_kerSize=args.pn_topk_softcoord_kersize,topk_softCoord_secOrder=args.pn_topk_softcoord_secorder,\
point_nb=args.pn_point_nb,reconst=args.pn_topk_reconst,topk_softcoord_shiftpred=args.pn_topk_softcoord_shiftpred,\
encoderChan=args.pn_topk_enc_chan,multiModel=args.resnet_multi_model,multiModSparseConst=args.resnet_multi_model_sparse_const,predictDepth=args.pn_topk_pred_depth,\
layerSizeReduce=args.resnet_layer_size_reduce,preLayerSizeReduce=args.resnet_prelay_size_reduce,dilation=args.resnet_dilation,\
topk_fps=args.pn_topk_farthest_pts_sampling,topk_fps_nb_pts=args.pn_topk_fps_nb_points,useXYZ=args.pn_use_xyz)
secondModel = Identity(args.video_mode,nbFeat,args.class_nb,False)
elif args.temp_mod == "pointnet2_pp":
pn_builder = pointnet2.models.pointnet2_msg_cls.Pointnet2MSG
firstModel = PointNet2(args.cuda,args.video_mode,pn_builder,args.class_nb,nbFeat=nbFeat,featModelName=args.feat,pretrainedFeatMod=args.pretrained_visual,encoderHidChan=args.pn_topk_hid_chan,\
firstModel = PointNet2(args.cuda,args.video_mode,args.class_nb,nbFeat=nbFeat,featModelName=args.feat,pretrainedFeatMod=args.pretrained_visual,encoderHidChan=args.pn_topk_hid_chan,\
topk=args.pn_topk,topk_softcoord=args.pn_topk_softcoord,topk_softCoord_kerSize=args.pn_topk_softcoord_kersize,topk_softCoord_secOrder=args.pn_topk_softcoord_secorder,\
point_nb=args.pn_point_nb,reconst=args.pn_topk_reconst,topk_softcoord_shiftpred=args.pn_topk_softcoord_shiftpred,\
encoderChan=args.pn_topk_enc_chan,multiModel=args.resnet_multi_model,multiModSparseConst=args.resnet_multi_model_sparse_const,predictDepth=args.pn_topk_pred_depth,\
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment