Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel | |
class AbstractEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def encode(self, *args, **kwargs): | |
raise NotImplementedError | |
class IdentityEncoder(AbstractEncoder): | |
def encode(self, x): | |
return x | |
class ClassEmbedder(nn.Module): | |
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): | |
super().__init__() | |
self.key = key | |
self.embedding = nn.Embedding(n_classes, embed_dim) | |
self.n_classes = n_classes | |
self.ucg_rate = ucg_rate | |
def forward(self, batch, key=None, disable_dropout=False): | |
if key is None: | |
key = self.key | |
# this is for use in crossattn | |
c = batch[key][:, None] | |
if self.ucg_rate > 0. and not disable_dropout: | |
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) | |
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) | |
c = c.long() | |
c = self.embedding(c) | |
return c | |
def get_unconditional_conditioning(self, bs, device="cuda"): | |
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) | |
uc = torch.ones((bs,), device=device) * uc_class | |
uc = {self.key: uc} | |
return uc | |
class DanbooruEmbedder(AbstractEncoder): | |
def __init__(self): | |
super().__init__() |