import numpy as np import torch import common.thing as thing def _print_stat(key, thing): """ Helper function for printing statistics about a key-value pair in an xdict. """ mytype = type(thing) if isinstance(thing, (list, tuple)): print("{:<20}: {:<30}\t{:}".format(key, len(thing), mytype)) elif isinstance(thing, (torch.Tensor)): dev = thing.device shape = str(thing.shape).replace(" ", "") print("{:<20}: {:<30}\t{:}\t{}".format(key, shape, mytype, dev)) elif isinstance(thing, (np.ndarray)): dev = "" shape = str(thing.shape).replace(" ", "") print("{:<20}: {:<30}\t{:}".format(key, shape, mytype)) else: print("{:<20}: {:}".format(key, mytype)) class xdict(dict): """ A subclass of Python's built-in dict class, which provides additional methods for manipulating and operating on dictionaries. """ def __init__(self, mydict=None): """ Constructor for the xdict class. Creates a new xdict object and optionally initializes it with key-value pairs from the provided dictionary mydict. If mydict is not provided, an empty xdict is created. """ if mydict is None: return for k, v in mydict.items(): super().__setitem__(k, v) def subset(self, keys): """ Returns a new xdict object containing only the key-value pairs with keys in the provided list 'keys'. """ out_dict = {} for k in keys: out_dict[k] = self[k] return xdict(out_dict) def __setitem__(self, key, val): """ Overrides the dict.__setitem__ method to raise an assertion error if a key already exists. """ assert key not in self.keys(), f"Key already exists {key}" super().__setitem__(key, val) def search(self, keyword, replace_to=None): """ Returns a new xdict object containing only the key-value pairs with keys that contain the provided keyword. """ out_dict = {} for k in self.keys(): if keyword in k: if replace_to is None: out_dict[k] = self[k] else: out_dict[k.replace(keyword, replace_to)] = self[k] return xdict(out_dict) def rm(self, keyword, keep_list=[], verbose=False): """ Returns a new xdict object with keys that contain keyword removed. Keys in keep_list are excluded from the removal. """ out_dict = {} for k in self.keys(): if keyword not in k or k in keep_list: out_dict[k] = self[k] else: if verbose: print(f"Removing: {k}") return xdict(out_dict) def overwrite(self, k, v): """ The original assignment operation of Python dict """ super().__setitem__(k, v) def merge(self, dict2): """ Same as dict.update(), but raises an assertion error if there are duplicate keys between the two dictionaries. Args: dict2 (dict or xdict): The dictionary or xdict instance to merge with. Raises: AssertionError: If dict2 is not a dictionary or xdict instance. AssertionError: If there are duplicate keys between the two instances. """ assert isinstance(dict2, (dict, xdict)) mykeys = set(self.keys()) intersect = mykeys.intersection(set(dict2.keys())) assert len(intersect) == 0, f"Merge failed: duplicate keys ({intersect})" self.update(dict2) def mul(self, scalar): """ Multiplies each value (could be tensor, np.array, list) in the xdict instance by the provided scalar. Args: scalar (float): The scalar to multiply the values by. Raises: AssertionError: If scalar is not a float. """ if isinstance(scalar, int): scalar = 1.0 * scalar assert isinstance(scalar, float) out_dict = {} for k in self.keys(): if isinstance(self[k], list): out_dict[k] = [v * scalar for v in self[k]] else: out_dict[k] = self[k] * scalar return xdict(out_dict) def prefix(self, text): """ Adds a prefix to each key in the xdict instance. Args: text (str): The prefix to add. Returns: xdict: The xdict instance with the added prefix. """ out_dict = {} for k in self.keys(): out_dict[text + k] = self[k] return xdict(out_dict) def replace_keys(self, str_src, str_tar): """ Replaces a substring in all keys of the xdict instance. Args: str_src (str): The substring to replace. str_tar (str): The replacement string. Returns: xdict: The xdict instance with the replaced keys. """ out_dict = {} for k in self.keys(): old_key = k new_key = old_key.replace(str_src, str_tar) out_dict[new_key] = self[k] return xdict(out_dict) def postfix(self, text): """ Adds a postfix to each key in the xdict instance. Args: text (str): The postfix to add. Returns: xdict: The xdict instance with the added postfix. """ out_dict = {} for k in self.keys(): out_dict[k + text] = self[k] return xdict(out_dict) def sorted_keys(self): """ Returns a sorted list of the keys in the xdict instance. Returns: list: A sorted list of keys in the xdict instance. """ return sorted(list(self.keys())) def to(self, dev): """ Moves the xdict instance to a specific device. Args: dev (torch.device): The device to move the instance to. Returns: xdict: The xdict instance moved to the specified device. """ if dev is None: return self raw_dict = dict(self) return xdict(thing.thing2dev(raw_dict, dev)) def to_torch(self): """ Converts elements in the xdict to Torch tensors and returns a new xdict. Returns: xdict: A new xdict with Torch tensors as values. """ return xdict(thing.thing2torch(self)) def to_np(self): """ Converts elements in the xdict to numpy arrays and returns a new xdict. Returns: xdict: A new xdict with numpy arrays as values. """ return xdict(thing.thing2np(self)) def tolist(self): """ Converts elements in the xdict to Python lists and returns a new xdict. Returns: xdict: A new xdict with Python lists as values. """ return xdict(thing.thing2list(self)) def print_stat(self): """ Prints statistics for each item in the xdict. """ for k, v in self.items(): _print_stat(k, v) def detach(self): """ Detaches all Torch tensors in the xdict from the computational graph and moves them to the CPU. Non-tensor objects are ignored. Returns: xdict: A new xdict with detached Torch tensors as values. """ return xdict(thing.detach_thing(self)) def has_invalid(self): """ Checks if any of the Torch tensors in the xdict contain NaN or Inf values. Returns: bool: True if at least one tensor contains NaN or Inf values, False otherwise. """ for k, v in self.items(): if isinstance(v, torch.Tensor): if torch.isnan(v).any(): print(f"{k} contains nan values") return True if torch.isinf(v).any(): print(f"{k} contains inf values") return True return False def apply(self, operation, criterion=None): """ Applies an operation to the values in the xdict, based on an optional criterion. Args: operation (callable): A callable object that takes a single argument and returns a value. criterion (callable, optional): A callable object that takes two arguments (key and value) and returns a boolean. Returns: xdict: A new xdict with the same keys as the original, but with the values modified by the operation. """ out = {} for k, v in self.items(): if criterion is None or criterion(k, v): out[k] = operation(v) return xdict(out) def save(self, path, dev=None, verbose=True): """ Saves the xdict to disk as a Torch tensor. Args: path (str): The path to save the xdict. dev (torch.device, optional): The device to use for saving the tensor (default is CPU). verbose (bool, optional): Whether to print a message indicating that the xdict has been saved (default is True). """ if verbose: print(f"Saving to {path}") torch.save(self.to(dev), path)