Spaces:
Paused
Paused
from diffusers import ModelMixin | |
import torch.nn as nn | |
class ByT5Mapper(ModelMixin): | |
def __init__(self, byt5_output_dim, sdxl_text_dim): | |
super().__init__() | |
self.mapper = nn.Sequential( | |
nn.LayerNorm(byt5_output_dim), | |
nn.Linear(byt5_output_dim, sdxl_text_dim), | |
nn.ReLU(), | |
nn.Linear(sdxl_text_dim, sdxl_text_dim) | |
) | |
def forward(self, byt5_embedding): | |
return self.mapper(byt5_embedding) | |