Spaces:
Runtime error
Runtime error
Jonathan Malott
commited on
Commit
•
44df93e
0
Parent(s):
initial
Browse files- .gitignore +20 -0
- Procfile +1 -0
- dalle/__init__.py +0 -0
- dalle/models/__init__.py +206 -0
- dalle/models/stage1/layers.py +373 -0
- dalle/models/stage1/vqgan.py +99 -0
- dalle/models/stage2/layers.py +140 -0
- dalle/models/stage2/transformer.py +257 -0
- dalle/models/tokenizer.py +26 -0
- dalle/utils/__init__.py +3 -0
- dalle/utils/config.py +123 -0
- dalle/utils/sampling.py +162 -0
- dalle/utils/utils.py +84 -0
- page/generate.py +97 -0
- page/reduce.py +58 -0
- requirements.txt +18 -0
- streamlit_app.py +48 -0
- utils.py +160 -0
.gitignore
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.ipynb_checkpoints/
|
2 |
+
|
3 |
+
|
4 |
+
__pycache__/
|
5 |
+
|
6 |
+
|
7 |
+
_archives/
|
8 |
+
|
9 |
+
|
10 |
+
_exampleImages/
|
11 |
+
|
12 |
+
|
13 |
+
_trash/
|
14 |
+
|
15 |
+
|
16 |
+
minDALL-E/
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
temp/
|
Procfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
web: sh setup.sh && streamlit run streamlit_app.py
|
dalle/__init__.py
ADDED
File without changes
|
dalle/models/__init__.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from typing import Optional, Tuple
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from .stage1.vqgan import VQGAN
|
17 |
+
from .stage2.transformer import Transformer1d, iGPT
|
18 |
+
from .. import utils
|
19 |
+
from ..utils.config import get_base_config
|
20 |
+
from ..utils.sampling import sampling, sampling_igpt
|
21 |
+
from .tokenizer import build_tokenizer
|
22 |
+
|
23 |
+
_MODELS = {
|
24 |
+
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class Dalle(nn.Module):
|
29 |
+
def __init__(self,
|
30 |
+
config: OmegaConf) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.tokenizer = None
|
33 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
34 |
+
embed_dim=config.stage1.embed_dim,
|
35 |
+
hparams=config.stage1.hparams)
|
36 |
+
self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
|
37 |
+
vocab_size_img=config.stage2.vocab_size_img,
|
38 |
+
hparams=config.stage2.hparams)
|
39 |
+
self.config_stage1 = config.stage1
|
40 |
+
self.config_stage2 = config.stage2
|
41 |
+
self.config_dataset = config.dataset
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_pretrained(cls,
|
45 |
+
path: str) -> nn.Module:
|
46 |
+
#path = _MODELS[path] if path in _MODELS else path
|
47 |
+
#path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
|
48 |
+
path = ''
|
49 |
+
|
50 |
+
config_base = get_base_config()
|
51 |
+
config_new = OmegaConf.load(os.path.join(path, '.cache/minDALL-E/1.3B/config.yaml'))
|
52 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
53 |
+
|
54 |
+
model = cls(config_update)
|
55 |
+
model.tokenizer = build_tokenizer('.cache/minDALL-E/1.3B/tokenizer',
|
56 |
+
context_length=model.config_dataset.context_length,
|
57 |
+
lowercase=True,
|
58 |
+
dropout=None)
|
59 |
+
model.stage1.from_ckpt('.cache/minDALL-E/1.3B/stage1_last.ckpt')
|
60 |
+
model.stage2.from_ckpt('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt')
|
61 |
+
#model.stage1.from_ckpt('https://utexas.box.com/shared/static/rpt9miyj2kikogyekpqnkd6y115xp51i.ckpt')
|
62 |
+
#model.stage2.from_ckpt('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt')
|
63 |
+
|
64 |
+
return model
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def sampling(self,
|
68 |
+
prompt: str,
|
69 |
+
top_k: int = 256,
|
70 |
+
top_p: Optional[float] = None,
|
71 |
+
softmax_temperature: float = 1.0,
|
72 |
+
num_candidates: int = 96,
|
73 |
+
device: str = 'cuda:0',
|
74 |
+
use_fp16: bool = True) -> torch.FloatTensor:
|
75 |
+
self.stage1.eval()
|
76 |
+
self.stage2.eval()
|
77 |
+
|
78 |
+
tokens = self.tokenizer.encode(prompt)
|
79 |
+
tokens = torch.LongTensor(tokens.ids)
|
80 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
81 |
+
|
82 |
+
# Check if the encoding works as intended
|
83 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
84 |
+
|
85 |
+
tokens = tokens.to(device)
|
86 |
+
codes = sampling(self.stage2,
|
87 |
+
tokens,
|
88 |
+
top_k=top_k,
|
89 |
+
top_p=top_p,
|
90 |
+
softmax_temperature=softmax_temperature,
|
91 |
+
use_fp16=use_fp16)
|
92 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
93 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
94 |
+
return pixels
|
95 |
+
|
96 |
+
|
97 |
+
class ImageGPT(pl.LightningModule):
|
98 |
+
def __init__(self,
|
99 |
+
config: OmegaConf) -> None:
|
100 |
+
super().__init__()
|
101 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
102 |
+
embed_dim=config.stage1.embed_dim,
|
103 |
+
hparams=config.stage1.hparams)
|
104 |
+
self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
|
105 |
+
use_cls_cond=config.stage2.use_cls_cond,
|
106 |
+
hparams=config.stage2.hparams)
|
107 |
+
self.config = config
|
108 |
+
self.use_cls_cond = config.stage2.use_cls_cond
|
109 |
+
|
110 |
+
# make the parameters in stage 1 not trainable
|
111 |
+
self.stage1.eval()
|
112 |
+
for p in self.stage1.parameters():
|
113 |
+
p.requires_grad = False
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_pretrained(cls,
|
117 |
+
path_upstream: str,
|
118 |
+
path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
|
119 |
+
config_base = get_base_config(use_default=False)
|
120 |
+
config_down = OmegaConf.load(path_downstream)
|
121 |
+
config_down = OmegaConf.merge(config_base, config_down)
|
122 |
+
|
123 |
+
model = cls(config_down)
|
124 |
+
model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
|
125 |
+
model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
|
126 |
+
return model, config_down
|
127 |
+
|
128 |
+
def sample(self,
|
129 |
+
cls_idx: Optional[int] = None,
|
130 |
+
top_k: int = 256,
|
131 |
+
top_p: Optional[float] = None,
|
132 |
+
softmax_temperature: float = 1.0,
|
133 |
+
num_candidates: int = 16,
|
134 |
+
device: str = 'cuda:0',
|
135 |
+
use_fp16: bool = True,
|
136 |
+
is_tqdm: bool = True) -> torch.FloatTensor:
|
137 |
+
self.stage1.eval()
|
138 |
+
self.stage2.eval()
|
139 |
+
|
140 |
+
if cls_idx is None:
|
141 |
+
sos = self.stage2.sos.repeat(num_candidates, 1, 1)
|
142 |
+
else:
|
143 |
+
sos = torch.LongTensor([cls_idx]).to(device=device)
|
144 |
+
sos = sos.repeat(num_candidates)
|
145 |
+
sos = self.stage2.sos(sos).unsqueeze(1)
|
146 |
+
|
147 |
+
codes = sampling_igpt(self.stage2,
|
148 |
+
sos=sos,
|
149 |
+
top_k=top_k,
|
150 |
+
top_p=top_p,
|
151 |
+
softmax_temperature=softmax_temperature,
|
152 |
+
use_fp16=use_fp16,
|
153 |
+
is_tqdm=is_tqdm)
|
154 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
155 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
156 |
+
return pixels
|
157 |
+
|
158 |
+
def forward(self,
|
159 |
+
images: torch.FloatTensor,
|
160 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
161 |
+
B, C, H, W = images.shape
|
162 |
+
with torch.no_grad():
|
163 |
+
with autocast(enabled=False):
|
164 |
+
codes = self.stage1.get_codes(images).detach()
|
165 |
+
logits = self.stage2(codes, labels)
|
166 |
+
return logits, codes
|
167 |
+
|
168 |
+
def training_step(self, batch, batch_idx):
|
169 |
+
images, labels = batch
|
170 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
171 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
172 |
+
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
173 |
+
return loss
|
174 |
+
|
175 |
+
def validation_step(self, batch, batch_idx):
|
176 |
+
images, labels = batch
|
177 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
178 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
179 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
180 |
+
return loss
|
181 |
+
|
182 |
+
def configure_optimizers(self):
|
183 |
+
assert self.config.optimizer.opt_type == 'adamW'
|
184 |
+
assert self.config.optimizer.sched_type == 'cosine'
|
185 |
+
|
186 |
+
opt = torch.optim.AdamW(self.parameters(),
|
187 |
+
lr=self.config.optimizer.base_lr,
|
188 |
+
betas=self.config.optimizer.betas,
|
189 |
+
weight_decay=self.config.optimizer.weight_decay)
|
190 |
+
sched = CosineAnnealingLR(opt,
|
191 |
+
T_max=self.config.optimizer.max_steps,
|
192 |
+
eta_min=self.config.optimizer.min_lr)
|
193 |
+
sched = {
|
194 |
+
'scheduler': sched,
|
195 |
+
'name': 'cosine'
|
196 |
+
}
|
197 |
+
return [opt], [sched]
|
198 |
+
|
199 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
|
200 |
+
on_tpu=False, using_native_amp=False, using_lbfgs=False):
|
201 |
+
optimizer.step(closure=optimizer_closure)
|
202 |
+
self.lr_schedulers().step()
|
203 |
+
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
204 |
+
|
205 |
+
def on_epoch_start(self):
|
206 |
+
self.stage1.eval()
|
dalle/models/stage1/layers.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Modified from VQGAN (https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import Tuple, Optional
|
9 |
+
|
10 |
+
|
11 |
+
def nonlinearity(x):
|
12 |
+
# swish
|
13 |
+
return x*torch.sigmoid(x)
|
14 |
+
|
15 |
+
|
16 |
+
def Normalize(in_channels):
|
17 |
+
return torch.nn.GroupNorm(num_groups=32,
|
18 |
+
num_channels=in_channels,
|
19 |
+
eps=1e-6,
|
20 |
+
affine=True)
|
21 |
+
|
22 |
+
|
23 |
+
class Upsample(nn.Module):
|
24 |
+
def __init__(self, in_channels, with_conv):
|
25 |
+
super().__init__()
|
26 |
+
self.with_conv = with_conv
|
27 |
+
if self.with_conv:
|
28 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
29 |
+
in_channels,
|
30 |
+
kernel_size=3,
|
31 |
+
stride=1,
|
32 |
+
padding=1)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
36 |
+
if self.with_conv:
|
37 |
+
x = self.conv(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class Downsample(nn.Module):
|
42 |
+
def __init__(self, in_channels, with_conv):
|
43 |
+
super().__init__()
|
44 |
+
self.with_conv = with_conv
|
45 |
+
if self.with_conv:
|
46 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
47 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
48 |
+
in_channels,
|
49 |
+
kernel_size=3,
|
50 |
+
stride=2,
|
51 |
+
padding=0)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.with_conv:
|
55 |
+
pad = (0, 1, 0, 1)
|
56 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
57 |
+
x = self.conv(x)
|
58 |
+
else:
|
59 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class ResnetBlock(nn.Module):
|
64 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
65 |
+
dropout, temb_channels=512):
|
66 |
+
assert temb_channels == 0
|
67 |
+
super().__init__()
|
68 |
+
self.in_channels = in_channels
|
69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
70 |
+
self.out_channels = out_channels
|
71 |
+
self.use_conv_shortcut = conv_shortcut
|
72 |
+
|
73 |
+
self.norm1 = Normalize(in_channels)
|
74 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
75 |
+
out_channels,
|
76 |
+
kernel_size=3,
|
77 |
+
stride=1,
|
78 |
+
padding=1)
|
79 |
+
self.norm2 = Normalize(out_channels)
|
80 |
+
self.dropout = torch.nn.Dropout(dropout)
|
81 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
82 |
+
out_channels,
|
83 |
+
kernel_size=3,
|
84 |
+
stride=1,
|
85 |
+
padding=1)
|
86 |
+
if self.in_channels != self.out_channels:
|
87 |
+
if self.use_conv_shortcut:
|
88 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
else:
|
94 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
95 |
+
out_channels,
|
96 |
+
kernel_size=1,
|
97 |
+
stride=1,
|
98 |
+
padding=0)
|
99 |
+
|
100 |
+
def forward(self, x, temb=None):
|
101 |
+
assert temb is None
|
102 |
+
|
103 |
+
h = x
|
104 |
+
h = self.norm1(h)
|
105 |
+
h = nonlinearity(h)
|
106 |
+
h = self.conv1(h)
|
107 |
+
|
108 |
+
h = self.norm2(h)
|
109 |
+
h = nonlinearity(h)
|
110 |
+
h = self.dropout(h)
|
111 |
+
h = self.conv2(h)
|
112 |
+
|
113 |
+
if self.in_channels != self.out_channels:
|
114 |
+
if self.use_conv_shortcut:
|
115 |
+
x = self.conv_shortcut(x)
|
116 |
+
else:
|
117 |
+
x = self.nin_shortcut(x)
|
118 |
+
return x+h
|
119 |
+
|
120 |
+
|
121 |
+
class AttnBlock(nn.Module):
|
122 |
+
def __init__(self, in_channels):
|
123 |
+
super().__init__()
|
124 |
+
self.in_channels = in_channels
|
125 |
+
|
126 |
+
self.norm = Normalize(in_channels)
|
127 |
+
self.q = torch.nn.Conv2d(in_channels,
|
128 |
+
in_channels,
|
129 |
+
kernel_size=1,
|
130 |
+
stride=1,
|
131 |
+
padding=0)
|
132 |
+
self.k = torch.nn.Conv2d(in_channels,
|
133 |
+
in_channels,
|
134 |
+
kernel_size=1,
|
135 |
+
stride=1,
|
136 |
+
padding=0)
|
137 |
+
self.v = torch.nn.Conv2d(in_channels,
|
138 |
+
in_channels,
|
139 |
+
kernel_size=1,
|
140 |
+
stride=1,
|
141 |
+
padding=0)
|
142 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
143 |
+
in_channels,
|
144 |
+
kernel_size=1,
|
145 |
+
stride=1,
|
146 |
+
padding=0)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
h_ = x
|
150 |
+
h_ = self.norm(h_)
|
151 |
+
q = self.q(h_)
|
152 |
+
k = self.k(h_)
|
153 |
+
v = self.v(h_)
|
154 |
+
|
155 |
+
# compute attention
|
156 |
+
b, c, h, w = q.shape
|
157 |
+
q = q.reshape(b, c, h*w)
|
158 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
159 |
+
k = k.reshape(b, c, h*w) # b,c,hw
|
160 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
161 |
+
w_ = w_ * (int(c)**(-0.5))
|
162 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
163 |
+
|
164 |
+
# attend to values
|
165 |
+
v = v.reshape(b, c, h*w)
|
166 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
167 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
168 |
+
h_ = h_.reshape(b, c, h, w)
|
169 |
+
|
170 |
+
h_ = self.proj_out(h_)
|
171 |
+
return x+h_
|
172 |
+
|
173 |
+
|
174 |
+
class Encoder(nn.Module):
|
175 |
+
def __init__(self,
|
176 |
+
*, # forced to use named arguments
|
177 |
+
ch: int,
|
178 |
+
out_ch: int,
|
179 |
+
ch_mult: Tuple[int] = (1, 2, 4, 8),
|
180 |
+
num_res_blocks: int,
|
181 |
+
attn_resolutions: Tuple[int],
|
182 |
+
pdrop: float = 0.0,
|
183 |
+
resamp_with_conv: bool = True,
|
184 |
+
in_channels: int,
|
185 |
+
resolution: int,
|
186 |
+
z_channels: int,
|
187 |
+
double_z: Optional[bool] = None) -> None:
|
188 |
+
super().__init__()
|
189 |
+
self.ch = ch
|
190 |
+
self.temb_ch = 0
|
191 |
+
self.num_resolutions = len(ch_mult)
|
192 |
+
self.num_res_blocks = num_res_blocks
|
193 |
+
self.resolution = resolution
|
194 |
+
self.in_channels = in_channels
|
195 |
+
|
196 |
+
# downsampling
|
197 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
198 |
+
self.ch,
|
199 |
+
kernel_size=3,
|
200 |
+
stride=1,
|
201 |
+
padding=1)
|
202 |
+
|
203 |
+
curr_res = resolution
|
204 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
205 |
+
self.down = nn.ModuleList()
|
206 |
+
for i_level in range(self.num_resolutions):
|
207 |
+
block = nn.ModuleList()
|
208 |
+
attn = nn.ModuleList()
|
209 |
+
block_in = ch*in_ch_mult[i_level]
|
210 |
+
block_out = ch*ch_mult[i_level]
|
211 |
+
for i_block in range(self.num_res_blocks):
|
212 |
+
block.append(ResnetBlock(in_channels=block_in,
|
213 |
+
out_channels=block_out,
|
214 |
+
temb_channels=self.temb_ch,
|
215 |
+
dropout=pdrop))
|
216 |
+
block_in = block_out
|
217 |
+
if curr_res in attn_resolutions:
|
218 |
+
attn.append(AttnBlock(block_in))
|
219 |
+
down = nn.Module()
|
220 |
+
down.block = block
|
221 |
+
down.attn = attn
|
222 |
+
if i_level != self.num_resolutions-1:
|
223 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
224 |
+
curr_res = curr_res // 2
|
225 |
+
self.down.append(down)
|
226 |
+
|
227 |
+
# middle
|
228 |
+
self.mid = nn.Module()
|
229 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
230 |
+
out_channels=block_in,
|
231 |
+
temb_channels=self.temb_ch,
|
232 |
+
dropout=pdrop)
|
233 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
234 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_in,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=pdrop)
|
238 |
+
|
239 |
+
# end
|
240 |
+
self.norm_out = Normalize(block_in)
|
241 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
242 |
+
2*z_channels if double_z else z_channels,
|
243 |
+
kernel_size=3,
|
244 |
+
stride=1,
|
245 |
+
padding=1)
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
assert x.shape[2] == x.shape[3] == self.resolution, \
|
249 |
+
"{}, {}".format(x.shape, self.resolution)
|
250 |
+
|
251 |
+
# downsampling
|
252 |
+
h = self.conv_in(x)
|
253 |
+
for i_level in range(self.num_resolutions):
|
254 |
+
for i_block in range(self.num_res_blocks):
|
255 |
+
h = self.down[i_level].block[i_block](h)
|
256 |
+
if len(self.down[i_level].attn) > 0:
|
257 |
+
h = self.down[i_level].attn[i_block](h)
|
258 |
+
if i_level != self.num_resolutions-1:
|
259 |
+
h = self.down[i_level].downsample(h)
|
260 |
+
|
261 |
+
# middle
|
262 |
+
h = self.mid.block_1(h)
|
263 |
+
h = self.mid.attn_1(h)
|
264 |
+
h = self.mid.block_2(h)
|
265 |
+
|
266 |
+
# end
|
267 |
+
h = self.norm_out(h)
|
268 |
+
h = nonlinearity(h)
|
269 |
+
h = self.conv_out(h)
|
270 |
+
return h
|
271 |
+
|
272 |
+
|
273 |
+
class Decoder(nn.Module):
|
274 |
+
def __init__(self,
|
275 |
+
*, # forced to use named arguments
|
276 |
+
ch: int,
|
277 |
+
out_ch: int,
|
278 |
+
ch_mult: Tuple[int] = (1, 2, 4, 8),
|
279 |
+
num_res_blocks: int,
|
280 |
+
attn_resolutions: Tuple[int],
|
281 |
+
pdrop: float = 0.0,
|
282 |
+
resamp_with_conv: bool = True,
|
283 |
+
in_channels: int,
|
284 |
+
resolution: int,
|
285 |
+
z_channels: int,
|
286 |
+
double_z: bool) -> None:
|
287 |
+
super().__init__()
|
288 |
+
self.ch = ch
|
289 |
+
self.temb_ch = 0
|
290 |
+
self.num_resolutions = len(ch_mult)
|
291 |
+
self.num_res_blocks = num_res_blocks
|
292 |
+
self.resolution = resolution
|
293 |
+
self.in_channels = in_channels
|
294 |
+
|
295 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
296 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
297 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
298 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
299 |
+
|
300 |
+
# z to block_in
|
301 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
302 |
+
block_in,
|
303 |
+
kernel_size=3,
|
304 |
+
stride=1,
|
305 |
+
padding=1)
|
306 |
+
|
307 |
+
# middle
|
308 |
+
self.mid = nn.Module()
|
309 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
310 |
+
out_channels=block_in,
|
311 |
+
temb_channels=self.temb_ch,
|
312 |
+
dropout=pdrop)
|
313 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
314 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
315 |
+
out_channels=block_in,
|
316 |
+
temb_channels=self.temb_ch,
|
317 |
+
dropout=pdrop)
|
318 |
+
|
319 |
+
# upsampling
|
320 |
+
self.up = nn.ModuleList()
|
321 |
+
for i_level in reversed(range(self.num_resolutions)):
|
322 |
+
block = nn.ModuleList()
|
323 |
+
attn = nn.ModuleList()
|
324 |
+
block_out = ch*ch_mult[i_level]
|
325 |
+
for i_block in range(self.num_res_blocks+1):
|
326 |
+
block.append(ResnetBlock(in_channels=block_in,
|
327 |
+
out_channels=block_out,
|
328 |
+
temb_channels=self.temb_ch,
|
329 |
+
dropout=pdrop))
|
330 |
+
block_in = block_out
|
331 |
+
if curr_res in attn_resolutions:
|
332 |
+
attn.append(AttnBlock(block_in))
|
333 |
+
up = nn.Module()
|
334 |
+
up.block = block
|
335 |
+
up.attn = attn
|
336 |
+
if i_level != 0:
|
337 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
338 |
+
curr_res = curr_res * 2
|
339 |
+
self.up.insert(0, up) # prepend to get consistent order
|
340 |
+
|
341 |
+
# end
|
342 |
+
self.norm_out = Normalize(block_in)
|
343 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
344 |
+
out_ch,
|
345 |
+
kernel_size=3,
|
346 |
+
stride=1,
|
347 |
+
padding=1)
|
348 |
+
|
349 |
+
def forward(self, z):
|
350 |
+
assert z.shape[1:] == self.z_shape[1:]
|
351 |
+
self.last_z_shape = z.shape
|
352 |
+
|
353 |
+
# z to block_in
|
354 |
+
h = self.conv_in(z)
|
355 |
+
|
356 |
+
# middle
|
357 |
+
h = self.mid.block_1(h)
|
358 |
+
h = self.mid.attn_1(h)
|
359 |
+
h = self.mid.block_2(h)
|
360 |
+
|
361 |
+
# upsampling
|
362 |
+
for i_level in reversed(range(self.num_resolutions)):
|
363 |
+
for i_block in range(self.num_res_blocks+1):
|
364 |
+
h = self.up[i_level].block[i_block](h)
|
365 |
+
if len(self.up[i_level].attn) > 0:
|
366 |
+
h = self.up[i_level].attn[i_block](h)
|
367 |
+
if i_level != 0:
|
368 |
+
h = self.up[i_level].upsample(h)
|
369 |
+
|
370 |
+
h = self.norm_out(h)
|
371 |
+
h = nonlinearity(h)
|
372 |
+
h = self.conv_out(h)
|
373 |
+
return h
|
dalle/models/stage1/vqgan.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Modified from VQGAN (https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import List, Tuple, Optional
|
9 |
+
from einops import rearrange
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from .layers import Encoder, Decoder
|
12 |
+
|
13 |
+
|
14 |
+
class VectorQuantizer(nn.Module):
|
15 |
+
"""
|
16 |
+
Simplified VectorQuantizer in the original VQGAN repository
|
17 |
+
by removing unncessary modules for sampling
|
18 |
+
"""
|
19 |
+
def __init__(self, dim: int, n_embed: int, beta: float) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.n_embed = n_embed
|
22 |
+
self.dim = dim
|
23 |
+
self.beta = beta
|
24 |
+
|
25 |
+
self.embedding = nn.Embedding(self.n_embed, self.dim)
|
26 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
|
27 |
+
|
28 |
+
def forward(self,
|
29 |
+
z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
30 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
|
31 |
+
z_flattened = z.view(-1, self.dim)
|
32 |
+
|
33 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
34 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
35 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
36 |
+
|
37 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
38 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
39 |
+
return z_q, min_encoding_indices
|
40 |
+
|
41 |
+
def get_codebook_entry(self,
|
42 |
+
indices: torch.LongTensor,
|
43 |
+
shape: Optional[List[int]] = None) -> torch.FloatTensor:
|
44 |
+
z_q = self.embedding(indices)
|
45 |
+
if shape is not None:
|
46 |
+
z_q = z_q.view(shape)
|
47 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
48 |
+
return z_q
|
49 |
+
|
50 |
+
|
51 |
+
class VQGAN(nn.Module):
|
52 |
+
def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
|
53 |
+
super().__init__()
|
54 |
+
self.encoder = Encoder(**hparams)
|
55 |
+
self.decoder = Decoder(**hparams)
|
56 |
+
self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
|
57 |
+
self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
|
58 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
|
59 |
+
self.latent_dim = hparams.attn_resolutions[0]
|
60 |
+
|
61 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
62 |
+
quant = self.encode(x)
|
63 |
+
dec = self.decode(quant)
|
64 |
+
return dec
|
65 |
+
|
66 |
+
def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
67 |
+
h = self.encoder(x)
|
68 |
+
h = self.quant_conv(h)
|
69 |
+
quant = self.quantize(h)[0]
|
70 |
+
quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
|
71 |
+
return quant
|
72 |
+
|
73 |
+
def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
|
74 |
+
quant = self.post_quant_conv(quant)
|
75 |
+
dec = self.decoder(quant)
|
76 |
+
return dec
|
77 |
+
|
78 |
+
def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
|
79 |
+
quant = self.quantize.get_codebook_entry(code)
|
80 |
+
quant = quant.permute(0, 3, 1, 2)
|
81 |
+
dec = self.decode(quant)
|
82 |
+
return dec
|
83 |
+
|
84 |
+
def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
|
85 |
+
h = self.encoder(x)
|
86 |
+
h = self.quant_conv(h)
|
87 |
+
codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
|
88 |
+
return codes
|
89 |
+
|
90 |
+
def from_ckpt(self, path: str, strict: bool = True) -> None:
|
91 |
+
#ckpt = torch.load(path, map_location='cpu')['state_dict']
|
92 |
+
#self.load_state_dict(ckpt, strict=strict)
|
93 |
+
#print(f'{path} successfully restored..')
|
94 |
+
|
95 |
+
#ckpt = torch.load(path, map_location='cpu')['state_dict']
|
96 |
+
ckpt = torch.utils.model_zoo.load_url('https://utexas.box.com/shared/static/rpt9miyj2kikogyekpqnkd6y115xp51i.ckpt', map_location='cpu')['state_dict']
|
97 |
+
|
98 |
+
self.load_state_dict(ckpt, strict=True)
|
99 |
+
print(f'{path} succesfully restored..')
|
dalle/models/stage2/layers.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
# Modified from minGPT (https://github.com/karpathy/minGPT)
|
7 |
+
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
class GELU(nn.Module):
|
17 |
+
def __init__(self, use_approx=False):
|
18 |
+
super().__init__()
|
19 |
+
self.use_approx = use_approx
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
if self.use_approx:
|
23 |
+
return x * torch.sigmoid(1.702 * x)
|
24 |
+
else:
|
25 |
+
return F.gelu(x)
|
26 |
+
|
27 |
+
|
28 |
+
class MultiHeadSelfAttention(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
ctx_len: int,
|
32 |
+
embed_dim: int,
|
33 |
+
n_heads: int,
|
34 |
+
resid_pdrop: float,
|
35 |
+
attn_pdrop: float,
|
36 |
+
attn_bias: bool,
|
37 |
+
use_mask: bool = True):
|
38 |
+
super().__init__()
|
39 |
+
assert embed_dim % n_heads == 0
|
40 |
+
|
41 |
+
# key, query, value projections for all heads
|
42 |
+
self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
43 |
+
self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
44 |
+
self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
45 |
+
|
46 |
+
# regularization
|
47 |
+
self.attn_drop = nn.Dropout(attn_pdrop)
|
48 |
+
self.resid_drop = nn.Dropout(resid_pdrop)
|
49 |
+
|
50 |
+
# output projection
|
51 |
+
self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
|
52 |
+
|
53 |
+
self.n_heads = n_heads
|
54 |
+
self.ctx_len = ctx_len
|
55 |
+
self.use_mask = use_mask
|
56 |
+
if self.use_mask:
|
57 |
+
self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
|
58 |
+
self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
|
59 |
+
|
60 |
+
def forward(self, x, use_cache=False, layer_past=None):
|
61 |
+
B, T, C = x.shape
|
62 |
+
x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
|
63 |
+
|
64 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
65 |
+
k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
66 |
+
q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
67 |
+
v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
68 |
+
|
69 |
+
if use_cache:
|
70 |
+
present = torch.stack([k, v])
|
71 |
+
|
72 |
+
if layer_past is not None:
|
73 |
+
past_key, past_value = layer_past
|
74 |
+
k = torch.cat([past_key, k], dim=-2)
|
75 |
+
v = torch.cat([past_value, v], dim=-2)
|
76 |
+
|
77 |
+
if use_cache and layer_past is not None:
|
78 |
+
# Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
|
79 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
80 |
+
att = F.softmax(att, dim=-1)
|
81 |
+
att = self.attn_drop(att)
|
82 |
+
y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
|
83 |
+
else:
|
84 |
+
# Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
|
85 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
86 |
+
if self.use_mask:
|
87 |
+
mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
|
88 |
+
att = att.masked_fill(mask == 0, float('-inf'))
|
89 |
+
att = F.softmax(att, dim=-1)
|
90 |
+
att = self.attn_drop(att)
|
91 |
+
y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
|
92 |
+
y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
|
93 |
+
|
94 |
+
# output projection
|
95 |
+
y = self.resid_drop(self.proj(y))
|
96 |
+
if use_cache:
|
97 |
+
return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
|
98 |
+
else:
|
99 |
+
return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
|
100 |
+
|
101 |
+
|
102 |
+
class Block(nn.Module):
|
103 |
+
|
104 |
+
def __init__(self,
|
105 |
+
ctx_len: int,
|
106 |
+
embed_dim: int,
|
107 |
+
n_heads: int,
|
108 |
+
mlp_bias: bool,
|
109 |
+
attn_bias: bool,
|
110 |
+
resid_pdrop: bool,
|
111 |
+
attn_pdrop: bool,
|
112 |
+
gelu_use_approx: bool):
|
113 |
+
super().__init__()
|
114 |
+
self.ln1 = nn.LayerNorm(embed_dim)
|
115 |
+
self.ln2 = nn.LayerNorm(embed_dim)
|
116 |
+
|
117 |
+
self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
|
118 |
+
embed_dim=embed_dim,
|
119 |
+
n_heads=n_heads,
|
120 |
+
attn_pdrop=attn_pdrop,
|
121 |
+
resid_pdrop=resid_pdrop,
|
122 |
+
attn_bias=attn_bias,
|
123 |
+
use_mask=True)
|
124 |
+
self.mlp = nn.Sequential(
|
125 |
+
nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
|
126 |
+
GELU(gelu_use_approx),
|
127 |
+
nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
|
128 |
+
nn.Dropout(resid_pdrop),
|
129 |
+
)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x = x + self.attn(self.ln1(x))
|
133 |
+
x = x + self.mlp(self.ln2(x))
|
134 |
+
return x
|
135 |
+
|
136 |
+
def sample(self, x, layer_past=None):
|
137 |
+
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
138 |
+
x = x + attn
|
139 |
+
x = x + self.mlp(self.ln2(x))
|
140 |
+
return x, present
|
dalle/models/stage2/transformer.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
# Modified from minGPT (https://github.com/karpathy/minGPT)
|
7 |
+
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from typing import Optional, Tuple, List
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from .layers import Block
|
16 |
+
import io
|
17 |
+
|
18 |
+
class Transformer1d(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
vocab_size_txt: int,
|
22 |
+
vocab_size_img: int,
|
23 |
+
hparams: OmegaConf) -> None:
|
24 |
+
super().__init__()
|
25 |
+
assert hparams.n_layers == hparams.n_dense_layers
|
26 |
+
|
27 |
+
# input embedding for image and text
|
28 |
+
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
|
29 |
+
self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
|
30 |
+
|
31 |
+
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
|
32 |
+
self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
|
33 |
+
|
34 |
+
self.drop = nn.Dropout(hparams.embd_pdrop)
|
35 |
+
|
36 |
+
# transformer blocks
|
37 |
+
self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
|
38 |
+
embed_dim=hparams.embed_dim,
|
39 |
+
n_heads=hparams.n_heads,
|
40 |
+
mlp_bias=hparams.mlp_bias,
|
41 |
+
attn_bias=hparams.attn_bias,
|
42 |
+
resid_pdrop=hparams.resid_pdrop,
|
43 |
+
attn_pdrop=hparams.attn_pdrop,
|
44 |
+
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
|
45 |
+
self.blocks = nn.Sequential(*self.blocks)
|
46 |
+
|
47 |
+
# heads for image and text
|
48 |
+
self.ln_f = nn.LayerNorm(hparams.embed_dim)
|
49 |
+
self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
|
50 |
+
self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
|
51 |
+
|
52 |
+
self.ctx_len_img = hparams.ctx_len_img
|
53 |
+
self.ctx_len_txt = hparams.ctx_len_txt
|
54 |
+
self.n_layers = hparams.n_layers
|
55 |
+
|
56 |
+
self.apply(self._init_weights)
|
57 |
+
|
58 |
+
def _init_weights(self, module: nn.Module) -> None:
|
59 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
60 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
61 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
62 |
+
module.bias.data.zero_()
|
63 |
+
elif isinstance(module, nn.LayerNorm):
|
64 |
+
module.bias.data.zero_()
|
65 |
+
module.weight.data.fill_(1.0)
|
66 |
+
|
67 |
+
def forward(self,
|
68 |
+
images: torch.LongTensor,
|
69 |
+
texts: torch.LongTensor,
|
70 |
+
pos_images: torch.LongTensor,
|
71 |
+
pos_texts: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
72 |
+
B, T = images.shape
|
73 |
+
_, N = texts.shape
|
74 |
+
|
75 |
+
assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
|
76 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
77 |
+
|
78 |
+
texts = self.tok_emb_txt(texts)
|
79 |
+
images = self.tok_emb_img(images)
|
80 |
+
|
81 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
82 |
+
images = images + self.pos_emb_img(pos_images)
|
83 |
+
|
84 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
85 |
+
x = self.drop(x)
|
86 |
+
x = self.blocks(x)
|
87 |
+
x = self.ln_f(x)
|
88 |
+
|
89 |
+
texts = x[:, :N-1].contiguous()
|
90 |
+
images = x[:, N-1:-1].contiguous()
|
91 |
+
|
92 |
+
logits_txt = self.head_txt(texts)
|
93 |
+
logits_img = self.head_img(images)
|
94 |
+
return logits_img, logits_txt
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def sampling(self,
|
98 |
+
images: torch.LongTensor,
|
99 |
+
texts: torch.LongTensor,
|
100 |
+
pos_images: torch.LongTensor,
|
101 |
+
pos_texts: torch.LongTensor,
|
102 |
+
use_fp16: bool = True,
|
103 |
+
past: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
104 |
+
_, N = texts.shape
|
105 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
106 |
+
|
107 |
+
with autocast(enabled=use_fp16):
|
108 |
+
if images is None:
|
109 |
+
assert past is None
|
110 |
+
|
111 |
+
texts = self.tok_emb_txt(texts)
|
112 |
+
x = texts + self.pos_emb_txt(pos_texts)
|
113 |
+
x = self.drop(x)
|
114 |
+
|
115 |
+
presents = []
|
116 |
+
for i, block in enumerate(self.blocks):
|
117 |
+
x, present = block.sample(x, layer_past=None)
|
118 |
+
presents.append(present)
|
119 |
+
x = self.ln_f(x)
|
120 |
+
x = x[:, N-1].contiguous()
|
121 |
+
logits = self.head_img(x)
|
122 |
+
else:
|
123 |
+
if past is None:
|
124 |
+
texts = self.tok_emb_txt(texts)
|
125 |
+
images = self.tok_emb_img(images)
|
126 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
127 |
+
images = images + self.pos_emb_img(pos_images)
|
128 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
129 |
+
else:
|
130 |
+
images = self.tok_emb_img(images)
|
131 |
+
x = images + self.pos_emb_img(pos_images)
|
132 |
+
x = self.drop(x)
|
133 |
+
|
134 |
+
if past is not None:
|
135 |
+
past = torch.cat(past, dim=-2)
|
136 |
+
presents = []
|
137 |
+
for i, block in enumerate(self.blocks):
|
138 |
+
x, present = block.sample(x, layer_past=None if past is None else past[i])
|
139 |
+
presents.append(present)
|
140 |
+
x = self.ln_f(x)
|
141 |
+
x = x[:, -1].contiguous()
|
142 |
+
logits = self.head_img(x)
|
143 |
+
return logits, presents
|
144 |
+
|
145 |
+
def from_ckpt(self, path: str) -> None:
|
146 |
+
#ckpt = torch.load(path, map_location='cpu')['state_dict']
|
147 |
+
ckpt = torch.utils.model_zoo.load_url('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt', map_location='cpu')['state_dict']
|
148 |
+
|
149 |
+
self.load_state_dict(ckpt, strict=True)
|
150 |
+
print(f'{path} succesfully restored..')
|
151 |
+
|
152 |
+
|
153 |
+
class iGPT(nn.Module):
|
154 |
+
def __init__(self,
|
155 |
+
vocab_size_img: int,
|
156 |
+
use_cls_cond: bool,
|
157 |
+
hparams: OmegaConf) -> None:
|
158 |
+
super().__init__()
|
159 |
+
self.use_cls_cond = use_cls_cond
|
160 |
+
|
161 |
+
# sos token embedding
|
162 |
+
if self.use_cls_cond:
|
163 |
+
self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
|
164 |
+
else:
|
165 |
+
self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
|
166 |
+
|
167 |
+
# input embedding
|
168 |
+
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
|
169 |
+
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
|
170 |
+
|
171 |
+
self.drop = nn.Dropout(hparams.embd_pdrop)
|
172 |
+
|
173 |
+
# transformer blocks
|
174 |
+
self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
|
175 |
+
embed_dim=hparams.embed_dim,
|
176 |
+
n_heads=hparams.n_heads,
|
177 |
+
mlp_bias=hparams.mlp_bias,
|
178 |
+
attn_bias=hparams.attn_bias,
|
179 |
+
resid_pdrop=hparams.resid_pdrop,
|
180 |
+
attn_pdrop=hparams.attn_pdrop,
|
181 |
+
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
|
182 |
+
self.blocks = nn.Sequential(*self.blocks)
|
183 |
+
|
184 |
+
# head
|
185 |
+
self.ln_f = nn.LayerNorm(hparams.embed_dim)
|
186 |
+
self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
|
187 |
+
|
188 |
+
self.ctx_len_img = hparams.ctx_len_img
|
189 |
+
self.n_layers = hparams.n_layers
|
190 |
+
|
191 |
+
self.apply(self._init_weights)
|
192 |
+
|
193 |
+
def _init_weights(self, module: nn.Module) -> None:
|
194 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
195 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
196 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
197 |
+
module.bias.data.zero_()
|
198 |
+
elif isinstance(module, nn.LayerNorm):
|
199 |
+
module.bias.data.zero_()
|
200 |
+
module.weight.data.fill_(1.0)
|
201 |
+
|
202 |
+
@torch.no_grad()
|
203 |
+
def sampling(self,
|
204 |
+
sos: torch.FloatTensor,
|
205 |
+
codes: torch.LongTensor,
|
206 |
+
pos_codes: torch.LongTensor,
|
207 |
+
n_samples: int = 16,
|
208 |
+
use_fp16: bool = True,
|
209 |
+
past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
210 |
+
with autocast(enabled=use_fp16):
|
211 |
+
if codes is None:
|
212 |
+
assert past is None
|
213 |
+
xs = self.drop(sos)
|
214 |
+
presents = []
|
215 |
+
for i, block in enumerate(self.blocks):
|
216 |
+
xs, present = block.sample(xs, layer_past=None)
|
217 |
+
presents.append(present)
|
218 |
+
xs = self.ln_f(xs)
|
219 |
+
logits = self.head(xs)[:, -1]
|
220 |
+
else:
|
221 |
+
if past is None:
|
222 |
+
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
|
223 |
+
xs = torch.cat([sos, xs], dim=1)
|
224 |
+
else:
|
225 |
+
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
|
226 |
+
xs = self.drop(xs)
|
227 |
+
|
228 |
+
past = torch.cat(past, dim=-2) if past is not None else past
|
229 |
+
presents = []
|
230 |
+
for i, block in enumerate(self.blocks):
|
231 |
+
xs, present = block.sample(xs, layer_past=None if past is None else past[i])
|
232 |
+
presents.append(present)
|
233 |
+
|
234 |
+
xs = self.ln_f(xs)
|
235 |
+
logits = self.head(xs)[:, -1]
|
236 |
+
return logits, presents
|
237 |
+
|
238 |
+
def forward(self,
|
239 |
+
codes: torch.LongTensor,
|
240 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
241 |
+
B, T = codes.shape
|
242 |
+
xps = torch.arange(T, device=codes.device).repeat((B, 1))
|
243 |
+
sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
|
244 |
+
|
245 |
+
h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
|
246 |
+
h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
|
247 |
+
|
248 |
+
h = self.drop(h)
|
249 |
+
h = self.blocks(h)
|
250 |
+
h = self.ln_f(h)
|
251 |
+
logits = self.head(h)
|
252 |
+
return logits
|
253 |
+
|
254 |
+
def from_ckpt(self, path: str, strict: bool = True) -> None:
|
255 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
256 |
+
self.load_state_dict(ckpt, strict=strict)
|
257 |
+
print(f'{path} successfully restored..')
|
dalle/models/tokenizer.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
from functools import partial
|
9 |
+
from tokenizers import CharBPETokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def build_tokenizer(path: str,
|
13 |
+
context_length: int = 64,
|
14 |
+
*args,
|
15 |
+
**kwargs):
|
16 |
+
from_file = partial(CharBPETokenizer.from_file,
|
17 |
+
vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
|
18 |
+
merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
|
19 |
+
unk_token='[UNK]')
|
20 |
+
tokenizer = from_file(*args, **kwargs)
|
21 |
+
tokenizer.add_special_tokens(['[PAD]'])
|
22 |
+
tokenizer.enable_padding(length=context_length,
|
23 |
+
pad_id=tokenizer.token_to_id('[PAD]'))
|
24 |
+
tokenizer.enable_truncation(max_length=context_length)
|
25 |
+
print(f'{path} successfully restored..')
|
26 |
+
return tokenizer
|
dalle/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import *
|
2 |
+
from .config import *
|
3 |
+
from .sampling import *
|
dalle/utils/config.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Optional, List
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class DataConfig:
|
14 |
+
dataset: Optional[str] = None
|
15 |
+
tokenizer_type: str = 'CharBPE'
|
16 |
+
context_length: int = 64
|
17 |
+
image_resolution: int = 256
|
18 |
+
transforms: str = 'dalle-vqvae'
|
19 |
+
bpe_pdrop: Optional[float] = None
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Stage1Hparams:
|
24 |
+
double_z: bool = False
|
25 |
+
z_channels: int = 256
|
26 |
+
resolution: int = 256
|
27 |
+
in_channels: int = 3
|
28 |
+
out_ch: int = 3
|
29 |
+
ch: int = 128
|
30 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
31 |
+
num_res_blocks: int = 2
|
32 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [16])
|
33 |
+
pdrop: float = 0.0
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class Stage2Hparams:
|
38 |
+
embed_dim: int = 1536
|
39 |
+
n_layers: int = 42
|
40 |
+
n_heads: int = 24
|
41 |
+
n_dense_layers: int = 42
|
42 |
+
ctx_len_img: int = 256
|
43 |
+
ctx_len_txt: int = 64
|
44 |
+
embd_pdrop: float = 0.0
|
45 |
+
resid_pdrop: float = 0.0
|
46 |
+
attn_pdrop: float = 0.0
|
47 |
+
mlp_bias: bool = True
|
48 |
+
attn_bias: bool = True
|
49 |
+
gelu_use_approx: bool = False
|
50 |
+
use_head_txt: bool = True
|
51 |
+
n_classes: Optional[int] = None
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class Stage1Config:
|
56 |
+
type: str = 'vqgan'
|
57 |
+
embed_dim: int = 256
|
58 |
+
n_embed: int = 16384
|
59 |
+
hparams: Stage1Hparams = Stage1Hparams()
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class Stage2Config:
|
64 |
+
type: str = 'transformer1d'
|
65 |
+
vocab_size_txt: int = 16384
|
66 |
+
vocab_size_img: int = 16384
|
67 |
+
use_cls_cond: Optional[bool] = None
|
68 |
+
hparams: Stage2Hparams = Stage2Hparams()
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class WarmupConfig:
|
73 |
+
epoch: int = 1
|
74 |
+
multiplier: int = 1
|
75 |
+
buffer_epoch: int = 0
|
76 |
+
min_lr: float = 0.0
|
77 |
+
mode: str = 'fix'
|
78 |
+
peak_lr: float = 1e-4
|
79 |
+
start_from_zero: bool = True
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class OptConfig:
|
84 |
+
opt_type: str = 'adamW'
|
85 |
+
base_lr: float = 1e-4
|
86 |
+
weight_decay: float = 1e-4
|
87 |
+
betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
|
88 |
+
grad_clip_norm: float = 1.0
|
89 |
+
|
90 |
+
sched_type: str = 'cosine'
|
91 |
+
max_steps: int = 0
|
92 |
+
min_lr: float = 0.0
|
93 |
+
|
94 |
+
|
95 |
+
@dataclass
|
96 |
+
class ExpConfig:
|
97 |
+
local_batch_size: int = 4
|
98 |
+
total_batch_size: int = 512
|
99 |
+
valid_batch_size: int = 32
|
100 |
+
epochs: int = 10
|
101 |
+
save_ckpt_freq: int = 2
|
102 |
+
test_freq: int = 1
|
103 |
+
use_amp: bool = True
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class DefaultConfig:
|
108 |
+
dataset: DataConfig = DataConfig()
|
109 |
+
stage1: Stage1Config = Stage1Config()
|
110 |
+
stage2: Stage2Config = Stage2Config()
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class FineTuningConfig:
|
115 |
+
dataset: DataConfig = DataConfig()
|
116 |
+
stage1: Stage1Config = Stage1Config()
|
117 |
+
stage2: Stage2Config = Stage2Config()
|
118 |
+
optimizer: OptConfig = OptConfig()
|
119 |
+
experiment: ExpConfig = ExpConfig()
|
120 |
+
|
121 |
+
|
122 |
+
def get_base_config(use_default=True):
|
123 |
+
return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
|
dalle/utils/sampling.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from typing import Optional
|
9 |
+
from tqdm import tqdm
|
10 |
+
from torch.nn import functional as F
|
11 |
+
import streamlit as st
|
12 |
+
|
13 |
+
def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
|
14 |
+
if k is None:
|
15 |
+
return logits
|
16 |
+
else:
|
17 |
+
v, ix = torch.topk(logits, k)
|
18 |
+
out = logits.clone()
|
19 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
|
24 |
+
if p is None:
|
25 |
+
return probs
|
26 |
+
else:
|
27 |
+
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
|
28 |
+
cum_probs = torch.cumsum(sorted_probs, dim=-1)
|
29 |
+
|
30 |
+
sorted_idx_remove_cond = cum_probs >= p
|
31 |
+
|
32 |
+
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
|
33 |
+
sorted_idx_remove_cond[..., 0] = 0
|
34 |
+
|
35 |
+
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
|
36 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
37 |
+
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
|
38 |
+
return norm_probs
|
39 |
+
|
40 |
+
|
41 |
+
def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
|
42 |
+
device = inputs.device
|
43 |
+
if mode == '1d':
|
44 |
+
B, N = inputs.shape
|
45 |
+
xs_pos = torch.arange(N, device=device).repeat((B, 1))
|
46 |
+
elif mode == '2d':
|
47 |
+
B, H, W = inputs.shape
|
48 |
+
xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
|
49 |
+
xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
|
50 |
+
xs_pos = (xs_pos_h, xs_pos_w)
|
51 |
+
else:
|
52 |
+
raise ValueError('%s positional encoding invalid' % mode)
|
53 |
+
return xs_pos
|
54 |
+
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def sampling(model: torch.nn.Module,
|
58 |
+
tokens: torch.LongTensor,
|
59 |
+
top_k: Optional[float] = None,
|
60 |
+
top_p: Optional[float] = None,
|
61 |
+
softmax_temperature: float = 1.0,
|
62 |
+
is_tqdm: bool = True,
|
63 |
+
use_fp16: bool = True,
|
64 |
+
max_seq_len: int = 256) -> torch.LongTensor:
|
65 |
+
code = None
|
66 |
+
past = None
|
67 |
+
|
68 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
69 |
+
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
|
70 |
+
|
71 |
+
#my_bar = st.progress(0)
|
72 |
+
|
73 |
+
for cnt, h in enumerate(pbar):
|
74 |
+
if code is None:
|
75 |
+
code_ = None
|
76 |
+
pos_enc_code_ = None
|
77 |
+
else:
|
78 |
+
code_ = code.clone().detach()
|
79 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
80 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
81 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
82 |
+
|
83 |
+
logits, present = model.sampling(images=code_,
|
84 |
+
texts=tokens,
|
85 |
+
pos_images=pos_enc_code_,
|
86 |
+
pos_texts=pos_enc_tokens,
|
87 |
+
use_fp16=use_fp16,
|
88 |
+
past=past)
|
89 |
+
logits = logits.to(dtype=torch.float32)
|
90 |
+
logits = logits / softmax_temperature
|
91 |
+
|
92 |
+
present = torch.stack(present).clone().detach()
|
93 |
+
if past is None:
|
94 |
+
past = [present]
|
95 |
+
else:
|
96 |
+
past.append(present)
|
97 |
+
|
98 |
+
logits = cutoff_topk_logits(logits, top_k)
|
99 |
+
probs = F.softmax(logits, dim=-1)
|
100 |
+
probs = cutoff_topp_probs(probs, top_p)
|
101 |
+
|
102 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
103 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
104 |
+
|
105 |
+
#print(cnt/max_seq_len)
|
106 |
+
if(st.session_state.page != 0):
|
107 |
+
break
|
108 |
+
|
109 |
+
st.session_state.bar.progress(cnt/max_seq_len)
|
110 |
+
|
111 |
+
#my_bar.progress(cnt/max_seq_len)
|
112 |
+
|
113 |
+
del past
|
114 |
+
return code
|
115 |
+
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def sampling_igpt(model: torch.nn.Module,
|
119 |
+
sos: torch.FloatTensor,
|
120 |
+
top_k: Optional[float] = None,
|
121 |
+
top_p: Optional[float] = None,
|
122 |
+
softmax_temperature: float = 1.0,
|
123 |
+
is_tqdm: bool = True,
|
124 |
+
use_fp16: bool = True,
|
125 |
+
max_seq_len: int = 256) -> torch.LongTensor:
|
126 |
+
code = None
|
127 |
+
past = None
|
128 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
129 |
+
|
130 |
+
for cnt, h in enumerate(pbar):
|
131 |
+
if code is None:
|
132 |
+
code_ = None
|
133 |
+
pos_enc_code_ = None
|
134 |
+
else:
|
135 |
+
code_ = code.clone().detach()
|
136 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
137 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
138 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
139 |
+
|
140 |
+
logits, present = model.sampling(sos=sos,
|
141 |
+
codes=code_,
|
142 |
+
pos_codes=pos_enc_code_,
|
143 |
+
use_fp16=use_fp16,
|
144 |
+
past=past)
|
145 |
+
logits = logits.to(dtype=torch.float32)
|
146 |
+
logits = logits / softmax_temperature
|
147 |
+
|
148 |
+
present = torch.stack(present).clone().detach()
|
149 |
+
if past is None:
|
150 |
+
past = [present]
|
151 |
+
else:
|
152 |
+
past.append(present)
|
153 |
+
|
154 |
+
logits = cutoff_topk_logits(logits, top_k)
|
155 |
+
probs = F.softmax(logits, dim=-1)
|
156 |
+
probs = cutoff_topp_probs(probs, top_p)
|
157 |
+
|
158 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
159 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
160 |
+
|
161 |
+
del past
|
162 |
+
return code
|
dalle/utils/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# minDALL-E
|
3 |
+
# Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import urllib
|
10 |
+
import hashlib
|
11 |
+
import tarfile
|
12 |
+
import torch
|
13 |
+
import clip
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
|
20 |
+
def set_seed(seed: int):
|
21 |
+
random.seed(seed)
|
22 |
+
np.random.seed(seed)
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed_all(seed)
|
25 |
+
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def clip_score(prompt: str,
|
29 |
+
images: np.ndarray,
|
30 |
+
model_clip: torch.nn.Module,
|
31 |
+
preprocess_clip,
|
32 |
+
device: str) -> np.ndarray:
|
33 |
+
images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
|
34 |
+
images = torch.stack(images, dim=0).to(device=device)
|
35 |
+
texts = clip.tokenize(prompt).to(device=device)
|
36 |
+
texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
|
37 |
+
|
38 |
+
image_features = model_clip.encode_image(images)
|
39 |
+
text_features = model_clip.encode_text(texts)
|
40 |
+
|
41 |
+
scores = F.cosine_similarity(image_features, text_features).squeeze()
|
42 |
+
rank = torch.argsort(scores, descending=True).cpu().numpy()
|
43 |
+
return rank
|
44 |
+
|
45 |
+
|
46 |
+
def download(url: str, root: str) -> str:
|
47 |
+
os.makedirs(root, exist_ok=True)
|
48 |
+
filename = os.path.basename(url)
|
49 |
+
pathname = filename[:-len('.tar.gz')]
|
50 |
+
|
51 |
+
expected_md5 = url.split("/")[-2]
|
52 |
+
download_target = os.path.join(root, filename)
|
53 |
+
result_path = os.path.join(root, pathname)
|
54 |
+
|
55 |
+
if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
|
56 |
+
return result_path
|
57 |
+
|
58 |
+
with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
|
59 |
+
with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
|
60 |
+
unit_divisor=1024) as loop:
|
61 |
+
while True:
|
62 |
+
buffer = source.read(8192)
|
63 |
+
if not buffer:
|
64 |
+
break
|
65 |
+
|
66 |
+
output.write(buffer)
|
67 |
+
loop.update(len(buffer))
|
68 |
+
|
69 |
+
if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
|
70 |
+
raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
|
71 |
+
|
72 |
+
with tarfile.open(download_target, 'r:gz') as f:
|
73 |
+
pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
|
74 |
+
for member in pbar:
|
75 |
+
pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
|
76 |
+
f.extract(member=member, path=root)
|
77 |
+
|
78 |
+
return result_path
|
79 |
+
|
80 |
+
|
81 |
+
def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
|
82 |
+
if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
|
83 |
+
return download(url_or_path, root)
|
84 |
+
return url_or_path
|
page/generate.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
from numpy.core.defchararray import lower
|
3 |
+
import streamlit as st
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import streamlit as st
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import os, random, time
|
10 |
+
from utils import footer, generate, drawGrid
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
mode = "ai"
|
14 |
+
#mode = "dummy"
|
15 |
+
|
16 |
+
def app():
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
st.title('AI-Generated Architecture')
|
21 |
+
|
22 |
+
st.subheader('Describe a building, interior, or other architecture you would like to see.')
|
23 |
+
|
24 |
+
#Modern architecture museum with black brick and large windows.
|
25 |
+
prompt = st.text_input(label="",value="Modern architecture museum with black brick and large windows.")
|
26 |
+
|
27 |
+
st.text("")
|
28 |
+
|
29 |
+
|
30 |
+
with st.expander("Having trouble thinking of something? Click here to view examples."):
|
31 |
+
st.write("""
|
32 |
+
• Modern architecture museum with black brick and large windows.\n
|
33 |
+
• A prosaic, simple architecture.\n
|
34 |
+
• An urban, post-modern architecture with concrete and steel.\n
|
35 |
+
• A sleek urban interior design.
|
36 |
+
""")
|
37 |
+
|
38 |
+
st.text("")
|
39 |
+
|
40 |
+
crazy = st.slider('Temperature. This controls how "crazy" generated images are, where 0 is the least crazy.', 0.0, 1.0, 0.75)
|
41 |
+
k = st.slider('Top K. The higher the value, the higher quality the results tend to be at the cost of extra processing time.', 1, 10, 1)
|
42 |
+
|
43 |
+
if( 'results' not in st.session_state ):
|
44 |
+
st.session_state.results = []
|
45 |
+
|
46 |
+
holder = st.empty()
|
47 |
+
startButton = holder.button("Start")
|
48 |
+
|
49 |
+
already = []
|
50 |
+
|
51 |
+
print("-0-")
|
52 |
+
|
53 |
+
if startButton or hasattr(st.session_state, 'load_state'):
|
54 |
+
|
55 |
+
with st.spinner("Generating..."):
|
56 |
+
|
57 |
+
print("-1-")
|
58 |
+
|
59 |
+
holder.empty()
|
60 |
+
|
61 |
+
nextButton = holder.button("finished generating images")
|
62 |
+
st.session_state.load_state = True
|
63 |
+
|
64 |
+
placeholder = st.empty()
|
65 |
+
second = st.empty()
|
66 |
+
|
67 |
+
with second.container():
|
68 |
+
drawGrid()
|
69 |
+
|
70 |
+
while len(st.session_state.results) <= 15:
|
71 |
+
|
72 |
+
print("Length "+str(len(st.session_state.results)))
|
73 |
+
|
74 |
+
with placeholder.container():
|
75 |
+
|
76 |
+
st.session_state.bar = placeholder.progress(0)
|
77 |
+
|
78 |
+
|
79 |
+
if(nextButton):
|
80 |
+
st.session_state.page = 1
|
81 |
+
break
|
82 |
+
|
83 |
+
generate(prompt,crazy,k)
|
84 |
+
|
85 |
+
with second.container():
|
86 |
+
drawGrid()
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
#placeholder.empty()
|
92 |
+
|
93 |
+
#st.session_state.bar = placeholder.progress(0)
|
94 |
+
#drawGrid(placeholder)
|
95 |
+
|
96 |
+
|
97 |
+
|
page/reduce.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
from numpy.core.defchararray import lower
|
3 |
+
import streamlit as st
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from zipfile import ZipFile
|
7 |
+
import io
|
8 |
+
import os
|
9 |
+
|
10 |
+
def dell(ix):
|
11 |
+
print("!!!!")
|
12 |
+
st.session_state.results.pop(ix)
|
13 |
+
|
14 |
+
|
15 |
+
def app():
|
16 |
+
|
17 |
+
st.title('AI-Generated Architecture')
|
18 |
+
|
19 |
+
st.subheader('Choose which images you would like to remove from your working set.')
|
20 |
+
|
21 |
+
os.chdir(r"temp/")
|
22 |
+
all_files = os.listdir()
|
23 |
+
for f in all_files:
|
24 |
+
os.remove(f)
|
25 |
+
|
26 |
+
# create a ZipFile object
|
27 |
+
zipObj = ZipFile('ai_architecture.zip', 'w')
|
28 |
+
# Add multiple files to the zip
|
29 |
+
for ix,file in enumerate( st.session_state.results ):
|
30 |
+
file['image'].save("temp/"+str(ix)+".jpeg")
|
31 |
+
zipObj.write("temp/"+str(ix)+".jpeg")
|
32 |
+
|
33 |
+
zipObj.close()
|
34 |
+
|
35 |
+
st.download_button(
|
36 |
+
label="Download images as zip",
|
37 |
+
data=open('ai_architecture.zip', 'rb'),
|
38 |
+
file_name='ai_architecture.zip',
|
39 |
+
mime='application/zip'
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
deleteButtons = []
|
44 |
+
|
45 |
+
for ix,result in enumerate( st.session_state.results ):
|
46 |
+
|
47 |
+
with st.container():
|
48 |
+
col1,col2 = st.columns(2)
|
49 |
+
|
50 |
+
with col1:
|
51 |
+
st.image(result['image'])
|
52 |
+
with col2:
|
53 |
+
st.button("delete ", key=ix, on_click=dell, kwargs=dict(ix=ix) )
|
54 |
+
|
55 |
+
m = st.markdown("""
|
56 |
+
<hr />""", unsafe_allow_html=True)
|
57 |
+
|
58 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip==0.2.0
|
2 |
+
Cython==0.29.30
|
3 |
+
clip_anytorch==2.4.0
|
4 |
+
htbuilder==0.6.0
|
5 |
+
iteration_utilities==0.11.0
|
6 |
+
numpy==1.22.4
|
7 |
+
omegaconf==2.2.2
|
8 |
+
pages==0.3
|
9 |
+
pandas==1.4.2
|
10 |
+
Pillow==9.2.0
|
11 |
+
pytorch_lightning==1.6.3
|
12 |
+
ruclip==0.0.1
|
13 |
+
rudalle==1.1.3
|
14 |
+
streamlit==1.10.0
|
15 |
+
tokenizers==0.12.1
|
16 |
+
torch==1.8.0
|
17 |
+
torchvision==0.9.0
|
18 |
+
tqdm==4.64.0
|
streamlit_app.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import os, random, time
|
6 |
+
|
7 |
+
from utils import footer
|
8 |
+
from page import generate, reduce
|
9 |
+
|
10 |
+
|
11 |
+
if( hasattr(st.session_state, 'page') == False):
|
12 |
+
st.session_state.page = 0
|
13 |
+
|
14 |
+
if( hasattr(st.session_state, 'results') == False):
|
15 |
+
st.session_state.results = []
|
16 |
+
|
17 |
+
p1 = st.empty()
|
18 |
+
p2 = st.empty()
|
19 |
+
p3 = st.empty()
|
20 |
+
|
21 |
+
|
22 |
+
st.session_state.stop = False
|
23 |
+
st.session_state.progress = 0
|
24 |
+
st.session_state.regenerate = False
|
25 |
+
|
26 |
+
if(st.session_state.page == 0):
|
27 |
+
p2.empty()
|
28 |
+
p3.empty()
|
29 |
+
with p1.container():
|
30 |
+
generate.app()
|
31 |
+
|
32 |
+
|
33 |
+
if(st.session_state.page == 1):
|
34 |
+
p1.empty()
|
35 |
+
p3.empty()
|
36 |
+
with p2.container():
|
37 |
+
reduce.app()
|
38 |
+
|
39 |
+
if(st.session_state.page == 2):
|
40 |
+
p1.empty()
|
41 |
+
p2.empty()
|
42 |
+
with p3.container():
|
43 |
+
st.write("This 333")
|
44 |
+
startButton = st.button("S3")
|
45 |
+
if startButton:
|
46 |
+
st.session_state.page = 0
|
47 |
+
|
48 |
+
footer()
|
utils.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
|
2 |
+
from htbuilder.units import percent, px
|
3 |
+
from htbuilder.funcs import rgba, rgb
|
4 |
+
import streamlit as st
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import argparse
|
8 |
+
import clip
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from dalle.models import Dalle
|
12 |
+
from dalle.utils.utils import set_seed, clip_score
|
13 |
+
|
14 |
+
def link(link, text, **style):
|
15 |
+
return a(_href=link, _target="_blank", style=styles(**style))(text)
|
16 |
+
|
17 |
+
def layout(*args):
|
18 |
+
|
19 |
+
style = """
|
20 |
+
<style>
|
21 |
+
# MainMenu {visibility: hidden;}
|
22 |
+
footer {visibility: hidden;}
|
23 |
+
.stApp { bottom: 105px; }
|
24 |
+
</style>
|
25 |
+
"""
|
26 |
+
|
27 |
+
style_div = styles(
|
28 |
+
position="fixed",
|
29 |
+
left=0,
|
30 |
+
bottom=0,
|
31 |
+
margin=px(0, 0, 0, 0),
|
32 |
+
width=percent(100),
|
33 |
+
color="black",
|
34 |
+
text_align="center",
|
35 |
+
height="auto",
|
36 |
+
opacity=1
|
37 |
+
)
|
38 |
+
|
39 |
+
style_hr = styles(
|
40 |
+
display="block",
|
41 |
+
margin=px(8, 8, "auto", "auto"),
|
42 |
+
border_style="inset",
|
43 |
+
border_width=px(2)
|
44 |
+
)
|
45 |
+
|
46 |
+
body = p()
|
47 |
+
foot = div(
|
48 |
+
style=style_div
|
49 |
+
)(
|
50 |
+
hr(
|
51 |
+
style=style_hr
|
52 |
+
),
|
53 |
+
body
|
54 |
+
)
|
55 |
+
|
56 |
+
st.markdown(style, unsafe_allow_html=True)
|
57 |
+
|
58 |
+
for arg in args:
|
59 |
+
if isinstance(arg, str):
|
60 |
+
body(arg)
|
61 |
+
|
62 |
+
elif isinstance(arg, HtmlElement):
|
63 |
+
body(arg)
|
64 |
+
|
65 |
+
st.markdown(str(foot), unsafe_allow_html=True)
|
66 |
+
|
67 |
+
def footer():
|
68 |
+
myargs = [
|
69 |
+
"Created by ",
|
70 |
+
link("https://jonathanmalott.com", "Jonathan Malott"),
|
71 |
+
br(),
|
72 |
+
link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems"),
|
73 |
+
" Grand Challenge",
|
74 |
+
", The University of Texas at Austin.",
|
75 |
+
" Advised by Dr. Junfeng Jiao.",
|
76 |
+
br(),
|
77 |
+
br(),
|
78 |
+
]
|
79 |
+
layout(*myargs)
|
80 |
+
|
81 |
+
#footer()
|
82 |
+
|
83 |
+
def generate(prompt,crazy,k):
|
84 |
+
|
85 |
+
device = 'cpu'
|
86 |
+
print("-2-")
|
87 |
+
model = Dalle.from_pretrained('.cache/minDALL-E/1.3B') # This will automatically download the pretrained model.
|
88 |
+
print("-3-")
|
89 |
+
model.to(device=device)
|
90 |
+
num_candidates = 1
|
91 |
+
|
92 |
+
images = []
|
93 |
+
|
94 |
+
set_seed(np.random.randint(0,10000))
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
# Sampling
|
99 |
+
images = model.sampling(prompt=prompt,
|
100 |
+
top_k=2048,
|
101 |
+
top_p=None,
|
102 |
+
softmax_temperature=crazy,
|
103 |
+
num_candidates=num_candidates,
|
104 |
+
device=device).cpu().numpy()
|
105 |
+
images = np.transpose(images, (0, 2, 3, 1))
|
106 |
+
|
107 |
+
# CLIP Re-ranking
|
108 |
+
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
109 |
+
model_clip.to(device=device)
|
110 |
+
rank = clip_score(prompt=prompt,
|
111 |
+
images=images,
|
112 |
+
model_clip=model_clip,
|
113 |
+
preprocess_clip=preprocess_clip,
|
114 |
+
device=device)
|
115 |
+
|
116 |
+
result = images[rank]
|
117 |
+
|
118 |
+
item = {}
|
119 |
+
item['prompt'] = prompt
|
120 |
+
item['crazy'] = crazy
|
121 |
+
item['k'] = k
|
122 |
+
item['image'] = Image.fromarray((result*255).astype(np.uint8))
|
123 |
+
st.session_state.results.append(item)
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
def drawGrid():
|
128 |
+
master = {}
|
129 |
+
order = 0
|
130 |
+
|
131 |
+
#print(st.session_state.results)
|
132 |
+
|
133 |
+
for r in st.session_state.results[::-1]:
|
134 |
+
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
|
135 |
+
|
136 |
+
if(_txt not in master):
|
137 |
+
master[_txt] = [r]
|
138 |
+
order += 1
|
139 |
+
else:
|
140 |
+
master[_txt].append(r)
|
141 |
+
|
142 |
+
|
143 |
+
for m in master:
|
144 |
+
#with placeholder.container():
|
145 |
+
|
146 |
+
txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
|
147 |
+
st.subheader(txt)
|
148 |
+
col1, col2, col3 = st.columns(3)
|
149 |
+
|
150 |
+
for ix, item in enumerate(master[m]):
|
151 |
+
if ix % 3 == 0:
|
152 |
+
with col1:
|
153 |
+
st.image(item["image"])
|
154 |
+
if ix % 3 == 1:
|
155 |
+
with col2:
|
156 |
+
st.image(item["image"])
|
157 |
+
if ix % 3 == 2:
|
158 |
+
with col3:
|
159 |
+
st.image(item["image"])
|
160 |
+
|