Spaces:
Runtime error
Runtime error
from typing import Tuple, Union, Literal | |
from einops import repeat | |
import torch | |
import numpy as np | |
def get_diags_indices( | |
shape: Union[int, Tuple[int, int]], k_min: int = 0, k_max: int = 0 | |
): | |
if isinstance(shape, int): | |
shape = (shape, shape) | |
rows, cols = np.indices(shape) | |
diag = cols - rows | |
return np.where((diag >= k_min) & (diag <= k_max)) | |
def generate_mask_from_indices( | |
shape: Tuple[int, int], | |
indices: Tuple[np.ndarray, np.ndarray], | |
big_value: float = 0, | |
small_value: float = -1e9, | |
): | |
matrix = np.ones(shape) * small_value | |
matrix[indices] = big_value | |
return matrix | |
def generate_sparse_causcal_attn_mask( | |
batch_size: int, | |
n: int, | |
n_near: int = 1, | |
big_value: float = 0, | |
small_value: float = -1e9, | |
out_type: Literal["torch", "numpy"] = "numpy", | |
expand: int = 1, | |
) -> np.ndarray: | |
"""generate b (n expand) (n expand) mask, | |
where value of diag (0<=<=n_near) and first column of shape mat (n n) is set as big_value, others as small value | |
expand的概念: | |
attn 是 b n d 时,mask 是 b n n, 当 attn 是 b (expand n) d 时, mask 是 b (n expand) (n expand) | |
Args: | |
batch_size (int): _description_ | |
n (int): _description_ | |
n_near (int, optional): _description_. Defaults to 1. | |
big_value (float, optional): _description_. Defaults to 0. | |
small_value (float, optional): _description_. Defaults to -1e9. | |
out_type (Literal["torch", "numpy"], optional): _description_. Defaults to "numpy". | |
expand (int, optional): _description_. Defaults to 1. | |
Returns: | |
np.ndarray: _description_ | |
""" | |
shape = (n, n) | |
diag_indices = get_diags_indices(n, k_min=-n_near, k_max=0) | |
first_column = (np.arange(n), np.zeros(n).astype(np.int)) | |
indices = ( | |
np.concatenate([diag_indices[0], first_column[0]]), | |
np.concatenate([diag_indices[1], first_column[1]]), | |
) | |
mask = generate_mask_from_indices( | |
shape=shape, indices=indices, big_value=big_value, small_value=small_value | |
) | |
mask = repeat(mask, "m n-> b m n", b=batch_size) | |
if expand > 1: | |
mask = repeat( | |
mask, | |
"b m n -> b (m d1) (n d2)", | |
d1=expand, | |
d2=expand, | |
) | |
if out_type == "torch": | |
mask = torch.from_numpy(mask) | |
return mask | |