Commit a4824f9e authored by Antoine PIGEAU's avatar Antoine PIGEAU
Browse files

NN is done

parent a8747aed
......@@ -29,7 +29,7 @@ import keras as kr
# from keras.utils import plot_model
from keras.utils.vis_utils import plot_model
from tensorflow.keras import optimizers
from keras import optimizers
from classifierManager.model.neuralNetworkInitializer import ModelLSTM
from classifierManager.model.neuralNetworkInitializer import ModelPerceptron
......@@ -58,7 +58,7 @@ class DenseNN:
self.callbacks_list = [
kr.callbacks.EarlyStopping(
monitor='acc',
monitor='val_accuracy',
patience=1,
),
kr.callbacks.ModelCheckpoint(
......@@ -129,14 +129,15 @@ class DenseNN:
def compile(self):
#optimizer = 'rmsprop' #'adam'
optimizer = optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)
#optimizer = optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)
optimizer = optimizers.RMSprop(lr=0.001)
self.model.compile(optimizer = optimizer, loss = self.loss, metrics = ['accuracy'])
print("DenseNN - compile() - compile is done")
def fit(self, train_x, train_y, validation_x, validation_y, batchSize=32, epochs=20):
print("DenseNN - fit() - start fitting")
if validation_x is None and validation_y is None:
self.model.fit(x=train_x,
......@@ -146,7 +147,7 @@ class DenseNN:
verbose=2)
else:
self.model.fit(x=train_x,
y=train_y,
batch_size=batchSize,
......
......@@ -256,7 +256,7 @@ class ScriptNeuralNetwork(ScriptClassifier):
for _ in range(0, ntime):
#(accuracy, confusionMatrix, nbEpoch, auc)
result = self.predictionTask(course, whereToCut, cache=cache)
result = self.predictionTask(course, whereToCut)
if auc is None:
raise ValueError("classifierManager.script.ScriptNeuralNetwork - predictionTaskNTimes : auc is None")
......@@ -310,6 +310,8 @@ class ScriptNeuralNetwork(ScriptClassifier):
accuracies = []
aucs = []
epochs = []
size = len(self.classifier.nameGroups)
avgConfusionMatrix = np.zeros((size, size))
accuraciesPerClass = []
dictAllCourses = {}
......@@ -349,7 +351,7 @@ class ScriptNeuralNetwork(ScriptClassifier):
fileResult.write(str(auc))
fileResult.write("\n nbEpoch\n")
fileResult.write(str(nbEpoch))
fileResult.write(str(epoch))
......
......@@ -109,7 +109,7 @@ if __name__ == "__main__":
''' for all periods '''
classifier.predictionTaskForAllPeriods(ntime=10, cache=False)
classifier.predictionTaskForAllPeriods(ntime=10, cache=True)
''' NN : all courses with hidden layers'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment