williamcfrancis's picture
Upload 74 files
e22b55b
raw
history blame
1.74 kB
from tensorflow.keras import callbacks
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import json
import os
class TrainMonitor(callbacks.BaseLogger):
def __init__(self, figPath, jsonPath=None, startAt=0):
super(TrainMonitor, self).__init__()
self.figPath= figPath
self.jsonPath= jsonPath
self.startAt= startAt
def on_train_begin(self, logs={}):
self.H={}
if self.jsonPath is not None:
if os.path.exists(self.jsonPath):
self.H = json.loads(open(self.jsonPath).read())
if self.startAt > 0:
for k in self.H.keys():
self.H[k] = self.H[k][:self.startAt]
def on_epoch_end(self, epoch, logs={}):
for keys, values in logs.items():
l= self.H.get(keys, [])
l.append(float(values))
self.H[keys] = l
if self.jsonPath is not None:
with open(self.jsonPath, 'w') as f:
f.write(json.dumps(self.H))
f.close()
if len(self.H["loss"]) > 1:
N = np.arange(0, len(self.H["loss"]), 1)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, self.H["loss"], label="train_loss")
plt.plot(N, self.H["val_loss"], label="val_loss")
plt.plot(N, self.H["accuracy"], label="train_acc")
plt.plot(N, self.H["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy [Epoch {}]".format(len(self.H["loss"])))
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(self.figPath)
plt.close()