Commit 238a3442 authored by E144069X's avatar E144069X

Print more significant digits for correlation and skip the first phase if already done

parent 97c56cb8
......@@ -21,7 +21,6 @@ from PIL import Image
import load_data
import metrics
import utils
import formatData
......@@ -63,8 +62,6 @@ def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbC
- epoch (int): the epoch at which to evaluate
'''
try:
resFilePaths = np.array(sorted(glob.glob("../results/{}/{}_epoch{}_*.csv".format(exp_id,model_id,epoch)),key=utils.findNumbers))
videoNameDict = buildVideoNameDict(dataset,partBeg,partEnd,propSetIntFormat,resFilePaths)
......@@ -136,7 +133,6 @@ def evalModel(dataset,partBeg,partEnd,propSetIntFormat,exp_id,model_id,epoch,nbC
printHeader = not os.path.exists("../results/{}/metrics.csv".format(exp_id))
with open("../results/{}/metrics.csv".format(exp_id),"a") as text_file:
if printHeader:
#print("Model,Accuracy,Accuracy (Viterbi),Correlation,Temp Accuracy",file=text_file)
print("Model,",end="",file=text_file)
print(",".join([key for key in metEval.keys() if key.find("Temp Accuracy") == -1])+",",end="",file=text_file)
print("Temp Accuracy "+str(tempAccThresToPrint),file=text_file)
......@@ -188,8 +184,11 @@ def computeMetrics(path,dataset,videoName,resFilePaths,videoNameDict,metTun,tran
return metr_dict,len(scores)
def formatMetr(mean,std):
return "$"+str(round(mean,2))+" \pm "+str(round(std,2))+"$"
def formatMetr(mean,std,corr):
if corr:
return "$"+str(round(mean,4))+" \pm "+str(round(std,4))+"$"
else:
return "$"+str(round(mean,2))+" \pm "+str(round(std,2))+"$"
def plotScore(dataset,exp_id,model_ids,epochs,trainPartBeg,trainPartEnd,testPartBeg,testPartEnd,model_labels):
''' This function plots the scores given by a model to seral videos.
......@@ -413,7 +412,7 @@ def agregatePerfs(exp_id,paramAgr,keysRef,namesRef,tempAccMinThres,tempAccMaxThr
csvStr += keyToNameDict[key]
for j in range(len(mean[0])):
csvStr += "&"+formatMetr(mean[i,j],std[i,j])
csvStr += "&"+formatMetr(mean[i,j],std[i,j],corr=metricNames[j]=="Correlation")
csvStr += "\\\\ \n"
......@@ -970,19 +969,21 @@ def main(argv=None):
args.test_part_beg,args.test_part_end,args.names)
if args.eval_model:
if os.path.exists("../results/{}/metrics.csv".format(args.exp_id)):
os.remove("../results/{}/metrics.csv".format(args.exp_id))
#if os.path.exists("../results/{}/metrics.csv".format(args.exp_id)):
# os.remove("../results/{}/metrics.csv".format(args.exp_id))
if not os.path.exists("../results/{}/metrics.csv".format(args.exp_id)):
for i,model_id in enumerate(args.model_ids):
for i,model_id in enumerate(args.model_ids):
conf = configparser.ConfigParser()
conf.read("../models/{}/{}.ini".format(args.exp_id,model_id))
conf = conf["default"]
conf = configparser.ConfigParser()
conf.read("../models/{}/{}.ini".format(args.exp_id,model_id))
conf = conf["default"]
evalModel(conf["dataset_test"],float(conf["test_part_beg"]),float(conf["test_part_end"]),str2bool(conf["prop_set_int_fmt"]),args.exp_id,model_id,epoch=args.epochs_to_process[i],\
nbClass=int(conf["class_nb"]),tempAccMinThres=args.temp_acc_min_thres,tempAccMaxThres=args.temp_acc_max_thres,tempAccThresStep=args.temp_acc_thres_step,\
tempAccThresToPrint=args.temp_acc_threstoprint)
evalModel(conf["dataset_test"],float(conf["test_part_beg"]),float(conf["test_part_end"]),str2bool(conf["prop_set_int_fmt"]),args.exp_id,model_id,epoch=args.epochs_to_process[i],\
nbClass=int(conf["class_nb"]),tempAccMinThres=args.temp_acc_min_thres,tempAccMaxThres=args.temp_acc_max_thres,tempAccThresStep=args.temp_acc_thres_step,\
tempAccThresToPrint=args.temp_acc_threstoprint)
if len(args.param_agr) > 0:
agregatePerfs(args.exp_id,args.param_agr,args.keys,args.names,args.temp_acc_min_thres,args.temp_acc_max_thres,args.temp_acc_thres_step)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment