jbilcke-hf's picture
jbilcke-hf HF staff
Upload 5 files
f8498f5 verified
# teacache.py
import torch
import numpy as np
from typing import Optional, Dict, Union, Any
from functools import wraps
class TeaCacheConfig:
"""Configuration for TeaCache acceleration"""
def __init__(
self,
rel_l1_thresh: float = 0.15,
enable: bool = True
):
self.rel_l1_thresh = rel_l1_thresh
self.enable = enable
self._reset_state()
def _reset_state(self):
"""Reset internal state"""
self.cnt = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
def create_teacache_forward(original_forward):
"""Factory function to create a TeaCache-enabled forward pass"""
@wraps(original_forward)
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
pooled_projections: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
# Skip TeaCache if not enabled
if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
return original_forward(
self,
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
pooled_projections=pooled_projections,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=return_dict
)
config = self.teacache_config
# Prepare modulation vectors similar to HunyuanVideo implementation
if pooled_projections is not None:
vec = self.vector_in(pooled_projections)
if guidance is not None:
if vec is None:
vec = self.guidance_in(guidance)
else:
vec = vec + self.guidance_in(guidance)
# TeaCache optimization logic
inp = hidden_states.clone()
if hasattr(self.double_blocks[0], 'img_norm1'):
# HunyuanVideo specific modulation
img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)
normed_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
else:
# Fallback modulation
normed_inp = self.transformer_blocks[0].norm1(inp)
modulated_inp = normed_inp
# Determine if we should calculate or use cache
should_calc = True
if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
should_calc = True
config.accumulated_rel_l1_distance = 0
elif config.previous_modulated_input is not None:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01,
-3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() /
config.previous_modulated_input.abs().mean()).cpu().item()
config.accumulated_rel_l1_distance += rescale_func(rel_l1)
should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
if should_calc:
config.accumulated_rel_l1_distance = 0
config.previous_modulated_input = modulated_inp
config.cnt += 1
if config.cnt >= self.num_inference_steps:
config.cnt = 0
# Use cache or calculate new result
if not should_calc and config.previous_residual is not None:
hidden_states += config.previous_residual
else:
ori_hidden_states = hidden_states.clone()
# Use original forward pass
out = original_forward(
self,
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
pooled_projections=pooled_projections,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=True
)
hidden_states = out["sample"]
# Store residual for future use
config.previous_residual = hidden_states - ori_hidden_states
if not return_dict:
return (hidden_states,)
return {"sample": hidden_states}
return teacache_forward
def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
"""Enable TeaCache acceleration for a model"""
if not hasattr(model, '_original_forward'):
model._original_forward = model.forward
model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
model.num_inference_steps = num_inference_steps
model.forward = create_teacache_forward(model._original_forward).__get__(model)
def disable_teacache(model: Any):
"""Disable TeaCache acceleration for a model"""
if hasattr(model, '_original_forward'):
model.forward = model._original_forward
del model._original_forward
if hasattr(model, 'teacache_config'):
del model.teacache_config