|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Iterator |
|
|
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
from torchvision.transforms import InterpolationMode |
|
|
|
from .template import BaseSampler, CKPT_PATH |
|
|
|
|
|
class T2ISampler(BaseSampler): |
|
""" |
|
A sampler for text-to-image generation. |
|
:param root_dir: directory for model checkpoints. |
|
:param sampling_type: ["default", "fast"] |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root_dir: str, |
|
sampling_type: str = "default", |
|
): |
|
super().__init__(root_dir, sampling_type) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
root_dir: str, |
|
clip_model_path: str, |
|
clip_stat_path: str, |
|
sampling_type: str = "default", |
|
): |
|
|
|
model = cls( |
|
root_dir=root_dir, |
|
sampling_type=sampling_type, |
|
) |
|
model.load_clip(clip_model_path) |
|
model.load_prior( |
|
f"{CKPT_PATH['prior']}", |
|
clip_stat_path=clip_stat_path, |
|
prior_config="configs/karlo/prior_1B_vit_l.yaml" |
|
) |
|
model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml") |
|
model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml") |
|
return model |
|
|
|
def preprocess( |
|
self, |
|
prompt: str, |
|
bsz: int, |
|
): |
|
"""Setup prompts & cfg scales""" |
|
prompts_batch = [prompt for _ in range(bsz)] |
|
|
|
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) |
|
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") |
|
|
|
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) |
|
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") |
|
|
|
""" Get CLIP text feature """ |
|
clip_model = self._clip |
|
tokenizer = self._tokenizer |
|
max_txt_length = self._prior.model.text_ctx |
|
|
|
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) |
|
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) |
|
if not (cf_token.shape == tok.shape): |
|
cf_token = cf_token.expand(tok.shape[0], -1) |
|
cf_mask = cf_mask.expand(tok.shape[0], -1) |
|
|
|
tok = torch.cat([tok, cf_token], dim=0) |
|
mask = torch.cat([mask, cf_mask], dim=0) |
|
|
|
tok, mask = tok.to(device="cuda"), mask.to(device="cuda") |
|
txt_feat, txt_feat_seq = clip_model.encode_text(tok) |
|
|
|
return ( |
|
prompts_batch, |
|
prior_cf_scales_batch, |
|
decoder_cf_scales_batch, |
|
txt_feat, |
|
txt_feat_seq, |
|
tok, |
|
mask, |
|
) |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
bsz: int, |
|
progressive_mode=None, |
|
) -> Iterator[torch.Tensor]: |
|
assert progressive_mode in ("loop", "stage", "final") |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
( |
|
prompts_batch, |
|
prior_cf_scales_batch, |
|
decoder_cf_scales_batch, |
|
txt_feat, |
|
txt_feat_seq, |
|
tok, |
|
mask, |
|
) = self.preprocess( |
|
prompt, |
|
bsz, |
|
) |
|
|
|
""" Transform CLIP text feature into image feature """ |
|
img_feat = self._prior( |
|
txt_feat, |
|
txt_feat_seq, |
|
mask, |
|
prior_cf_scales_batch, |
|
timestep_respacing=self._prior_sm, |
|
) |
|
|
|
""" Generate 64x64px images """ |
|
images_64_outputs = self._decoder( |
|
txt_feat, |
|
txt_feat_seq, |
|
tok, |
|
mask, |
|
img_feat, |
|
cf_guidance_scales=decoder_cf_scales_batch, |
|
timestep_respacing=self._decoder_sm, |
|
) |
|
|
|
images_64 = None |
|
for k, out in enumerate(images_64_outputs): |
|
images_64 = out |
|
if progressive_mode == "loop": |
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) |
|
if progressive_mode == "stage": |
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) |
|
|
|
images_64 = torch.clamp(images_64, -1, 1) |
|
|
|
""" Upsample 64x64 to 256x256 """ |
|
images_256 = TVF.resize( |
|
images_64, |
|
[256, 256], |
|
interpolation=InterpolationMode.BICUBIC, |
|
antialias=True, |
|
) |
|
images_256_outputs = self._sr_64_256( |
|
images_256, timestep_respacing=self._sr_sm |
|
) |
|
|
|
for k, out in enumerate(images_256_outputs): |
|
images_256 = out |
|
if progressive_mode == "loop": |
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) |
|
if progressive_mode == "stage": |
|
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) |
|
|
|
yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0) |
|
|
|
|
|
class PriorSampler(BaseSampler): |
|
""" |
|
A sampler for text-to-image generation, but only the prior. |
|
:param root_dir: directory for model checkpoints. |
|
:param sampling_type: ["default", "fast"] |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root_dir: str, |
|
sampling_type: str = "default", |
|
): |
|
super().__init__(root_dir, sampling_type) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
root_dir: str, |
|
clip_model_path: str, |
|
clip_stat_path: str, |
|
sampling_type: str = "default", |
|
): |
|
model = cls( |
|
root_dir=root_dir, |
|
sampling_type=sampling_type, |
|
) |
|
model.load_clip(clip_model_path) |
|
model.load_prior( |
|
f"{CKPT_PATH['prior']}", |
|
clip_stat_path=clip_stat_path, |
|
prior_config="configs/karlo/prior_1B_vit_l.yaml" |
|
) |
|
return model |
|
|
|
def preprocess( |
|
self, |
|
prompt: str, |
|
bsz: int, |
|
): |
|
"""Setup prompts & cfg scales""" |
|
prompts_batch = [prompt for _ in range(bsz)] |
|
|
|
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) |
|
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") |
|
|
|
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) |
|
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") |
|
|
|
""" Get CLIP text feature """ |
|
clip_model = self._clip |
|
tokenizer = self._tokenizer |
|
max_txt_length = self._prior.model.text_ctx |
|
|
|
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) |
|
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) |
|
if not (cf_token.shape == tok.shape): |
|
cf_token = cf_token.expand(tok.shape[0], -1) |
|
cf_mask = cf_mask.expand(tok.shape[0], -1) |
|
|
|
tok = torch.cat([tok, cf_token], dim=0) |
|
mask = torch.cat([mask, cf_mask], dim=0) |
|
|
|
tok, mask = tok.to(device="cuda"), mask.to(device="cuda") |
|
txt_feat, txt_feat_seq = clip_model.encode_text(tok) |
|
|
|
return ( |
|
prompts_batch, |
|
prior_cf_scales_batch, |
|
decoder_cf_scales_batch, |
|
txt_feat, |
|
txt_feat_seq, |
|
tok, |
|
mask, |
|
) |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
bsz: int, |
|
progressive_mode=None, |
|
) -> Iterator[torch.Tensor]: |
|
assert progressive_mode in ("loop", "stage", "final") |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
( |
|
prompts_batch, |
|
prior_cf_scales_batch, |
|
decoder_cf_scales_batch, |
|
txt_feat, |
|
txt_feat_seq, |
|
tok, |
|
mask, |
|
) = self.preprocess( |
|
prompt, |
|
bsz, |
|
) |
|
|
|
""" Transform CLIP text feature into image feature """ |
|
img_feat = self._prior( |
|
txt_feat, |
|
txt_feat_seq, |
|
mask, |
|
prior_cf_scales_batch, |
|
timestep_respacing=self._prior_sm, |
|
) |
|
|
|
yield img_feat |
|
|