Commit 77feaaed authored by E144069X's avatar E144069X

Write raw predictions for each image from test loader when working with video_mode=False

parent 033782d3
......@@ -86,10 +86,9 @@ def epochSeqTr(model,optim,log_interval,loader, epoch, args,writer,**kwargs):
with torch.autograd.detect_anomaly():
resDict = model(data,timeElapsedTensor)
output = resDict["pred"]
loss = computeLoss(args.nll_weight,output,target,args.pn_reconst_weight,resDict,data,args.video_mode)
loss.backward()
if args.distributed:
average_gradients(model)
......@@ -300,6 +299,8 @@ def epochImgEval(model,log_interval,loader, epoch, args,writer,metricEarlyStop,m
metDictSample["Loss"] = loss.detach().data.item()
metrDict = metrics.updateMetrDict(metrDict,metDictSample)
writePreds(output,target,epoch,args.exp_id,args.model_id,args.class_nb,batch_idx)
validBatch += 1
if validBatch > 15 and args.debug:
......@@ -314,6 +315,18 @@ def epochImgEval(model,log_interval,loader, epoch, args,writer,metricEarlyStop,m
return metrDict[metricEarlyStop]
def writePreds(predBatch,targBatch,epoch,exp_id,model_id,class_nb,batch_idx):
csvPath = "../results/{}/{}_epoch{}.csv".format(exp_id,model_id,epoch)
if (batch_idx==0 and epoch==1) or not os.path.exists(csvPath):
with open(csvPath,"w") as text_file:
print("targ,"+",".join(np.arange(class_nb).astype(str)),file=text_file)
with open(csvPath,"a") as text_file:
for i in range(len(predBatch)):
print(str(targBatch[i].cpu().detach().numpy())+","+",".join(predBatch[i].cpu().detach().numpy().astype(str)),file=text_file)
def computeTransMat(dataset,transMat,priors,propStart,propEnd,propSetIntFormat):
videoPaths = load_data.findVideos(dataset,propStart,propEnd,propSetIntFormat)
......@@ -414,7 +427,7 @@ def initialize_Net_And_EpochNumber(net,exp_id,model_id,cuda,start_mode,init_path
startEpoch = 1
elif start_mode == "fine_tune":
params = torch.load(init_path)
params = torch.load(init_path,map_location="cpu" if not cuda else None)
#Checking if the key of the model start with "module."
startsWithModule = (list(net.state_dict().keys())[0].find("module.") != -1)
......@@ -605,7 +618,8 @@ def run(args):
if not args.no_train:
trainFunc(**kwargsTr)
else:
net.load_state_dict(torch.load("../models/{}/model{}_epoch{}".format(args.no_train[0],args.no_train[1],epoch)))
if not args.no_val:
net.load_state_dict(torch.load("../models/{}/model{}_epoch{}".format(args.exp_id_no_train[0],args.model_id_no_train[1],epoch),map_location="cpu" if not args.cuda else None))
if not args.no_val:
with torch.no_grad():
......@@ -625,7 +639,7 @@ def run(args):
kwargsTest['loader'] = testLoader
net.load_state_dict(torch.load("../models/{}/model{}_epoch{}".format(args.exp_id,args.model_id,bestEpoch)))
net.load_state_dict(torch.load("../models/{}/model{}_epoch{}".format(args.exp_id,args.model_id,bestEpoch),map_location="cpu" if not args.cuda else None))
kwargsTest["model"] = net
kwargsTest["epoch"] = bestEpoch
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment