Spaces:
Paused
Paused
File size: 496 Bytes
d3653d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from diffusers import ModelMixin
import torch.nn as nn
class ByT5Mapper(ModelMixin):
def __init__(self, byt5_output_dim, sdxl_text_dim):
super().__init__()
self.mapper = nn.Sequential(
nn.LayerNorm(byt5_output_dim),
nn.Linear(byt5_output_dim, sdxl_text_dim),
nn.ReLU(),
nn.Linear(sdxl_text_dim, sdxl_text_dim)
)
def forward(self, byt5_embedding):
return self.mapper(byt5_embedding)
|