winfred2027 commited on
Commit
257d404
·
verified ·
1 Parent(s): 8a7f4ef

Update demo_support/generation.py

Browse files
Files changed (1) hide show
  1. demo_support/generation.py +203 -203
demo_support/generation.py CHANGED
@@ -1,204 +1,204 @@
1
- import torch
2
- import torch_redstone as rst
3
- import transformers
4
- import numpy as np
5
- from torch import nn
6
- from typing import Tuple, List, Union, Optional
7
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
8
- from huggingface_hub import hf_hub_download
9
- from diffusers import StableUnCLIPImg2ImgPipeline
10
-
11
-
12
- class Wrapper(transformers.modeling_utils.PreTrainedModel):
13
- def __init__(self) -> None:
14
- super().__init__(transformers.configuration_utils.PretrainedConfig())
15
- self.param = torch.nn.Parameter(torch.tensor(0.))
16
-
17
- def forward(self, x):
18
- return rst.ObjectProxy(image_embeds=x)
19
-
20
- class MLP(nn.Module):
21
-
22
- def forward(self, x: T) -> T:
23
- return self.model(x)
24
-
25
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
26
- super(MLP, self).__init__()
27
- layers = []
28
- for i in range(len(sizes) -1):
29
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
30
- if i < len(sizes) - 2:
31
- layers.append(act())
32
- self.model = nn.Sequential(*layers)
33
-
34
- class ClipCaptionModel(nn.Module):
35
-
36
- #@functools.lru_cache #FIXME
37
- def get_dummy_token(self, batch_size: int, device: D) -> T:
38
- return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
39
-
40
- def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
41
- embedding_text = self.gpt.transformer.wte(tokens)
42
- prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
43
- #print(embedding_text.size()) #torch.Size([5, 67, 768])
44
- #print(prefix_projections.size()) #torch.Size([5, 1, 768])
45
- embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
46
- if labels is not None:
47
- dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
48
- labels = torch.cat((dummy_token, tokens), dim=1)
49
- out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
50
- return out
51
-
52
- def __init__(self, prefix_length: int, prefix_size: int = 512):
53
- super(ClipCaptionModel, self).__init__()
54
- self.prefix_length = prefix_length
55
- self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
56
- self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
57
- if prefix_length > 10: # not enough memory
58
- self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
59
- else:
60
- self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
61
-
62
- class ClipCaptionPrefix(ClipCaptionModel):
63
-
64
- def parameters(self, recurse: bool = True):
65
- return self.clip_project.parameters()
66
-
67
- def train(self, mode: bool = True):
68
- super(ClipCaptionPrefix, self).train(mode)
69
- self.gpt.eval()
70
- return self
71
-
72
- def generate2(
73
- model,
74
- tokenizer,
75
- tokens=None,
76
- prompt=None,
77
- embed=None,
78
- entry_count=1,
79
- entry_length=67, # maximum number of words
80
- top_p=0.8,
81
- temperature=1.,
82
- stop_token: str = '.',
83
- ):
84
- model.eval()
85
- generated_num = 0
86
- generated_list = []
87
- stop_token_index = tokenizer.encode(stop_token)[0]
88
- filter_value = -float("Inf")
89
- device = next(model.parameters()).device
90
- score_col = []
91
- with torch.no_grad():
92
-
93
- for entry_idx in range(entry_count):
94
- if embed is not None:
95
- generated = embed
96
- else:
97
- if tokens is None:
98
- tokens = torch.tensor(tokenizer.encode(prompt))
99
- tokens = tokens.unsqueeze(0).to(device)
100
-
101
- generated = model.gpt.transformer.wte(tokens)
102
-
103
- for i in range(entry_length):
104
-
105
- outputs = model.gpt(inputs_embeds=generated)
106
- logits = outputs.logits
107
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
108
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
109
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
110
- sorted_indices_to_remove = cumulative_probs > top_p
111
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
112
- ..., :-1
113
- ].clone()
114
- sorted_indices_to_remove[..., 0] = 0
115
-
116
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
117
- logits[:, indices_to_remove] = filter_value
118
- next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
119
- score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
120
- score_col.append(score)
121
- next_token_embed = model.gpt.transformer.wte(next_token)
122
- if tokens is None:
123
- tokens = next_token
124
- else:
125
- tokens = torch.cat((tokens, next_token), dim=1)
126
- generated = torch.cat((generated, next_token_embed), dim=1)
127
- if stop_token_index == next_token.item():
128
- break
129
-
130
- output_list = list(tokens.squeeze(0).cpu().numpy())
131
- output_text = tokenizer.decode(output_list)
132
- generated_list.append(output_text)
133
- return generated_list[0]
134
-
135
-
136
- @torch.no_grad()
137
- def pc_to_text(pc_encoder: torch.nn.Module, pc, cond_scale):
138
- ref_dev = next(pc_encoder.parameters()).device
139
- prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
140
- prefix = prefix.float() * cond_scale
141
- prefix = prefix.to(next(model.parameters()).device)
142
- prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
143
- text = generate2(model, tokenizer, embed=prefix_embed)
144
- return text
145
-
146
- @torch.no_grad()
147
- def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
148
- ref_dev = next(pc_encoder.parameters()).device
149
- enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
150
- enc = torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2
151
- if torch.cuda.is_available():
152
- enc = enc.to('cuda:' + str(torch.cuda.current_device()))
153
- # enc = enc.type(half)
154
- # with torch.autocast("cuda"):
155
- return pipe(
156
- prompt=', '.join(["best quality"] + ([prompt] if prompt else [])),
157
- negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
158
- image=enc,
159
- width=width, height=height,
160
- guidance_scale=cfg_scale,
161
- noise_level=noise_level,
162
- callback=callback,
163
- num_inference_steps=num_steps
164
- ).images[0]
165
-
166
-
167
- N = type(None)
168
- V = np.array
169
- ARRAY = np.ndarray
170
- ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
171
- VS = Union[Tuple[V, ...], List[V]]
172
- VN = Union[V, N]
173
- VNS = Union[VS, N]
174
- T = torch.Tensor
175
- TS = Union[Tuple[T, ...], List[T]]
176
- TN = Optional[T]
177
- TNS = Union[Tuple[TN, ...], List[TN]]
178
- TSN = Optional[TS]
179
- TA = Union[T, ARRAY]
180
-
181
-
182
- D = torch.device
183
-
184
- pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
185
- "diffusers/stable-diffusion-2-1-unclip-i2i-l",
186
- # variant="fp16",
187
- image_encoder = Wrapper()
188
- )
189
- # pe = pipe.text_encoder.text_model.embeddings
190
- # pe.position_ids = torch.arange(pe.position_ids.shape[-1]).expand((1, -1)).to(pe.position_ids) # workaround
191
- if torch.cuda.is_available():
192
- pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
193
- pipe.enable_model_cpu_offload(torch.cuda.current_device())
194
- pipe.enable_attention_slicing()
195
- pipe.enable_vae_slicing()
196
-
197
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
198
- prefix_length = 10
199
- model = ClipCaptionModel(prefix_length)
200
- # print(model.gpt_embedding_size)
201
- model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt'), map_location='cpu'))
202
- model.eval()
203
- if torch.cuda.is_available():
204
  model = model.cuda()
 
1
+ import torch
2
+ import torch_redstone as rst
3
+ import transformers
4
+ import numpy as np
5
+ from torch import nn
6
+ from typing import Tuple, List, Union, Optional
7
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
8
+ from huggingface_hub import hf_hub_download
9
+ from diffusers import StableUnCLIPImg2ImgPipeline
10
+
11
+ N = type(None)
12
+ V = np.array
13
+ ARRAY = np.ndarray
14
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
15
+ VS = Union[Tuple[V, ...], List[V]]
16
+ VN = Union[V, N]
17
+ VNS = Union[VS, N]
18
+ T = torch.Tensor
19
+ TS = Union[Tuple[T, ...], List[T]]
20
+ TN = Optional[T]
21
+ TNS = Union[Tuple[TN, ...], List[TN]]
22
+ TSN = Optional[TS]
23
+ TA = Union[T, ARRAY]
24
+
25
+
26
+ D = torch.device
27
+
28
+ class Wrapper(transformers.modeling_utils.PreTrainedModel):
29
+ def __init__(self) -> None:
30
+ super().__init__(transformers.configuration_utils.PretrainedConfig())
31
+ self.param = torch.nn.Parameter(torch.tensor(0.))
32
+
33
+ def forward(self, x):
34
+ return rst.ObjectProxy(image_embeds=x)
35
+
36
+ class MLP(nn.Module):
37
+
38
+ def forward(self, x: T) -> T:
39
+ return self.model(x)
40
+
41
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
42
+ super(MLP, self).__init__()
43
+ layers = []
44
+ for i in range(len(sizes) -1):
45
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
46
+ if i < len(sizes) - 2:
47
+ layers.append(act())
48
+ self.model = nn.Sequential(*layers)
49
+
50
+ class ClipCaptionModel(nn.Module):
51
+
52
+ #@functools.lru_cache #FIXME
53
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
54
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
55
+
56
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
57
+ embedding_text = self.gpt.transformer.wte(tokens)
58
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
59
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
60
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
61
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
62
+ if labels is not None:
63
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
64
+ labels = torch.cat((dummy_token, tokens), dim=1)
65
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
66
+ return out
67
+
68
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
69
+ super(ClipCaptionModel, self).__init__()
70
+ self.prefix_length = prefix_length
71
+ self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
72
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
73
+ if prefix_length > 10: # not enough memory
74
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
75
+ else:
76
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
77
+
78
+ class ClipCaptionPrefix(ClipCaptionModel):
79
+
80
+ def parameters(self, recurse: bool = True):
81
+ return self.clip_project.parameters()
82
+
83
+ def train(self, mode: bool = True):
84
+ super(ClipCaptionPrefix, self).train(mode)
85
+ self.gpt.eval()
86
+ return self
87
+
88
+ def generate2(
89
+ model,
90
+ tokenizer,
91
+ tokens=None,
92
+ prompt=None,
93
+ embed=None,
94
+ entry_count=1,
95
+ entry_length=67, # maximum number of words
96
+ top_p=0.8,
97
+ temperature=1.,
98
+ stop_token: str = '.',
99
+ ):
100
+ model.eval()
101
+ generated_num = 0
102
+ generated_list = []
103
+ stop_token_index = tokenizer.encode(stop_token)[0]
104
+ filter_value = -float("Inf")
105
+ device = next(model.parameters()).device
106
+ score_col = []
107
+ with torch.no_grad():
108
+
109
+ for entry_idx in range(entry_count):
110
+ if embed is not None:
111
+ generated = embed
112
+ else:
113
+ if tokens is None:
114
+ tokens = torch.tensor(tokenizer.encode(prompt))
115
+ tokens = tokens.unsqueeze(0).to(device)
116
+
117
+ generated = model.gpt.transformer.wte(tokens)
118
+
119
+ for i in range(entry_length):
120
+
121
+ outputs = model.gpt(inputs_embeds=generated)
122
+ logits = outputs.logits
123
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
124
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
125
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
126
+ sorted_indices_to_remove = cumulative_probs > top_p
127
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
128
+ ..., :-1
129
+ ].clone()
130
+ sorted_indices_to_remove[..., 0] = 0
131
+
132
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
133
+ logits[:, indices_to_remove] = filter_value
134
+ next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
135
+ score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
136
+ score_col.append(score)
137
+ next_token_embed = model.gpt.transformer.wte(next_token)
138
+ if tokens is None:
139
+ tokens = next_token
140
+ else:
141
+ tokens = torch.cat((tokens, next_token), dim=1)
142
+ generated = torch.cat((generated, next_token_embed), dim=1)
143
+ if stop_token_index == next_token.item():
144
+ break
145
+
146
+ output_list = list(tokens.squeeze(0).cpu().numpy())
147
+ output_text = tokenizer.decode(output_list)
148
+ generated_list.append(output_text)
149
+ return generated_list[0]
150
+
151
+
152
+ @torch.no_grad()
153
+ def pc_to_text(pc_encoder: torch.nn.Module, pc, cond_scale):
154
+ ref_dev = next(pc_encoder.parameters()).device
155
+ prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
156
+ prefix = prefix.float() * cond_scale
157
+ prefix = prefix.to(next(model.parameters()).device)
158
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
159
+ text = generate2(model, tokenizer, embed=prefix_embed)
160
+ return text
161
+
162
+ @torch.no_grad()
163
+ def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
164
+ ref_dev = next(pc_encoder.parameters()).device
165
+ enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
166
+ enc = torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2
167
+ if torch.cuda.is_available():
168
+ enc = enc.to('cuda:' + str(torch.cuda.current_device()))
169
+ # enc = enc.type(half)
170
+ # with torch.autocast("cuda"):
171
+ return pipe(
172
+ prompt=', '.join(["best quality"] + ([prompt] if prompt else [])),
173
+ negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
174
+ image=enc,
175
+ width=width, height=height,
176
+ guidance_scale=cfg_scale,
177
+ noise_level=noise_level,
178
+ callback=callback,
179
+ num_inference_steps=num_steps
180
+ ).images[0]
181
+
182
+
183
+
184
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
185
+ "diffusers/stable-diffusion-2-1-unclip-i2i-l",
186
+ # variant="fp16",
187
+ image_encoder = Wrapper()
188
+ )
189
+ # pe = pipe.text_encoder.text_model.embeddings
190
+ # pe.position_ids = torch.arange(pe.position_ids.shape[-1]).expand((1, -1)).to(pe.position_ids) # workaround
191
+ if torch.cuda.is_available():
192
+ pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
193
+ pipe.enable_model_cpu_offload(torch.cuda.current_device())
194
+ pipe.enable_attention_slicing()
195
+ pipe.enable_vae_slicing()
196
+
197
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
198
+ prefix_length = 10
199
+ model = ClipCaptionModel(prefix_length)
200
+ # print(model.gpt_embedding_size)
201
+ model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt'), map_location='cpu'))
202
+ model.eval()
203
+ if torch.cuda.is_available():
204
  model = model.cuda()