|
import glob |
|
import os |
|
import sys |
|
from itertools import product |
|
from pathlib import Path |
|
from typing import Literal, List, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from omegaconf import OmegaConf |
|
from pytorch_lightning import seed_everything |
|
from torch import Tensor |
|
from torchvision.utils import save_image |
|
from tqdm import tqdm |
|
|
|
from scripts.make_samples import get_parser, load_model_and_dset |
|
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder |
|
from taming.data.helper_types import BoundingBox, Annotation |
|
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset |
|
from taming.models.cond_transformer import Net2NetTransformer |
|
|
|
seed_everything(42424242) |
|
device: Literal['cuda', 'cpu'] = 'cuda' |
|
first_stage_factor = 16 |
|
trained_on_res = 256 |
|
|
|
|
|
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int): |
|
assert 0 <= coord < coord_max |
|
coord_desired_center = (coord_window - 1) // 2 |
|
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window) |
|
|
|
|
|
def get_crop_coordinates(x: int, y: int) -> BoundingBox: |
|
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0] |
|
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH |
|
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT |
|
w = first_stage_factor / WIDTH |
|
h = first_stage_factor / HEIGHT |
|
return x0, y0, w, h |
|
|
|
|
|
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor: |
|
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0] |
|
x0 = _helper(predict_x, WIDTH, first_stage_factor) |
|
y0 = _helper(predict_y, HEIGHT, first_stage_factor) |
|
no_images = z_indices.shape[0] |
|
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1)) |
|
cut_out_2 = z_indices[:, predict_y, x0:predict_x] |
|
return torch.cat((cut_out_1, cut_out_2), dim=1) |
|
|
|
|
|
@torch.no_grad() |
|
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset, |
|
conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int, |
|
temperature: float, top_k: int) -> Tensor: |
|
x_max, y_max = desired_z_shape[1], desired_z_shape[0] |
|
|
|
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations] |
|
|
|
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res)) |
|
if not recompute_conditional: |
|
crop_coordinates = get_crop_coordinates(0, 0) |
|
conditional_indices = conditional_builder.build(annotations, crop_coordinates) |
|
c_indices = conditional_indices.to(device).repeat(no_samples, 1) |
|
z_indices = torch.zeros((no_samples, 0), device=device).long() |
|
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature, |
|
sample=True, top_k=top_k) |
|
else: |
|
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long() |
|
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max): |
|
crop_coordinates = get_crop_coordinates(predict_x, predict_y) |
|
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y) |
|
conditional_indices = conditional_builder.build(annotations, crop_coordinates) |
|
c_indices = conditional_indices.to(device).repeat(no_samples, 1) |
|
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k) |
|
output_indices[:, predict_y, predict_x] = new_index[:, -1] |
|
z_shape = ( |
|
no_samples, |
|
model.first_stage_model.quantize.e_dim, |
|
desired_z_shape[0], |
|
desired_z_shape[1] |
|
) |
|
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5 |
|
x_sample = x_sample.to('cpu') |
|
|
|
plotter = conditional_builder.plot |
|
figure_size = (x_sample.shape[2], x_sample.shape[3]) |
|
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.)) |
|
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size) |
|
return torch.cat((x_sample, plot.unsqueeze(0))) |
|
|
|
|
|
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]): |
|
if not resolution_str.count(',') == 1: |
|
raise ValueError("Give resolution as in 'height,width'") |
|
res_h, res_w = resolution_str.split(',') |
|
res_h = max(int(res_h), trained_on_res) |
|
res_w = max(int(res_w), trained_on_res) |
|
z_h = int(round(res_h/first_stage_factor)) |
|
z_w = int(round(res_w/first_stage_factor)) |
|
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor) |
|
|
|
|
|
def add_arg_to_parser(parser): |
|
parser.add_argument( |
|
"-R", |
|
"--resolution", |
|
type=str, |
|
default='256,256', |
|
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'", |
|
) |
|
parser.add_argument( |
|
"-C", |
|
"--conditional", |
|
type=str, |
|
default='objects_bbox', |
|
help=f"objects_bbox or objects_center_points", |
|
) |
|
parser.add_argument( |
|
"-N", |
|
"--n_samples_per_layout", |
|
type=int, |
|
default=4, |
|
help=f"how many samples to generate per layout", |
|
) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.path.append(os.getcwd()) |
|
|
|
parser = get_parser() |
|
parser = add_arg_to_parser(parser) |
|
|
|
opt, unknown = parser.parse_known_args() |
|
|
|
ckpt = None |
|
if opt.resume: |
|
if not os.path.exists(opt.resume): |
|
raise ValueError("Cannot find {}".format(opt.resume)) |
|
if os.path.isfile(opt.resume): |
|
paths = opt.resume.split("/") |
|
try: |
|
idx = len(paths)-paths[::-1].index("logs")+1 |
|
except ValueError: |
|
idx = -2 |
|
logdir = "/".join(paths[:idx]) |
|
ckpt = opt.resume |
|
else: |
|
assert os.path.isdir(opt.resume), opt.resume |
|
logdir = opt.resume.rstrip("/") |
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
|
print(f"logdir:{logdir}") |
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) |
|
opt.base = base_configs+opt.base |
|
|
|
if opt.config: |
|
if type(opt.config) == str: |
|
opt.base = [opt.config] |
|
else: |
|
opt.base = [opt.base[-1]] |
|
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base] |
|
cli = OmegaConf.from_dotlist(unknown) |
|
if opt.ignore_base_data: |
|
for config in configs: |
|
if hasattr(config, "data"): |
|
del config["data"] |
|
config = OmegaConf.merge(*configs, cli) |
|
desired_z_shape, desired_resolution = get_resolution(opt.resolution) |
|
conditional = opt.conditional |
|
|
|
print(ckpt) |
|
gpu = True |
|
eval_mode = True |
|
show_config = False |
|
if show_config: |
|
print(OmegaConf.to_container(config)) |
|
|
|
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode) |
|
print(f"Global step: {global_step}") |
|
|
|
data_loader = dsets.val_dataloader() |
|
print(dsets.datasets["validation"].conditional_builders) |
|
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional] |
|
|
|
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}") |
|
outdir.mkdir(exist_ok=True, parents=True) |
|
print("Writing samples to ", outdir) |
|
|
|
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader)) |
|
for batch_no, batch in p_bar_1: |
|
save_img: Optional[Tensor] = None |
|
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size): |
|
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder, |
|
opt.n_samples_per_layout, opt.temperature, opt.top_k) |
|
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1) |
|
|