chaitanya9 commited on
Commit
9a38ba2
·
1 Parent(s): 6609b72

Upload emotion_recognition.py

Browse files
Files changed (1) hide show
  1. emotion_recognition.py +497 -0
emotion_recognition.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_extractor import load_data
2
+ from utils import extract_feature, AVAILABLE_EMOTIONS
3
+ from create_csv import write_emodb_csv, write_tess_ravdess_csv, write_custom_csv
4
+
5
+ from sklearn.metrics import accuracy_score, make_scorer, fbeta_score, mean_squared_error, mean_absolute_error
6
+ from sklearn.metrics import confusion_matrix
7
+ from sklearn.model_selection import GridSearchCV
8
+
9
+ import matplotlib.pyplot as pl
10
+ from time import time
11
+ from utils import get_best_estimators, get_audio_config
12
+ import numpy as np
13
+ import tqdm
14
+ import os
15
+ import random
16
+ import pandas as pd
17
+
18
+
19
+ class EmotionRecognizer:
20
+ """A class for training, testing and predicting emotions based on
21
+ speech's features that are extracted and fed into `sklearn` or `keras` model"""
22
+ def __init__(self, model=None, **kwargs):
23
+ """
24
+ Params:
25
+ model (sklearn model): the model used to detect emotions. If `model` is None, then self.determine_best_model()
26
+ will be automatically called
27
+ emotions (list): list of emotions to be used. Note that these emotions must be available in
28
+ RAVDESS_TESS & EMODB Datasets, available nine emotions are the following:
29
+ 'neutral', 'calm', 'happy', 'sad', 'angry', 'fear', 'disgust', 'ps' ( pleasant surprised ), 'boredom'.
30
+ Default is ["sad", "neutral", "happy"].
31
+ tess_ravdess (bool): whether to use TESS & RAVDESS Speech datasets, default is True
32
+ emodb (bool): whether to use EMO-DB Speech dataset, default is True,
33
+ custom_db (bool): whether to use custom Speech dataset that is located in `data/train-custom`
34
+ and `data/test-custom`, default is True
35
+ tess_ravdess_name (str): the name of the output CSV file for TESS&RAVDESS dataset, default is "tess_ravdess.csv"
36
+ emodb_name (str): the name of the output CSV file for EMO-DB dataset, default is "emodb.csv"
37
+ custom_db_name (str): the name of the output CSV file for the custom dataset, default is "custom.csv"
38
+ features (list): list of speech features to use, default is ["mfcc", "chroma", "mel"]
39
+ (i.e MFCC, Chroma and MEL spectrogram )
40
+ classification (bool): whether to use classification or regression, default is True
41
+ balance (bool): whether to balance the dataset ( both training and testing ), default is True
42
+ verbose (bool/int): whether to print messages on certain tasks, default is 1
43
+ Note that when `tess_ravdess`, `emodb` and `custom_db` are set to `False`, `tess_ravdess` will be set to True
44
+ automatically.
45
+ """
46
+ # emotions
47
+ self.emotions = kwargs.get("emotions", ["sad", "neutral", "happy"])
48
+ # make sure that there are only available emotions
49
+ self._verify_emotions()
50
+ # audio config
51
+ self.features = kwargs.get("features", ["mfcc", "chroma", "mel"])
52
+ self.audio_config = get_audio_config(self.features)
53
+ # datasets
54
+ self.tess_ravdess = kwargs.get("tess_ravdess", True)
55
+ self.emodb = kwargs.get("emodb", True)
56
+ self.custom_db = kwargs.get("custom_db", True)
57
+
58
+ if not self.tess_ravdess and not self.emodb and not self.custom_db:
59
+ self.tess_ravdess = True
60
+
61
+ self.classification = kwargs.get("classification", True)
62
+ self.balance = kwargs.get("balance", True)
63
+ self.override_csv = kwargs.get("override_csv", True)
64
+ self.verbose = kwargs.get("verbose", 1)
65
+
66
+ self.tess_ravdess_name = kwargs.get("tess_ravdess_name", "tess_ravdess.csv")
67
+ self.emodb_name = kwargs.get("emodb_name", "emodb.csv")
68
+ self.custom_db_name = kwargs.get("custom_db_name", "custom.csv")
69
+
70
+ self.verbose = kwargs.get("verbose", 1)
71
+
72
+ # set metadata path file names
73
+ self._set_metadata_filenames()
74
+ # write csv's anyway
75
+ self.write_csv()
76
+
77
+ # boolean attributes
78
+ self.data_loaded = False
79
+ self.model_trained = False
80
+
81
+ # model
82
+ if not model:
83
+ self.determine_best_model()
84
+ else:
85
+ self.model = model
86
+
87
+ def _set_metadata_filenames(self):
88
+ """
89
+ Protected method to get all CSV (metadata) filenames into two instance attributes:
90
+ - `self.train_desc_files` for training CSVs
91
+ - `self.test_desc_files` for testing CSVs
92
+ """
93
+ train_desc_files, test_desc_files = [], []
94
+ if self.tess_ravdess:
95
+ train_desc_files.append(f"train_{self.tess_ravdess_name}")
96
+ test_desc_files.append(f"test_{self.tess_ravdess_name}")
97
+ if self.emodb:
98
+ train_desc_files.append(f"train_{self.emodb_name}")
99
+ test_desc_files.append(f"test_{self.emodb_name}")
100
+ if self.custom_db:
101
+ train_desc_files.append(f"train_{self.custom_db_name}")
102
+ test_desc_files.append(f"test_{self.custom_db_name}")
103
+
104
+ # set them to be object attributes
105
+ self.train_desc_files = train_desc_files
106
+ self.test_desc_files = test_desc_files
107
+
108
+ def _verify_emotions(self):
109
+ """
110
+ This method makes sure that emotions passed in parameters are valid.
111
+ """
112
+ for emotion in self.emotions:
113
+ assert emotion in AVAILABLE_EMOTIONS, "Emotion not recognized."
114
+
115
+ def get_best_estimators(self):
116
+ """Loads estimators from grid files and returns them"""
117
+ return get_best_estimators(self.classification)
118
+
119
+ def write_csv(self):
120
+ """
121
+ Write available CSV files in `self.train_desc_files` and `self.test_desc_files`
122
+ determined by `self._set_metadata_filenames()` method.
123
+ """
124
+ for train_csv_file, test_csv_file in zip(self.train_desc_files, self.test_desc_files):
125
+ # not safe approach
126
+ if os.path.isfile(train_csv_file) and os.path.isfile(test_csv_file):
127
+ # file already exists, just skip writing csv files
128
+ if not self.override_csv:
129
+ continue
130
+ if self.emodb_name in train_csv_file:
131
+ write_emodb_csv(self.emotions, train_name=train_csv_file, test_name=test_csv_file, verbose=self.verbose)
132
+ if self.verbose:
133
+ print("[+] Writed EMO-DB CSV File")
134
+ elif self.tess_ravdess_name in train_csv_file:
135
+ write_tess_ravdess_csv(self.emotions, train_name=train_csv_file, test_name=test_csv_file, verbose=self.verbose)
136
+ if self.verbose:
137
+ print("[+] Writed TESS & RAVDESS DB CSV File")
138
+ elif self.custom_db_name in train_csv_file:
139
+ write_custom_csv(emotions=self.emotions, train_name=train_csv_file, test_name=test_csv_file, verbose=self.verbose)
140
+ if self.verbose:
141
+ print("[+] Writed Custom DB CSV File")
142
+
143
+ def load_data(self):
144
+ """
145
+ Loads and extracts features from the audio files for the db's specified
146
+ """
147
+ if not self.data_loaded:
148
+ result = load_data(self.train_desc_files, self.test_desc_files, self.audio_config, self.classification,
149
+ emotions=self.emotions, balance=self.balance)
150
+ self.X_train = result['X_train']
151
+ self.X_test = result['X_test']
152
+ self.y_train = result['y_train']
153
+ self.y_test = result['y_test']
154
+ self.train_audio_paths = result['train_audio_paths']
155
+ self.test_audio_paths = result['test_audio_paths']
156
+ self.balance = result["balance"]
157
+ if self.verbose:
158
+ print("[+] Data loaded")
159
+ self.data_loaded = True
160
+
161
+ def train(self, verbose=1):
162
+ """
163
+ Train the model, if data isn't loaded, it 'll be loaded automatically
164
+ """
165
+ if not self.data_loaded:
166
+ # if data isn't loaded yet, load it then
167
+ self.load_data()
168
+ if not self.model_trained:
169
+ self.model.fit(X=self.X_train, y=self.y_train)
170
+ self.model_trained = True
171
+ if verbose:
172
+ print("[+] Model trained")
173
+
174
+ def predict(self, audio_path):
175
+ """
176
+ given an `audio_path`, this method extracts the features
177
+ and predicts the emotion
178
+ """
179
+ feature = extract_feature(audio_path, **self.audio_config).reshape(1, -1)
180
+ return self.model.predict(feature)[0]
181
+
182
+ def predict_proba(self, audio_path):
183
+ """
184
+ Predicts the probability of each emotion.
185
+ """
186
+ if self.classification:
187
+ feature = extract_feature(audio_path, **self.audio_config).reshape(1, -1)
188
+ proba = self.model.predict_proba(feature)[0]
189
+ result = {}
190
+ for emotion, prob in zip(self.model.classes_, proba):
191
+ result[emotion] = prob
192
+ return result
193
+ else:
194
+ raise NotImplementedError("Probability prediction doesn't make sense for regression")
195
+
196
+ def grid_search(self, params, n_jobs=2, verbose=1):
197
+ """
198
+ Performs GridSearchCV on `params` passed on the `self.model`
199
+ And returns the tuple: (best_estimator, best_params, best_score).
200
+ """
201
+ score = accuracy_score if self.classification else mean_absolute_error
202
+ grid = GridSearchCV(estimator=self.model, param_grid=params, scoring=make_scorer(score),
203
+ n_jobs=n_jobs, verbose=verbose, cv=3)
204
+ grid_result = grid.fit(self.X_train, self.y_train)
205
+ return grid_result.best_estimator_, grid_result.best_params_, grid_result.best_score_
206
+
207
+ def determine_best_model(self):
208
+ """
209
+ Loads best estimators and determine which is best for test data,
210
+ and then set it to `self.model`.
211
+ In case of regression, the metric used is MSE and accuracy for classification.
212
+ Note that the execution of this method may take several minutes due
213
+ to training all estimators (stored in `grid` folder) for determining the best possible one.
214
+ """
215
+ if not self.data_loaded:
216
+ self.load_data()
217
+
218
+ # loads estimators
219
+ estimators = self.get_best_estimators()
220
+
221
+ result = []
222
+
223
+ if self.verbose:
224
+ estimators = tqdm.tqdm(estimators)
225
+
226
+ for estimator, params, cv_score in estimators:
227
+ if self.verbose:
228
+ estimators.set_description(f"Evaluating {estimator.__class__.__name__}")
229
+ detector = EmotionRecognizer(estimator, emotions=self.emotions, tess_ravdess=self.tess_ravdess,
230
+ emodb=self.emodb, custom_db=self.custom_db, classification=self.classification,
231
+ features=self.features, balance=self.balance, override_csv=False)
232
+ # data already loaded
233
+ detector.X_train = self.X_train
234
+ detector.X_test = self.X_test
235
+ detector.y_train = self.y_train
236
+ detector.y_test = self.y_test
237
+ detector.data_loaded = True
238
+ # train the model
239
+ detector.train(verbose=0)
240
+ # get test accuracy
241
+ accuracy = detector.test_score()
242
+ # append to result
243
+ result.append((detector.model, accuracy))
244
+
245
+ # sort the result
246
+ # regression: best is the lower, not the higher
247
+ # classification: best is higher, not the lower
248
+ result = sorted(result, key=lambda item: item[1], reverse=self.classification)
249
+ best_estimator = result[0][0]
250
+ accuracy = result[0][1]
251
+ self.model = best_estimator
252
+ self.model_trained = True
253
+ if self.verbose:
254
+ if self.classification:
255
+ print(f"[+] Best model determined: {self.model.__class__.__name__} with {accuracy*100:.3f}% test accuracy")
256
+ else:
257
+ print(f"[+] Best model determined: {self.model.__class__.__name__} with {accuracy:.5f} mean absolute error")
258
+
259
+ def test_score(self):
260
+ """
261
+ Calculates score on testing data
262
+ if `self.classification` is True, the metric used is accuracy,
263
+ Mean-Squared-Error is used otherwise (regression)
264
+ """
265
+ y_pred = self.model.predict(self.X_test)
266
+ if self.classification:
267
+ return accuracy_score(y_true=self.y_test, y_pred=y_pred)
268
+ else:
269
+ return mean_squared_error(y_true=self.y_test, y_pred=y_pred)
270
+
271
+ def train_score(self):
272
+ """
273
+ Calculates accuracy score on training data
274
+ if `self.classification` is True, the metric used is accuracy,
275
+ Mean-Squared-Error is used otherwise (regression)
276
+ """
277
+ y_pred = self.model.predict(self.X_train)
278
+ if self.classification:
279
+ return accuracy_score(y_true=self.y_train, y_pred=y_pred)
280
+ else:
281
+ return mean_squared_error(y_true=self.y_train, y_pred=y_pred)
282
+
283
+ def train_fbeta_score(self, beta):
284
+ y_pred = self.model.predict(self.X_train)
285
+ return fbeta_score(self.y_train, y_pred, beta, average='micro')
286
+
287
+ def test_fbeta_score(self, beta):
288
+ y_pred = self.model.predict(self.X_test)
289
+ return fbeta_score(self.y_test, y_pred, beta, average='micro')
290
+
291
+ def confusion_matrix(self, percentage=True, labeled=True):
292
+ """
293
+ Computes confusion matrix to evaluate the test accuracy of the classification
294
+ and returns it as numpy matrix or pandas dataframe (depends on params).
295
+ params:
296
+ percentage (bool): whether to use percentage instead of number of samples, default is True.
297
+ labeled (bool): whether to label the columns and indexes in the dataframe.
298
+ """
299
+ if not self.classification:
300
+ raise NotImplementedError("Confusion matrix works only when it is a classification problem")
301
+ y_pred = self.model.predict(self.X_test)
302
+ matrix = confusion_matrix(self.y_test, y_pred, labels=self.emotions).astype(np.float32)
303
+ if percentage:
304
+ for i in range(len(matrix)):
305
+ matrix[i] = matrix[i] / np.sum(matrix[i])
306
+ # make it percentage
307
+ matrix *= 100
308
+ if labeled:
309
+ matrix = pd.DataFrame(matrix, index=[ f"true_{e}" for e in self.emotions ],
310
+ columns=[ f"predicted_{e}" for e in self.emotions ])
311
+ return matrix
312
+
313
+ def draw_confusion_matrix(self):
314
+ """Calculates the confusion matrix and shows it"""
315
+ matrix = self.confusion_matrix(percentage=False, labeled=False)
316
+ #TODO: add labels, title, legends, etc.
317
+ pl.imshow(matrix, cmap="binary")
318
+ pl.show()
319
+
320
+ def get_n_samples(self, emotion, partition):
321
+ """Returns number data samples of the `emotion` class in a particular `partition`
322
+ ('test' or 'train')
323
+ """
324
+ if partition == "test":
325
+ return len([y for y in self.y_test if y == emotion])
326
+ elif partition == "train":
327
+ return len([y for y in self.y_train if y == emotion])
328
+
329
+ def get_samples_by_class(self):
330
+ """
331
+ Returns a dataframe that contains the number of training
332
+ and testing samples for all emotions.
333
+ Note that if data isn't loaded yet, it'll be loaded
334
+ """
335
+ if not self.data_loaded:
336
+ self.load_data()
337
+ train_samples = []
338
+ test_samples = []
339
+ total = []
340
+ for emotion in self.emotions:
341
+ n_train = self.get_n_samples(emotion, "train")
342
+ n_test = self.get_n_samples(emotion, "test")
343
+ train_samples.append(n_train)
344
+ test_samples.append(n_test)
345
+ total.append(n_train + n_test)
346
+
347
+ # get total
348
+ total.append(sum(train_samples) + sum(test_samples))
349
+ train_samples.append(sum(train_samples))
350
+ test_samples.append(sum(test_samples))
351
+ return pd.DataFrame(data={"train": train_samples, "test": test_samples, "total": total}, index=self.emotions + ["total"])
352
+
353
+ def get_random_emotion(self, emotion, partition="train"):
354
+ """
355
+ Returns random `emotion` data sample index on `partition`.
356
+ """
357
+ if partition == "train":
358
+ index = random.choice(list(range(len(self.y_train))))
359
+ while self.y_train[index] != emotion:
360
+ index = random.choice(list(range(len(self.y_train))))
361
+ elif partition == "test":
362
+ index = random.choice(list(range(len(self.y_test))))
363
+ while self.y_train[index] != emotion:
364
+ index = random.choice(list(range(len(self.y_test))))
365
+ else:
366
+ raise TypeError("Unknown partition, only 'train' or 'test' is accepted")
367
+
368
+ return index
369
+
370
+
371
+ def plot_histograms(classifiers=True, beta=0.5, n_classes=3, verbose=1):
372
+ """
373
+ Loads different estimators from `grid` folder and calculate some statistics to plot histograms.
374
+ Params:
375
+ classifiers (bool): if `True`, this will plot classifiers, regressors otherwise.
376
+ beta (float): beta value for calculating fbeta score for various estimators.
377
+ n_classes (int): number of classes
378
+ """
379
+ # get the estimators from the performed grid search result
380
+ estimators = get_best_estimators(classifiers)
381
+
382
+ final_result = {}
383
+ for estimator, params, cv_score in estimators:
384
+ final_result[estimator.__class__.__name__] = []
385
+ for i in range(3):
386
+ result = {}
387
+ # initialize the class
388
+ detector = EmotionRecognizer(estimator, verbose=0)
389
+ # load the data
390
+ detector.load_data()
391
+ if i == 0:
392
+ # first get 1% of sample data
393
+ sample_size = 0.01
394
+ elif i == 1:
395
+ # second get 10% of sample data
396
+ sample_size = 0.1
397
+ elif i == 2:
398
+ # last get all the data
399
+ sample_size = 1
400
+ # calculate number of training and testing samples
401
+ n_train_samples = int(len(detector.X_train) * sample_size)
402
+ n_test_samples = int(len(detector.X_test) * sample_size)
403
+ # set the data
404
+ detector.X_train = detector.X_train[:n_train_samples]
405
+ detector.X_test = detector.X_test[:n_test_samples]
406
+ detector.y_train = detector.y_train[:n_train_samples]
407
+ detector.y_test = detector.y_test[:n_test_samples]
408
+ # calculate train time
409
+ t_train = time()
410
+ detector.train()
411
+ t_train = time() - t_train
412
+ # calculate test time
413
+ t_test = time()
414
+ test_accuracy = detector.test_score()
415
+ t_test = time() - t_test
416
+ # set the result to the dictionary
417
+ result['train_time'] = t_train
418
+ result['pred_time'] = t_test
419
+ result['acc_train'] = cv_score
420
+ result['acc_test'] = test_accuracy
421
+ result['f_train'] = detector.train_fbeta_score(beta)
422
+ result['f_test'] = detector.test_fbeta_score(beta)
423
+ if verbose:
424
+ print(f"[+] {estimator.__class__.__name__} with {sample_size*100}% ({n_train_samples}) data samples achieved {cv_score*100:.3f}% Validation Score in {t_train:.3f}s & {test_accuracy*100:.3f}% Test Score in {t_test:.3f}s")
425
+ # append the dictionary to the list of results
426
+ final_result[estimator.__class__.__name__].append(result)
427
+ if verbose:
428
+ print()
429
+ visualize(final_result, n_classes=n_classes)
430
+
431
+
432
+
433
+ def visualize(results, n_classes):
434
+ """
435
+ Visualization code to display results of various learners.
436
+
437
+ inputs:
438
+ - results: a dictionary of lists of dictionaries that contain various results on the corresponding estimator
439
+ - n_classes: number of classes
440
+ """
441
+
442
+ n_estimators = len(results)
443
+
444
+ # naive predictor
445
+ accuracy = 1 / n_classes
446
+ f1 = 1 / n_classes
447
+ # Create figure
448
+ fig, ax = pl.subplots(2, 4, figsize = (11,7))
449
+ # Constants
450
+ bar_width = 0.4
451
+ colors = [ (random.random(), random.random(), random.random()) for _ in range(n_estimators) ]
452
+ # Super loop to plot four panels of data
453
+ for k, learner in enumerate(results.keys()):
454
+ for j, metric in enumerate(['train_time', 'acc_train', 'f_train', 'pred_time', 'acc_test', 'f_test']):
455
+ for i in np.arange(3):
456
+ x = bar_width * n_estimators
457
+ # Creative plot code
458
+ ax[j//3, j%3].bar(i*x+k*(bar_width), results[learner][i][metric], width = bar_width, color = colors[k])
459
+ ax[j//3, j%3].set_xticks([x-0.2, x*2-0.2, x*3-0.2])
460
+ ax[j//3, j%3].set_xticklabels(["1%", "10%", "100%"])
461
+ ax[j//3, j%3].set_xlabel("Training Set Size")
462
+ ax[j//3, j%3].set_xlim((-0.2, x*3))
463
+ # Add unique y-labels
464
+ ax[0, 0].set_ylabel("Time (in seconds)")
465
+ ax[0, 1].set_ylabel("Accuracy Score")
466
+ ax[0, 2].set_ylabel("F-score")
467
+ ax[1, 0].set_ylabel("Time (in seconds)")
468
+ ax[1, 1].set_ylabel("Accuracy Score")
469
+ ax[1, 2].set_ylabel("F-score")
470
+ # Add titles
471
+ ax[0, 0].set_title("Model Training")
472
+ ax[0, 1].set_title("Accuracy Score on Training Subset")
473
+ ax[0, 2].set_title("F-score on Training Subset")
474
+ ax[1, 0].set_title("Model Predicting")
475
+ ax[1, 1].set_title("Accuracy Score on Testing Set")
476
+ ax[1, 2].set_title("F-score on Testing Set")
477
+ # Add horizontal lines for naive predictors
478
+ ax[0, 1].axhline(y = accuracy, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
479
+ ax[1, 1].axhline(y = accuracy, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
480
+ ax[0, 2].axhline(y = f1, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
481
+ ax[1, 2].axhline(y = f1, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
482
+ # Set y-limits for score panels
483
+ ax[0, 1].set_ylim((0, 1))
484
+ ax[0, 2].set_ylim((0, 1))
485
+ ax[1, 1].set_ylim((0, 1))
486
+ ax[1, 2].set_ylim((0, 1))
487
+ # Set additional plots invisibles
488
+ ax[0, 3].set_visible(False)
489
+ ax[1, 3].axis('off')
490
+ # Create legend
491
+ for i, learner in enumerate(results.keys()):
492
+ pl.bar(0, 0, color=colors[i], label=learner)
493
+ pl.legend()
494
+ # Aesthetics
495
+ pl.suptitle("Performance Metrics for Three Supervised Learning Models", fontsize = 16, y = 1.10)
496
+ pl.tight_layout()
497
+ pl.show()