Spaces:
Running
Running
import os | |
import sys | |
import torch | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from omegaconf import OmegaConf | |
import subprocess | |
from tqdm import tqdm | |
import requests | |
import einops | |
import math | |
import random | |
import pytorch_lightning as pl | |
def download_file(url, filename): | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
block_size = 1024 | |
with open(filename, 'wb') as file, tqdm( | |
desc=filename, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for data in response.iter_content(block_size): | |
size = file.write(data) | |
progress_bar.update(size) | |
def setup_environment(): | |
if not os.path.exists("CCSR"): | |
print("Cloning CCSR repository...") | |
subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"]) | |
os.chdir("CCSR") | |
sys.path.append(os.getcwd()) | |
os.makedirs("weights", exist_ok=True) | |
if not os.path.exists("weights/real-world_ccsr.ckpt"): | |
print("Downloading model checkpoint...") | |
download_file( | |
"https://huggingface.co/camenduru/CCSR/resolve/main/real-world_ccsr.ckpt", | |
"weights/real-world_ccsr.ckpt" | |
) | |
else: | |
print("Model checkpoint already exists. Skipping download.") | |
setup_environment() | |
from ldm.xformers_state import disable_xformers | |
from model.q_sampler import SpacedSampler | |
from model.ccsr_stage1 import ControlLDM | |
from utils.common import instantiate_from_config, load_state_dict | |
from utils.image import auto_resize | |
config = OmegaConf.load("configs/model/ccsr_stage2.yaml") | |
model = instantiate_from_config(config) | |
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu") | |
load_state_dict(model, ckpt, strict=True) | |
model.freeze() | |
model.to("cuda") | |
def process( | |
control_img: Image.Image, | |
num_samples: int, | |
sr_scale: float, | |
strength: float, | |
positive_prompt: str, | |
negative_prompt: str, | |
cfg_scale: float, | |
steps: int, | |
use_color_fix: bool, | |
seed: int, | |
tile_diffusion: bool, | |
tile_diffusion_size: int, | |
tile_diffusion_stride: int, | |
tile_vae: bool, | |
vae_encoder_tile_size: int, | |
vae_decoder_tile_size: int | |
): | |
print( | |
f"control image shape={control_img.size}\n" | |
f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n" | |
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n" | |
f"cdf scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n" | |
f"seed={seed}\n" | |
f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}" | |
f"tile_vae={tile_vae}, vae_encoder_tile_size={vae_encoder_tile_size}, vae_decoder_tile_size={vae_decoder_tile_size}" | |
) | |
pl.seed_everything(seed) | |
# Resize lr | |
if sr_scale != 1: | |
control_img = control_img.resize( | |
tuple(math.ceil(x * sr_scale) for x in control_img.size), | |
Image.BICUBIC | |
) | |
input_size = control_img.size | |
# Resize the lr image | |
if not tile_diffusion: | |
control_img = auto_resize(control_img, 512) | |
else: | |
control_img = auto_resize(control_img, tile_diffusion_size) | |
# Resize image to be multiples of 64 | |
control_img = control_img.resize( | |
tuple((s // 64 + 1) * 64 for s in control_img.size), Image.LANCZOS | |
) | |
control_img = np.array(control_img) | |
# Convert to tensor (NCHW, [0,1]) | |
control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1) | |
control = einops.rearrange(control, "n h w c -> n c h w").contiguous() | |
height, width = control.size(-2), control.size(-1) | |
model.control_scales = [strength] * 13 | |
sampler = SpacedSampler(model, var_type="fixed_small") | |
preds = [] | |
for _ in tqdm(range(num_samples)): | |
shape = (1, 4, height // 8, width // 8) | |
x_T = torch.randn(shape, device=model.device, dtype=torch.float32) | |
if not tile_diffusion and not tile_vae: | |
samples = sampler.sample_ccsr( | |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
cfg_scale=cfg_scale, | |
color_fix_type="adain" if use_color_fix else "none" | |
) | |
else: | |
if tile_vae: | |
model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size) | |
if tile_diffusion: | |
samples = sampler.sample_with_tile_ccsr( | |
tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride, | |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
cfg_scale=cfg_scale, | |
color_fix_type="adain" if use_color_fix else "none" | |
) | |
else: | |
samples = sampler.sample_ccsr( | |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
cfg_scale=cfg_scale, | |
color_fix_type="adain" if use_color_fix else "none" | |
) | |
x_samples = samples.clamp(0, 1) | |
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) | |
img = Image.fromarray(x_samples[0, ...]).resize(input_size, Image.LANCZOS) | |
preds.append(np.array(img)) | |
return preds | |
def update_output_resolution(image, scale): | |
if image is not None: | |
width, height = image.size | |
return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}" | |
return "Upload an image to see the output resolution" | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
with gr.Row(): | |
sr_scale = gr.Slider(label="SR Scale", minimum=1, maximum=8, value=4, step=0.1, info="Super-resolution scale factor.") | |
output_resolution = gr.Markdown("Upload an image to see the output resolution") | |
with gr.Row(): | |
run_button = gr.Button(value="Run") | |
with gr.Accordion("Options", open=False): | |
with gr.Column(): | |
num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1) | |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
positive_prompt = gr.Textbox(label="Positive Prompt", value="") | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" | |
) | |
cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1) | |
use_color_fix = gr.Checkbox(label="Use Color Correction", value=True) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) | |
tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False) | |
tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256) | |
tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128) | |
tile_vae = gr.Checkbox(label="Tile VAE", value=True) | |
vae_encoder_tile_size = gr.Slider(label="Encoder tile size", minimum=512, maximum=5000, value=1024, step=256) | |
vae_decoder_tile_size = gr.Slider(label="Decoder tile size", minimum=64, maximum=512, value=224, step=128) | |
with gr.Column(): | |
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery") | |
inputs = [ | |
input_image, | |
num_samples, | |
sr_scale, | |
strength, | |
positive_prompt, | |
negative_prompt, | |
cfg_scale, | |
steps, | |
use_color_fix, | |
seed, | |
tile_diffusion, | |
tile_diffusion_size, | |
tile_diffusion_stride, | |
tile_vae, | |
vae_encoder_tile_size, | |
vae_decoder_tile_size, | |
] | |
run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) | |
# Update output resolution when image is uploaded or SR scale is changed | |
input_image.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution]) | |
sr_scale.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution]) | |
# Disable SR scale slider when no image is uploaded | |
input_image.change( | |
lambda x: gr.update(interactive=x is not None), | |
inputs=[input_image], | |
outputs=[sr_scale] | |
) | |
block.launch() |