# import torch import math import jax import jax.numpy as jnp from einops import rearrange from flax import nnx Tensor=jax.Array def check_tpu(): return any('TPU' in d.device_kind for d in jax.devices()) # from torch import Tensor if check_tpu(): from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention # q, # [batch_size, num_heads, q_seq_len, d_model] # k, # [batch_size, num_heads, kv_seq_len, d_model] # v, # [batch_size, num_heads, kv_seq_len, d_model] def flash_mha(q, k, v): return flash_attention(q, k, v, sm_scale=1/math.sqrt(q.shape[-1])) else: from jax.experimental.pallas.ops.gpu.attention import mha, mha_reference def pallas_mha(q, k, v): # B L H D # return mha_reference(q,k,v,segment_ids=None,sm_scale=1/math.sqrt(q.shape[-1])) q_len=q.shape[1] diff=(-q_len)&127 segment_ids=jnp.zeros((q.shape[0],q.shape[1]),dtype=jnp.int32) segment_ids=jnp.pad(segment_ids,((0,0),(0,diff)),mode="constant",constant_values=1) # q,k,v=map(lambda x: jnp.pad(x,((0,0),(0,diff),(0,0),(0,0)),mode="constant", constant_values=0),(q,k,v)) return mha(q,k,v,segment_ids=segment_ids,sm_scale=1/math.sqrt(q.shape[-1]))#[:,:q_len] # mha: batch_size, seq_len, num_heads, head_dim = q.shape from functools import partial from flux.modules.attention_flax import jax_memory_efficient_attention try: from flash_attn_jax import flash_mha except: flash_mha = pallas_mha # flash_mha = nnx.dot_product_attention def dot_product_attention(q, k, v, sm_scale=1.0): q,k,v=map(lambda x: rearrange(x, "b h n d -> b n h d"), (q,k,v)) # ret = pallas_mha(q,k,v) ret = nnx.dot_product_attention(q,k,v) # if q.shape[-3] % 64 == 0: # query_chunk_size = int(q.shape[-3] / 64) # elif q.shape[-3] % 16 == 0: # query_chunk_size = int(q.shape[-3] / 16) # elif q.shape[-3] % 4 == 0: # query_chunk_size = int(q.shape[-3] / 4) # else: # query_chunk_size = int(q.shape[-3]) # ret=jax_memory_efficient_attention(q, k, v, query_chunk_size=query_chunk_size) return rearrange(ret, "b n h d -> b h n d") def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: q, k = apply_rope(q, k, pe) # x = torch.nn.functional.scaled_dot_product_attention(q, k, v) # q is B H L D q,k,v=map(lambda x: rearrange(x, "B H L D -> B L H D"), (q,k,v)) # x = nnx.dot_product_attention(q,k,v) x = flash_mha(q,k,v) # x = pallas_mha(q,k,v) # x = mha(q,k,v,None,sm_scale=1/math.sqrt(q.shape[-1])) x = rearrange(x, "B L H D -> B L (H D)") # x = rearrange(x, "B H L D -> B L (H D)") return x def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 # scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim omega = 1.0 / (theta**scale) # out = torch.einsum("...n,d->...nd", pos, omega) out = jnp.einsum("...n,d->...nd", pos.astype(jnp.float32), omega) # out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) # return out.float() return out.astype(jnp.float32) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: # xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) # xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] # return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)