|
|
|
import torch |
|
import numpy as np |
|
from typing import Optional, Dict, Union, Any |
|
from functools import wraps |
|
|
|
class TeaCacheConfig: |
|
"""Configuration for TeaCache acceleration""" |
|
def __init__( |
|
self, |
|
rel_l1_thresh: float = 0.15, |
|
enable: bool = True |
|
): |
|
self.rel_l1_thresh = rel_l1_thresh |
|
self.enable = enable |
|
self._reset_state() |
|
|
|
def _reset_state(self): |
|
"""Reset internal state""" |
|
self.cnt = 0 |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = None |
|
self.previous_residual = None |
|
|
|
def create_teacache_forward(original_forward): |
|
"""Factory function to create a TeaCache-enabled forward pass""" |
|
@wraps(original_forward) |
|
def teacache_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
timestep: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
pooled_projections: Optional[torch.Tensor] = None, |
|
guidance: Optional[torch.Tensor] = None, |
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
return_dict: bool = True, |
|
): |
|
|
|
if not hasattr(self, 'teacache_config') or not self.teacache_config.enable: |
|
return original_forward( |
|
self, |
|
hidden_states=hidden_states, |
|
timestep=timestep, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
pooled_projections=pooled_projections, |
|
guidance=guidance, |
|
attention_kwargs=attention_kwargs, |
|
return_dict=return_dict |
|
) |
|
|
|
config = self.teacache_config |
|
|
|
|
|
if pooled_projections is not None: |
|
vec = self.vector_in(pooled_projections) |
|
|
|
if guidance is not None: |
|
if vec is None: |
|
vec = self.guidance_in(guidance) |
|
else: |
|
vec = vec + self.guidance_in(guidance) |
|
|
|
|
|
inp = hidden_states.clone() |
|
if hasattr(self.double_blocks[0], 'img_norm1'): |
|
|
|
img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1) |
|
normed_inp = self.double_blocks[0].img_norm1(inp) |
|
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift |
|
else: |
|
|
|
normed_inp = self.transformer_blocks[0].norm1(inp) |
|
modulated_inp = normed_inp |
|
|
|
|
|
should_calc = True |
|
if config.cnt == 0 or config.cnt == self.num_inference_steps - 1: |
|
should_calc = True |
|
config.accumulated_rel_l1_distance = 0 |
|
elif config.previous_modulated_input is not None: |
|
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, |
|
-3.14987800e+00, 9.61237896e-02] |
|
rescale_func = np.poly1d(coefficients) |
|
|
|
rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() / |
|
config.previous_modulated_input.abs().mean()).cpu().item() |
|
config.accumulated_rel_l1_distance += rescale_func(rel_l1) |
|
|
|
should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh |
|
if should_calc: |
|
config.accumulated_rel_l1_distance = 0 |
|
|
|
config.previous_modulated_input = modulated_inp |
|
config.cnt += 1 |
|
if config.cnt >= self.num_inference_steps: |
|
config.cnt = 0 |
|
|
|
|
|
if not should_calc and config.previous_residual is not None: |
|
hidden_states += config.previous_residual |
|
else: |
|
ori_hidden_states = hidden_states.clone() |
|
|
|
|
|
out = original_forward( |
|
self, |
|
hidden_states=hidden_states, |
|
timestep=timestep, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
pooled_projections=pooled_projections, |
|
guidance=guidance, |
|
attention_kwargs=attention_kwargs, |
|
return_dict=True |
|
) |
|
hidden_states = out["sample"] |
|
|
|
|
|
config.previous_residual = hidden_states - ori_hidden_states |
|
|
|
if not return_dict: |
|
return (hidden_states,) |
|
|
|
return {"sample": hidden_states} |
|
|
|
return teacache_forward |
|
|
|
def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15): |
|
"""Enable TeaCache acceleration for a model""" |
|
if not hasattr(model, '_original_forward'): |
|
model._original_forward = model.forward |
|
|
|
model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh) |
|
model.num_inference_steps = num_inference_steps |
|
model.forward = create_teacache_forward(model._original_forward).__get__(model) |
|
|
|
def disable_teacache(model: Any): |
|
"""Disable TeaCache acceleration for a model""" |
|
if hasattr(model, '_original_forward'): |
|
model.forward = model._original_forward |
|
del model._original_forward |
|
|
|
if hasattr(model, 'teacache_config'): |
|
del model.teacache_config |