winfred2027's picture
Update demo_support/generation.py
257d404 verified
raw
history blame
7.99 kB
import torch
import torch_redstone as rst
import transformers
import numpy as np
from torch import nn
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from huggingface_hub import hf_hub_download
from diffusers import StableUnCLIPImg2ImgPipeline
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]
D = torch.device
class Wrapper(transformers.modeling_utils.PreTrainedModel):
def __init__(self) -> None:
super().__init__(transformers.configuration_utils.PretrainedConfig())
self.param = torch.nn.Parameter(torch.tensor(0.))
def forward(self, x):
return rst.ObjectProxy(image_embeds=x)
class MLP(nn.Module):
def forward(self, x: T) -> T:
return self.model(x)
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) -1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
class ClipCaptionModel(nn.Module):
#@functools.lru_cache #FIXME
def get_dummy_token(self, batch_size: int, device: D) -> T:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
#print(embedding_text.size()) #torch.Size([5, 67, 768])
#print(prefix_projections.size()) #torch.Size([5, 1, 768])
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
def __init__(self, prefix_length: int, prefix_size: int = 512):
super(ClipCaptionModel, self).__init__()
self.prefix_length = prefix_length
self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if prefix_length > 10: # not enough memory
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
else:
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
class ClipCaptionPrefix(ClipCaptionModel):
def parameters(self, recurse: bool = True):
return self.clip_project.parameters()
def train(self, mode: bool = True):
super(ClipCaptionPrefix, self).train(mode)
self.gpt.eval()
return self
def generate2(
model,
tokenizer,
tokens=None,
prompt=None,
embed=None,
entry_count=1,
entry_length=67, # maximum number of words
top_p=0.8,
temperature=1.,
stop_token: str = '.',
):
model.eval()
generated_num = 0
generated_list = []
stop_token_index = tokenizer.encode(stop_token)[0]
filter_value = -float("Inf")
device = next(model.parameters()).device
score_col = []
with torch.no_grad():
for entry_idx in range(entry_count):
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
score_col.append(score)
next_token_embed = model.gpt.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
generated = torch.cat((generated, next_token_embed), dim=1)
if stop_token_index == next_token.item():
break
output_list = list(tokens.squeeze(0).cpu().numpy())
output_text = tokenizer.decode(output_list)
generated_list.append(output_text)
return generated_list[0]
@torch.no_grad()
def pc_to_text(pc_encoder: torch.nn.Module, pc, cond_scale):
ref_dev = next(pc_encoder.parameters()).device
prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
prefix = prefix.float() * cond_scale
prefix = prefix.to(next(model.parameters()).device)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
text = generate2(model, tokenizer, embed=prefix_embed)
return text
@torch.no_grad()
def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
ref_dev = next(pc_encoder.parameters()).device
enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
enc = torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2
if torch.cuda.is_available():
enc = enc.to('cuda:' + str(torch.cuda.current_device()))
# enc = enc.type(half)
# with torch.autocast("cuda"):
return pipe(
prompt=', '.join(["best quality"] + ([prompt] if prompt else [])),
negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
image=enc,
width=width, height=height,
guidance_scale=cfg_scale,
noise_level=noise_level,
callback=callback,
num_inference_steps=num_steps
).images[0]
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
"diffusers/stable-diffusion-2-1-unclip-i2i-l",
# variant="fp16",
image_encoder = Wrapper()
)
# pe = pipe.text_encoder.text_model.embeddings
# pe.position_ids = torch.arange(pe.position_ids.shape[-1]).expand((1, -1)).to(pe.position_ids) # workaround
if torch.cuda.is_available():
pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
pipe.enable_model_cpu_offload(torch.cuda.current_device())
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
prefix_length = 10
model = ClipCaptionModel(prefix_length)
# print(model.gpt_embedding_size)
model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt'), map_location='cpu'))
model.eval()
if torch.cuda.is_available():
model = model.cuda()