""" Reference: https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/modules/attention.py """ import torch import q8_kernels.functional as Q8F from diffusers.models.transformers.transformer_ltx import apply_rotary_emb from diffusers.models.attention import Attention NON_MM_PRECISION_TYPE = torch.bfloat16 MM_PRECISION_TYPE = torch.bfloat16 class LTXVideoQ8AttentionProcessor: def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None, ) -> torch.Tensor: if attention_mask is not None and attention_mask.ndim > 1: attention_mask = attention_mask.argmin(-1).squeeze().int() if encoder_hidden_states is None: encoder_hidden_states = hidden_states query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.norm_q(query, NON_MM_PRECISION_TYPE) key = attn.norm_k(key, NON_MM_PRECISION_TYPE) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) hidden_states = Q8F.flash_attention.flash_attn_func( query, key, value, batch_mask=attention_mask, apply_qk_hadamard=True ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states.to(NON_MM_PRECISION_TYPE)