Spaces:
Runtime error
Runtime error
import random | |
import time | |
import torch | |
import common.thing as thing | |
from common.ld_utils import ld2dl | |
def reweight_loss_by_keys(loss_dict, keys, alpha): | |
for key in keys: | |
val, weight = loss_dict[key] | |
weight_new = weight * alpha | |
loss_dict[key] = (val, weight_new) | |
return loss_dict | |
def select_loss_group(groups, agent_id, alphas): | |
random.seed(1) | |
random.shuffle(groups) | |
keys = groups[agent_id % len(groups)] | |
random.seed(time.time()) | |
alpha = random.choice(alphas) | |
random.seed(1) | |
return keys, alpha | |
def push_checkpoint_metric(key, val): | |
val = float(val) | |
checkpt_metric = torch.FloatTensor([val]) | |
result = {key: checkpt_metric} | |
return result | |
def avg_losses_cpu(outputs): | |
outputs = ld2dl(outputs) | |
for key, val in outputs.items(): | |
val = [v.cpu() for v in val] | |
val = torch.cat(val, dim=0).view(-1) | |
outputs[key] = val.mean() | |
return outputs | |
def reform_outputs(out_list): | |
out_list_dict = ld2dl(out_list) | |
outputs = ld2dl(out_list_dict["out_dict"]) | |
losses = ld2dl(out_list_dict["loss"]) | |
for k, tensor in outputs.items(): | |
if isinstance(tensor[0], list): | |
outputs[k] = sum(tensor, []) | |
else: | |
outputs[k] = torch.cat(tensor) | |
for k, tensor in losses.items(): | |
tensor = [ten.view(-1) for ten in tensor] | |
losses[k] = torch.cat(tensor) | |
outputs = {k: thing.thing2np(v) for k, v in outputs.items()} | |
loss_dict = {k: v.mean().item() for k, v in losses.items()} | |
return outputs, loss_dict | |