Nadine Rueegg
initial commit for barc
7629b39
raw
history blame
2.69 kB
# Modified from:
# https://github.com/anibali/pytorch-stacked-hourglass
# https://github.com/bearpaw/pytorch-pose
import numpy as np
__all__ = ['Logger']
class Logger:
"""Log training metrics to a file."""
def __init__(self, fpath, resume=False):
if resume: ############################################################################
# Read header names and previously logged values.
with open(fpath, 'r') as f:
header_line = f.readline()
self.names = header_line.rstrip().split('\t')
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in f:
numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(float(numbers[i]))
self.file = open(fpath, 'a')
self.header_written = True
else:
self.file = open(fpath, 'w')
self.header_written = False
def _write_line(self, field_values):
self.file.write('\t'.join(field_values) + '\n')
self.file.flush()
def set_names(self, names):
"""Set field names and write log header line."""
assert not self.header_written, 'Log header has already been written'
self.names = names
self.numbers = {name: [] for name in self.names}
self._write_line(self.names)
self.header_written = True
def append(self, numbers):
"""Append values to the log."""
assert self.header_written, 'Log header has not been written yet (use `set_names`)'
assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers):
self.numbers[self.names[index]].append(num)
self._write_line(['{0:.6f}'.format(num) for num in numbers])
def plot(self, ax, names=None):
"""Plot logged metrics on a set of Matplotlib axes."""
names = self.names if names == None else names
for name in names:
values = self.numbers[name]
ax.plot(np.arange(len(values)), np.asarray(values))
ax.grid(True)
ax.legend(names, loc='best')
def plot_to_file(self, fpath, names=None, dpi=150):
"""Plot logged metrics and save the resulting figure to a file."""
import matplotlib.pyplot as plt
fig = plt.figure(dpi=dpi)
ax = fig.subplots()
self.plot(ax, names)
fig.savefig(fpath)
plt.close(fig)
del ax, fig
def close(self):
self.file.close()