|
|
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
|
|
|
|
class LogBuffer: |
|
|
|
def __init__(self): |
|
self.val_history = OrderedDict() |
|
self.n_history = OrderedDict() |
|
self.output = OrderedDict() |
|
self.ready = False |
|
|
|
def clear(self): |
|
self.val_history.clear() |
|
self.n_history.clear() |
|
self.clear_output() |
|
|
|
def clear_output(self): |
|
self.output.clear() |
|
self.ready = False |
|
|
|
def update(self, vars, count=1): |
|
assert isinstance(vars, dict) |
|
for key, var in vars.items(): |
|
if key not in self.val_history: |
|
self.val_history[key] = [] |
|
self.n_history[key] = [] |
|
self.val_history[key].append(var) |
|
self.n_history[key].append(count) |
|
|
|
def average(self, n=0): |
|
"""Average latest n values or all values.""" |
|
assert n >= 0 |
|
for key in self.val_history: |
|
values = np.array(self.val_history[key][-n:]) |
|
nums = np.array(self.n_history[key][-n:]) |
|
avg = np.sum(values * nums) / np.sum(nums) |
|
self.output[key] = avg |
|
self.ready = True |
|
|