Glyph-SDXL-v2 / glyph_sdxl /modules /simple_byt5_mapper.py
lzy-tony
feat: add english ver
d3653d5
raw
history blame
496 Bytes
from diffusers import ModelMixin
import torch.nn as nn
class ByT5Mapper(ModelMixin):
def __init__(self, byt5_output_dim, sdxl_text_dim):
super().__init__()
self.mapper = nn.Sequential(
nn.LayerNorm(byt5_output_dim),
nn.Linear(byt5_output_dim, sdxl_text_dim),
nn.ReLU(),
nn.Linear(sdxl_text_dim, sdxl_text_dim)
)
def forward(self, byt5_embedding):
return self.mapper(byt5_embedding)