-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
127 lines (118 loc) · 6.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
import numpy as np
import sys
from Model import EEGNet
from load_data import load_data
from training import stage_training, standard_training
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import accuracy_score
import os
import shutil
from tabulate import tabulate
import gc
import tensorflow as tf
def main(args):
bw = np.array([[args.frequency_cut_low,args.frequency_cut_high]])
t_start = np.int(250.0 * 2.5)
t_end = np.int(250.0 * 6)
t_sample = 875
n_channel = 22
n_classes = 4
callback = [tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=args.patience,
verbose=0,
mode="auto",
restore_best_weights=True,
)]
subjects = [1]#,2,3,4,5,6,7,8,9]
Base_model = EEGNet(nb_classes = n_classes ,Chans = n_channel, Samples= t_sample)
if args.cross_subject :
[All_data_train,All_label_train,All_data_eval,All_label_eval] = load_data(subjects,args.path,bw,n_classes,t_start,t_end,t_sample)
train_data,validation_data,train_lbl,validation_lbl = train_test_split(All_data_train,All_label_train)
Base_model = stage_training(Base_model,train_data,train_lbl,
validation_data, validation_lbl,
callback,epochs = [args.epochs,args.fine_tune_epochs])
Base_model.save_weights('./models/base_model')
if args.subject == 0 :
pred_EEGNet = Base_model.predict(All_data_eval)
pred_EEGNet = np.argmax(pred_EEGNet,axis = 1)
eval_label = np.argmax(All_label_eval,axis = 1)
accuracy = accuracy_score(y_true=eval_label, y_pred=pred_EEGNet)
output = [
["Model","EEGNet"],
["Stage training", 'Enable' if args.stage else 'Disable'],
["Training Epochs ", args.epochs],
["Accuracy ",accuracy]
]
print(tabulate(output))
sys.exit()
[All_data_train,All_label_train,All_data_eval,All_label_eval] = load_data([args.subject],args.path,bw,n_classes,t_start,t_end,t_sample)
kf = KFold(n_splits=4)
accuracies = []
for iter in range(args.iterations):
for train_index, valid_index in kf.split(All_data_train):
gc.collect()
train_data = All_data_train[train_index]
validation_data = All_data_train[valid_index]
train_lbl = All_label_train[train_index]
validation_lbl = All_label_train[valid_index]
EEGNet_model = EEGNet(nb_classes = n_classes ,Chans = n_channel, Samples= t_sample)
EEGNet_model.load_weights('./models/base_model')
if args.stage :
EEGNet_model = stage_training(EEGNet_model,train_data,train_lbl,
validation_data, validation_lbl,
callback,epochs = [args.epochs,args.fine_tune_epochs])
else:
EEGNet_model = standard_training(EEGNet_model,train_data,train_lbl,
validation_data, validation_lbl,
callback,epochs = [args.epochs,args.fine_tune_epochs])
pred_EEGNet = EEGNet_model.predict(All_data_eval)
pred_EEGNet = np.argmax(pred_EEGNet,axis = 1)
#print(pred_EEGNet.shape)
eval_label = np.argmax(All_label_eval,axis = 1)
#print(eval_label.shape)
accuracy = accuracy_score(y_true=eval_label, y_pred=pred_EEGNet)
accuracies.append(accuracy)
'''print('EEGNet acc : ',accuracy)
print("=============================================================")'''
output = [
["Model","EEGNet"],
["Stage training", 'Enable' if args.stage else 'Disable'],
["Cross Subjects", 'Enable' if args.cross_subject else 'Disable'],
["Training Epochs ", args.epochs],
["Training Iterations ", args.iterations],
["K_Fold ", 'Enable' if args.k_fold else 'Disable'],
["Accuracy ",np.mean(accuracies)]
]
print(tabulate(output))
if(not args.save_model):
shutil.rmtree("./models")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Stage and Cross training strategy')
parser.add_argument('--path', action="store",dest = "path", default="../BCI/dataset/Competition_iv_2a/",
help="path to the dataset folder see : bcni datasets fotmat")
parser.add_argument('--patience', action="store", dest="patience", type = int ,default=1,
help="early stopping callback patience")
parser.add_argument('--epochs' , action="store", dest="epochs" , default=1, type = int,
help="epochs model will be trained on")
parser.add_argument('--frequency_cut_low' , action="store", dest="frequency_cut_low", type = float, default=4,
help="lower cut-off frequency in proprocessing")
parser.add_argument('--frequency_cut_high', action="store", dest="frequency_cut_high", type = float, default=40,
help="higher cut-off frequency in proprocessing")
parser.add_argument('--subject' , action="store", dest="subject", type = int, default= 0,
help="target subject")
parser.add_argument('--k_fold', action="store", dest="k_fold", type = bool,default=True )
parser.add_argument('--iterations' , action="store", dest="iterations", type = int, default=1)
parser.add_argument('--fine_tune_epochs', action="store", dest="fine_tune_epochs", type = int,default=1 )
parser.add_argument('--save_model', action="store", dest="save_model", type = int,default=1 ,
help="if true cross trained model wont be deleted after execution")
parser.add_argument('--stage' , action="store_true", dest="stage", default = False,
help="if true stage training will be used instead of standard training")
parser.add_argument('--cross_subject' , action="store_true", dest="cross_subject", default = False,
help="model will be pre-trained on all subjects of the data set")
args = parser.parse_args()
#print(args)
main(args)
sys.exit()