Spaces:
Runtime error
Runtime error
from enum import Enum | |
import gc | |
import numpy as np | |
import torch | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from PIL import Image | |
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel | |
import utils | |
import gradio_utils | |
import os | |
from einops import rearrange | |
import matplotlib.pyplot as plt | |
def create_key(seed=0): | |
return jax.random.PRNGKey(seed) | |
class Model: | |
def __init__(self, **kwargs): | |
self.base_controlnet, self.base_controlnet_params = FlaxControlNetModel.from_pretrained( | |
#"JFoz/dog-cat-pose", dtype=jnp.bfloat16 | |
"lllyasviel/control_v11p_sd15_openpose", dtype=jnp.bfloat16, from_pt=True | |
) | |
self.pipe, self.params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", controlnet=self.base_controlnet, revision="flax", dtype=jnp.bfloat16,# from_pt=True, | |
) | |
def infer_frame(self, frame_id, prompt, negative_prompt, rng, **kwargs): | |
print(prompt, frame_id) | |
num_samples = 1 | |
prompt_ids = self.pipe.prepare_text_inputs([prompt[frame_id]]*num_samples) | |
negative_prompt_ids = self.pipe.prepare_text_inputs([negative_prompt[frame_id]] * num_samples) | |
processed_image = self.pipe.prepare_image_inputs([kwargs['image'][frame_id]]*num_samples) | |
self.params["controlnet"] = self.base_controlnet_params | |
p_params = replicate(self.params) | |
prompt_ids = shard(prompt_ids) | |
negative_prompt_ids = shard(negative_prompt_ids) | |
processed_image = shard(processed_image) | |
output = self.pipe( | |
prompt_ids=prompt_ids, | |
image=processed_image, | |
params=p_params, | |
prng_seed=rng, | |
num_inference_steps=50, | |
neg_prompt_ids=negative_prompt_ids, | |
jit=True, | |
).images | |
output_images = np.asarray(output.reshape((num_samples,) + output.shape[-3:])) | |
return output_images | |
def inference(self, **kwargs): | |
seed = kwargs.pop('seed', 0) | |
rng = create_key(0) | |
rng = jax.random.split(rng, jax.device_count()) | |
f = len(kwargs['image']) | |
print('frames', f) | |
assert 'prompt' in kwargs | |
prompt = [kwargs.pop('prompt')] * f | |
negative_prompt = [kwargs.pop('negative_prompt', '')] * f | |
frames_counter = 0 | |
result = [] | |
for i in range(0, f): | |
print(f'Processing frame {i + 1} / {f}') | |
result.append(self.infer_frame(frame_id=i, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
rng = rng, | |
**kwargs)) | |
frames_counter += 1 | |
result = np.stack(result, axis=0) | |
return result | |
def process_controlnet_pose(self, | |
video_path, | |
prompt, | |
num_inference_steps=20, | |
controlnet_conditioning_scale=1.0, | |
guidance_scale=9.0, | |
seed=42, | |
eta=0.0, | |
resolution=512, | |
save_path=None): | |
print("Module Pose") | |
video_path = gradio_utils.motion_to_video_path(video_path) | |
added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth' | |
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic' | |
video, fps = utils.prepare_video( | |
video_path, resolution, False, output_fps=4) | |
control = utils.pre_process_pose( | |
video, apply_pose_detect=False) | |
print('N frames', len(control)) | |
f, _, h, w = video.shape | |
result = self.inference(image=control, | |
prompt=prompt + ', ' + added_prompt, | |
height=h, | |
width=w, | |
negative_prompt=negative_prompts, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
eta=eta, | |
seed=seed, | |
output_type='numpy', | |
) | |
return utils.create_gif(result.astype(jnp.float16), fps, path=save_path) | |