sjc / app.py
amankishore's picture
app with gradio
10528ca
raw
history blame
8.07 kB
import numpy as np
import time
from pathlib import Path
import torch
import imageio
from my.utils import tqdm
from my.utils.seed import seed_everything
from run_img_sampling import SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig
from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
from run_sjc import render_one_view, tsr_stats
import gradio as gr
import gc
device_glb = torch.device("cuda")
def vis_routine(y, depth):
pane = nerf_vis(y, depth, final_H=256)
im = torch_samps_to_imgs(y)[0]
depth = depth.cpu().numpy()
return pane, im, depth
with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
# title
gr.Markdown('[Score Jacobian Chaining](https://github.com/pals-ttic/sjc) Lifting Pretrained 2D Diffusion Models for 3D Generation')
# inputs
prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=10000, step=100)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
button = gr.Button('Generate')
# outputs
image = gr.Image(label="image", visible=True)
depth = gr.Image(label="depth", visible=True)
video = gr.Video(label="video", visible=False)
logs = gr.Textbox(label="logging")
def submit(prompt, iters, seed):
start_t = time.time()
seed_everything(seed)
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
poser = pose.make()
sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
model = sd_model.make()
vox = VoxConfig(
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
blend_bg_texture=True, bg_texture_hw=4,
bbox_len=1.0)
vox = vox.make()
lr = 0.05
n_steps = iters
emptiness_scale = 10
emptiness_weight = 10000
emptiness_step = 0.5
emptiness_multiplier = 20.0
depth_weight = 0
var_red = True
assert model.samps_centered()
_, target_H, target_W = model.data_shape()
bs = 1
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
H, W = poser.H, poser.W
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
ts = model.us[30:-10]
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
with tqdm(total=n_steps) as pbar:
for i in range(n_steps):
p = f"{prompt_prefixes[i]} {model.prompt}"
score_conds = model.prompts_emb([p])
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
if isinstance(model, StableDiffusion):
pass
else:
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
opt.zero_grad()
with torch.no_grad():
chosen_σs = np.random.choice(ts, bs, replace=False)
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
# chosen_σs = us[i]
noise = torch.randn(bs, *y.shape[1:], device=model.device)
zs = y + chosen_σs * noise
Ds = model.denoise(zs, chosen_σs, **score_conds)
if var_red:
grad = (Ds - y) / chosen_σs
else:
grad = (Ds - zs) / chosen_σs
grad = grad.mean(0, keepdim=True)
y.backward(-grad, retain_graph=True)
if depth_weight > 0:
center_depth = depth[7:-7, 7:-7]
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
center_depth_mean = center_depth.mean()
depth_diff = center_depth_mean - border_depth_mean
depth_loss = - torch.log(depth_diff + 1e-12)
depth_loss = depth_weight * depth_loss
depth_loss.backward(retain_graph=True)
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
emptiness_loss = emptiness_weight * emptiness_loss
if emptiness_step * n_steps <= i:
emptiness_loss *= emptiness_multiplier
emptiness_loss.backward()
opt.step()
# metric.put_scalars()
if every(pbar, percent=1):
with torch.no_grad():
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
# TODO: Output pane, img and depth to Gradio
pbar.update()
pbar.set_description(p)
yield {
image: gr.update(value=img, visible=True),
depth: gr.update(value=depth, visible=True),
video: gr.update(visible=False),
logs: str(tsr_stats(y)),
}
# TODO: Save Checkpoint
ckpt = vox.state_dict()
H, W = poser.H, poser.W
vox.eval()
K, poses = poser.sample_test(100)
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
num_imgs = len(poses)
for i in (pbar := tqdm(range(num_imgs))):
pose = poses[i]
y, depth = render_one_view(vox, aabb, H, W, K, pose)
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
# Save img to output
img.save(f"output/{i}.png")
yield {
image: gr.update(value=img, visible=True),
depth: gr.update(value=depth, visible=True),
video: gr.update(visible=False),
logs: str(tsr_stats(y)),
}
output_video = "view_seq.mp4"
def export_movie(seqs, fname, fps=30):
fname = Path(fname)
if fname.suffix == "":
fname = fname.with_suffix(".mp4")
writer = imageio.get_writer(fname, fps=fps)
for img in seqs:
writer.append_data(img)
writer.close()
def stitch_vis(save_fn, img_fnames, fps=10):
figs = [imageio.imread(fn) for fn in img_fnames]
export_movie(figs, save_fn, fps)
stitch_vis(output_video, [f"output/{i}.png" for i in range(num_imgs)])
end_t = time.time()
yield {
image: gr.update(value=img, visible=False),
depth: gr.update(value=depth, visible=False),
video: gr.update(value=output_video, visible=True),
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
}
button.click(
submit,
[prompt, iters, seed],
[image, depth, video, logs]
)
# concurrency_count: only allow ONE running progress, else GPU will OOM.
demo.queue(concurrency_count=1)
demo.launch()