File size: 2,688 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

# Modified from:
#   https://github.com/anibali/pytorch-stacked-hourglass 
#   https://github.com/bearpaw/pytorch-pose

import numpy as np

__all__ = ['Logger']


class Logger:
    """Log training metrics to a file."""
    def __init__(self, fpath, resume=False):
        if resume:   ############################################################################
            # Read header names and previously logged values.
            with open(fpath, 'r') as f:
                header_line = f.readline()
                self.names = header_line.rstrip().split('\t')
                self.numbers = {}
                for _, name in enumerate(self.names):
                    self.numbers[name] = []
                for numbers in f:
                    numbers = numbers.rstrip().split('\t')
                    for i in range(0, len(numbers)):
                        self.numbers[self.names[i]].append(float(numbers[i]))

            self.file = open(fpath, 'a')
            self.header_written = True
        else:
            self.file = open(fpath, 'w')
            self.header_written = False

    def _write_line(self, field_values):
        self.file.write('\t'.join(field_values) + '\n')
        self.file.flush()

    def set_names(self, names):
        """Set field names and write log header line."""
        assert not self.header_written, 'Log header has already been written'
        self.names = names
        self.numbers = {name: [] for name in self.names}
        self._write_line(self.names)
        self.header_written = True

    def append(self, numbers):
        """Append values to the log."""
        assert self.header_written, 'Log header has not been written yet (use `set_names`)'
        assert len(self.names) == len(numbers), 'Numbers do not match names'
        for index, num in enumerate(numbers):
            self.numbers[self.names[index]].append(num)
        self._write_line(['{0:.6f}'.format(num) for num in numbers])

    def plot(self, ax, names=None):
        """Plot logged metrics on a set of Matplotlib axes."""
        names = self.names if names == None else names
        for name in names:
            values = self.numbers[name]
            ax.plot(np.arange(len(values)), np.asarray(values))
        ax.grid(True)
        ax.legend(names, loc='best')

    def plot_to_file(self, fpath, names=None, dpi=150):
        """Plot logged metrics and save the resulting figure to a file."""
        import matplotlib.pyplot as plt
        fig = plt.figure(dpi=dpi)
        ax = fig.subplots()
        self.plot(ax, names)
        fig.savefig(fpath)
        plt.close(fig)
        del ax, fig

    def close(self):
        self.file.close()