Spaces:
Running
Running
import torch | |
from PIL import Image, ImageOps, ImageSequence | |
import numpy as np | |
import comfy.sample | |
import comfy.sd | |
def vencode(vae, pth): | |
pilimg = pth | |
pixels = np.array(pilimg).astype(np.float32) / 255.0 | |
pixels = torch.from_numpy(pixels)[None,] | |
t = vae.encode(pixels[:,:,:,:3]) | |
return {"samples":t} | |
from pathlib import Path | |
if not Path("model.safetensors").exists(): | |
import requests | |
with open("model.safetensors", "wb") as f: | |
f.write(requests.get("https://huggingface.co/parsee-mizuhashi/mangaka/resolve/main/mangaka.safetensors?download=true").content) | |
MODEL_FILE = "model.safetensors" | |
with torch.no_grad(): | |
unet, clip, vae = comfy.sd.load_checkpoint_guess_config(MODEL_FILE, output_vae=True, output_clip=True)[:3]# :3 | |
BASE_NEG = "(low-quality worst-quality:1.4 (bad-anatomy (inaccurate-limb:1.2 bad-composition inaccurate-eyes extra-digit fewer-digits (extra-arms:1.2)" | |
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda" | |
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0): | |
noise_mask = None | |
if "noise_mask" in latent: | |
noise_mask = latent["noise_mask"] | |
latnt = latent["samples"] | |
noise = comfy.sample.prepare_noise(latnt, seed, None) | |
disable_pbar = True | |
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latnt, | |
denoise=denoise, noise_mask=noise_mask, disable_pbar=disable_pbar, seed=seed) | |
out = samples | |
return out | |
def set_mask(samples, mask): | |
s = samples.copy() | |
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) | |
return s | |
def load_image_mask(image): | |
image_path = image | |
i = Image.open(image_path) | |
i = ImageOps.exif_transpose(i) | |
if i.getbands() != ("R", "G", "B", "A"): | |
if i.mode == 'I': | |
i = i.point(lambda i: i * (1 / 255)) | |
i = i.convert("RGBA") | |
mask = None | |
c = "A" | |
if c in i.getbands(): | |
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 | |
mask = torch.from_numpy(mask) | |
else: | |
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") | |
return mask.unsqueeze(0) | |
def main(img, variant, positive, negative, pilimg): | |
variant = min(int(variant), limits[img]) | |
global unet, clip, vae | |
mask = load_image_mask(f"./mangaka-d/{img}/i{variant}.png") | |
tkns = clip.tokenize("(greyscale monochrome black-and-white:1.3)" + positive) | |
cond, c = clip.encode_from_tokens(tkns, return_pooled=True) | |
uncond_tkns = clip.tokenize(BASE_NEG + negative) | |
uncond, uc = clip.encode_from_tokens(uncond_tkns, return_pooled=True) | |
cn = [[cond, {"pooled_output": c}]] | |
un = [[uncond, {"pooled_output": uc}]] | |
latent = vencode(vae, pilimg) | |
latent = set_mask(latent, mask) | |
denoised = common_ksampler(unet, 0, 20, 7, 'ddpm', 'karras', cn, un, latent, denoise=1) | |
decoded = vae.decode(denoised) | |
i = 255. * decoded[0].cpu().numpy() | |
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) | |
return img | |
limits = { | |
"1": 4, | |
"2": 4, | |
"3": 5, | |
"4": 6, | |
"5": 4, | |
"6": 6, | |
"7": 8, | |
"8": 5, | |
"9": 5, | |
"s1": 4, | |
"s2": 6, | |
"s3": 5, | |
"s4": 5, | |
"s5": 4, | |
"s6": 4 | |
} | |
import gradio as gr | |
def visualize_fn(page, panel): | |
base = f"./mangaka-d/{page}/base.png" | |
base = Image.open(base) | |
if panel == "none": | |
return base | |
panel = min(int(panel), limits[page]) | |
mask = f"./mangaka-d/{page}/i{panel}.png" | |
base = base.convert("RGBA") | |
mask = Image.open(mask) | |
#remove all green and blue from the mask | |
mask = mask.convert("RGBA") | |
data = mask.getdata() | |
data = [ | |
(255, 0, 0, 255) if pixel[:3] == (255, 255, 255) else pixel | |
for pixel in mask.getdata() | |
] | |
mask.putdata(data) | |
#overlay the mask on the base | |
base.paste(mask, (0,0), mask) | |
return base | |
def reset_fn(page): | |
base = f"./mangaka-d/{page}/base.png" | |
base = Image.open(base) | |
return base | |
with gr.Blocks() as demo: | |
with gr.Tab("Mangaka"): | |
with gr.Row(): | |
with gr.Column(): | |
positive = gr.Textbox(label="Positive prompt", lines=2) | |
negative = gr.Textbox(label="Negative prompt") | |
with gr.Accordion("Page Settings"): | |
with gr.Row(): | |
with gr.Column(): | |
page = gr.Dropdown(label="Page", choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "s1", "s2", "s3", "s4", "s5", "s6"], value="s1") | |
panel = gr.Dropdown(label="Panel", choices=["1", "2", "3", "4", "5", "6", "7", "8", "none"], value="1") | |
visualize = gr.Button("Visualize") | |
with gr.Column(): | |
visualize_output = gr.Image(interactive=False) | |
visualize.click(visualize_fn, inputs=[page, panel], outputs=visualize_output) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
generate = gr.Button("Generate", variant="primary") | |
with gr.Column(): | |
reset = gr.Button("Reset", variant="stop") | |
current_panel = gr.Image(interactive=False) | |
reset.click(reset_fn, inputs=[page], outputs=current_panel) | |
generate.click(main, inputs=[page, panel, positive, negative, current_panel], outputs=current_panel) | |
demo.launch() |