meow
init
d6d3a5b
raw
history blame
1.57 kB
import random
import time
import torch
import common.thing as thing
from common.ld_utils import ld2dl
def reweight_loss_by_keys(loss_dict, keys, alpha):
for key in keys:
val, weight = loss_dict[key]
weight_new = weight * alpha
loss_dict[key] = (val, weight_new)
return loss_dict
def select_loss_group(groups, agent_id, alphas):
random.seed(1)
random.shuffle(groups)
keys = groups[agent_id % len(groups)]
random.seed(time.time())
alpha = random.choice(alphas)
random.seed(1)
return keys, alpha
def push_checkpoint_metric(key, val):
val = float(val)
checkpt_metric = torch.FloatTensor([val])
result = {key: checkpt_metric}
return result
def avg_losses_cpu(outputs):
outputs = ld2dl(outputs)
for key, val in outputs.items():
val = [v.cpu() for v in val]
val = torch.cat(val, dim=0).view(-1)
outputs[key] = val.mean()
return outputs
def reform_outputs(out_list):
out_list_dict = ld2dl(out_list)
outputs = ld2dl(out_list_dict["out_dict"])
losses = ld2dl(out_list_dict["loss"])
for k, tensor in outputs.items():
if isinstance(tensor[0], list):
outputs[k] = sum(tensor, [])
else:
outputs[k] = torch.cat(tensor)
for k, tensor in losses.items():
tensor = [ten.view(-1) for ten in tensor]
losses[k] = torch.cat(tensor)
outputs = {k: thing.thing2np(v) for k, v in outputs.items()}
loss_dict = {k: v.mean().item() for k, v in losses.items()}
return outputs, loss_dict