Commit 205be22c authored by E144069X's avatar E144069X

Added args to vary threshold in temporal accuracy

parent d2c75351
......@@ -45,7 +45,7 @@ from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbClass):
def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbClass,tempAccMinThres,tempAccMaxThres,tempAccThresNb):
'''
Evaluate a model. It requires the scores for each video to have been computed already with the trainVal.py script. Check readme to
see how to compute the scores for each video.
......@@ -90,11 +90,16 @@ def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbC
metEval={}
for metricName in metricNameList:
if metricName.find("Accuracy") != -1:
metEval[metricName] = np.zeros(len(resFilePaths))
if metricName.find("Temp") != -1:
metEval[metricName] = np.zeros((len(resFilePaths),tempAccThresNb))
else:
metEval[metricName] = np.zeros(len(resFilePaths))
if metricName == "Correlation":
metEval[metricName] = []
metricParamDict = {"Temp Accuracy":{"minThres":tempAccMinThres,"maxThres":tempAccMaxThres,"thresNb":tempAccThresNb}}
transMat,priors = torch.zeros((nbClass,nbClass)).float(),torch.zeros((nbClass)).float()
transMat,_ = trainVal.computeTransMat(dataset,transMat,priors,partBeg,partEnd,propSetIntFormat)
......@@ -106,7 +111,7 @@ def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbC
videoName = videoNameDict[path]
#Compute the metrics with the default threshold (0.5) and with a threshold tuned on each video with a leave-one out method
metrDict,frameNb = computeMetrics(path,dataset,videoName,resFilePaths,videoNameDict,metTun,transMat)
metrDict,frameNb = computeMetrics(path,dataset,videoName,resFilePaths,videoNameDict,metTun,transMat,metricParamDict)
for metricName in metEval.keys():
......@@ -125,6 +130,9 @@ def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbC
metEval["Correlation"] = np.array(metEval["Correlation"])
metEval["Correlation"] = np.corrcoef(metEval["Correlation"][:,0],metEval["Correlation"][:,1])[0,1]
metEval["Temp Accuracy"],meanAccPerThres = agregateTempAcc(metEval["Temp Accuracy"],metricParamDict["Temp Accuracy"])
saveTempAccPerThres(exp_id,model_id,meanAccPerThres)
#Writing the latex table
printHeader = not os.path.exists("../results/{}/metrics.csv".format(exp_id))
with open("../results/{}/metrics.csv".format(exp_id),"a") as text_file:
......@@ -134,7 +142,18 @@ def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbC
print(model_id+","+str(metEval["Accuracy"].sum()/totalFrameNb)+","+str(metEval["Accuracy (Viterbi)"].sum()/totalFrameNb)+","\
+str(metEval["Correlation"])+","+str(metEval["Temp Accuracy"].mean()),file=text_file)
def computeMetrics(path,dataset,videoName,resFilePaths,videoNameDict,metTun,transMat):
def agregateTempAcc(acc,accParamDict):
meanAccPerThres = acc.mean(axis=0)
meanAccPerVid = acc.mean(axis=1)
return meanAccPerVid,meanAccPerThres
def saveTempAccPerThres(exp_id,model_id,meanAccPerThres):
#step = (tempDic["maxThres"]-tempDic["minThres"])/tempDic["thresNb"]
#thresList = np.arange(tempDic["minThres"],tempDic["maxThres"],step)
np.savetxt("../results/{}/{}_tempAcc.csv".format(exp_id,model_id),meanAccPerThres)
def computeMetrics(path,dataset,videoName,resFilePaths,videoNameDict,metTun,transMat,metricParamDict):
'''
Evaluate a model on a video by using the default threshold and a threshold tuned on all the other video
......@@ -159,7 +178,7 @@ def computeMetrics(path,dataset,videoName,resFilePaths,videoNameDict,metTun,tran
gt = gt[:len(scores)]
metr_dict = metrics.binaryToMetrics(torch.tensor(scores[np.newaxis,:]).float(),torch.tensor(gt[np.newaxis,:]),transMat,videoNames=[videoName])
metr_dict = metrics.binaryToMetrics(torch.tensor(scores[np.newaxis,:]).float(),torch.tensor(gt[np.newaxis,:]),metricParamDict,transMat,videoNames=[videoName])
return metr_dict,len(scores)
......@@ -351,7 +370,7 @@ def plotData(nbClass,dataset):
plt.tight_layout()
plt.savefig("../vis/prior_{}.png".format(dataset))
def agregatePerfs(exp_id,paramAgr,keysRef,namesRef):
def agregatePerfs(exp_id,paramAgr,keysRef,namesRef,tempAccMinThres,tempAccMaxThres,tempAccThresNb):
csv = np.genfromtxt("../results/{}/metrics.csv".format(exp_id),delimiter=",",dtype="str")
......@@ -369,12 +388,12 @@ def agregatePerfs(exp_id,paramAgr,keysRef,namesRef):
else:
groupedLines[key] = [line]
plotTempAcc(groupedLines.keys(),tempAccMinThres,tempAccMaxThres,tempAccThresNb,exp_id,keyToNameDict,groupedLines)
csvStr = "Model&"+"&".join(metricNames)+"\\\\ \n \hline \n"
mean = np.zeros((len(groupedLines.keys()),csv.shape[1]-1))
std = np.zeros((len(groupedLines.keys()),csv.shape[1]-1))
keys = groupedLines.keys()
#Reordering the keys
orderedKeys = []
for name in nameToKeyDict.keys():
......@@ -403,10 +422,31 @@ def agregatePerfs(exp_id,paramAgr,keysRef,namesRef):
#Computing the t-test
ttest_matrix(groupedLines,orderedKeys,exp_id,metricNames,keyToNameDict)
def plotRes(mean,std,csv,keys,exp_id,metricNames,keyToNameDict):
def plotTempAcc(keys,minThres,maxThres,thresNb,exp_id,keyToNameDict,groupedLines):
csv = csv[1:]
step = (maxThres-minThres)/thresNb
thresList = np.arange(minThres,maxThres,step)
plt.figure()
for key in keys:
tempAccPaths = []
for line in groupedLines[key]:
tempAccPaths.append("../results/{}/{}_tempAcc.csv".format(exp_id,line[0]))
tempAccFiles = np.concatenate(list(map(lambda x:np.genfromtxt(x)[np.newaxis],tempAccPaths)),axis=0)
meanTempAcc = tempAccFiles.mean(axis=0)
plt.plot(thresList,meanTempAcc,label=keyToNameDict[key], marker='o',)
plt.ylim(0,1)
plt.ylabel("Temporal Accuracy")
plt.xlabel("Tolerance threshold (hours)")
plt.legend()
plt.savefig("../vis/{}/tempAcc.png".format(exp_id))
def plotRes(mean,std,csv,keys,exp_id,metricNames,keyToNameDict):
csv = csv[1:]
#Plot the agregated results
fig = plt.figure()
plt.subplots_adjust(bottom=0.2)
......@@ -862,6 +902,11 @@ def main(argv=None):
argreader.parser.add_argument('--epochs_to_process',nargs="*",type=int,metavar="N",help='The list of epoch at which to evaluate each model. This argument should be set when using the --eval_model argument.')
argreader.parser.add_argument('--model_ids',type=str,nargs="*",metavar="NAME",help='The id of the models to process.')
argreader.parser.add_argument('--temp_acc_min_thres',type=float,metavar="MIN",help='The minimum threshold to use for temporal accuracy (In hours)')
argreader.parser.add_argument('--temp_acc_max_thres',type=float,metavar="MAX",help='The maximum threshold to use for temporal accuracy (In hours)')
argreader.parser.add_argument('--temp_acc_thres_nb',type=int,metavar="INT",help='The number of threshold to use for temporal accuracy.')
######################## Database plot #################################
argreader.parser.add_argument('--plot_data',type=int,metavar="N",help='To plot the state transition matrix and the prior vector. The value is the number of classes. The --dataset_test must be set.')
......@@ -917,6 +962,7 @@ def main(argv=None):
args.test_part_beg,args.test_part_end,args.names)
if args.eval_model:
'''
if os.path.exists("../results/{}/metrics.csv".format(args.exp_id)):
os.remove("../results/{}/metrics.csv".format(args.exp_id))
......@@ -928,10 +974,10 @@ def main(argv=None):
evalModel(conf["dataset_test"],float(conf["test_part_beg"]),float(conf["test_part_end"]),str2bool(conf["prop_set_int_fmt"]),args.exp_id,model_id,epoch=args.epochs_to_process[i],\
nbClass=int(conf["class_nb"]))
nbClass=int(conf["class_nb"]),tempAccMinThres=args.temp_acc_min_thres,tempAccMaxThres=args.temp_acc_max_thres,tempAccThresNb=args.temp_acc_thres_nb)
'''
if len(args.param_agr) > 0:
agregatePerfs(args.exp_id,args.param_agr,args.keys,args.names)
agregatePerfs(args.exp_id,args.param_agr,args.keys,args.names,args.temp_acc_min_thres,args.temp_acc_max_thres,args.temp_acc_thres_nb)
if not args.plot_data is None:
plotData(args.plot_data,args.dataset_test)
if args.plot_attention_maps:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment