{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\" \n", "os.environ[\"WORLD_SIZE\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fsx/homes/afruchtman/.envs/ms_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "2023-10-17 14:09:13.213394: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-10-17 14:09:18.049981: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import time\n", "import torch\n", "import gradio as gr\n", "from diffusers import StableDiffusionPipeline, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline\n", "from diffusers import DDPMScheduler, DEISMultistepScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler\n", "\n", "from PIL import Image\n", "\n", "from ip_adapter import IPAdapterPlus, IPAdapter" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# def image_grid(imgs, rows, cols):\n", "# assert len(imgs) == rows*cols\n", "\n", "# w, h = imgs[0].size\n", "# grid = Image.new('RGB', size=(cols*w, rows*h))\n", "# grid_w, grid_h = grid.size\n", " \n", "# for i, img in enumerate(imgs):\n", "# grid.paste(img, box=(i%cols*w, i//cols*h))\n", "# return grid\n", "\n", "# def init_pipe(base_model, ip_ckpt_path, vae_model_path, image_encoder_path, device, noise_sampler):\n", "# base_model = str(base_model)\n", "# noise_sampler = noise_sampler.from_pretrained(base_model, subfolder=\"scheduler\")\n", "# print(f\"{noise_sampler._class_name} was successfully loaded\")\n", "# vae = AutoencoderKL.from_pretrained(str(vae_model_path)).to(dtype=torch.float16)\n", "# torch.cuda.empty_cache()\n", "# pipe = StableDiffusionPipeline.from_pretrained(\n", "# base_model,\n", "# torch_dtype=torch.float16,\n", "# scheduler=noise_sampler,\n", "# vae=vae,\n", "# feature_extractor=None,\n", "# safety_checker=None\n", "# )\n", "# # load ip-adapter\n", "# ip_model = IPAdapterPlus(pipe, image_encoder_path, str(ip_ckpt_path), device, num_tokens=16)\n", "# print(f\"{base_model} was successfully loaded\")\n", "\n", "# return ip_model\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# def init_pipe(base_model, ip_ckpt_path, vae_model_path, image_encoder_path, device, noise_sampler):\n", "# base_model = str(base_model)\n", "# noise_sampler = noise_sampler.from_pretrained(base_model, subfolder=\"scheduler\")\n", "# print(f\"{noise_sampler._class_name} was successfully loaded\")\n", "# vae = AutoencoderKL.from_pretrained(str(vae_model_path)).to(dtype=torch.float16)\n", "# torch.cuda.empty_cache()\n", "# pipe = StableDiffusionPipeline.from_pretrained(\n", "# base_model,\n", "# torch_dtype=torch.float16,\n", "# scheduler=noise_sampler,\n", "# vae=vae,\n", "# feature_extractor=None,\n", "# safety_checker=None\n", "# )\n", "# # load ip-adapter\n", "# ip_model = IPAdapterPlus(pipe, image_encoder_path, str(ip_ckpt_path), device, num_tokens=16)\n", "# print(f\"{base_model} was successfully loaded\")\n", "\n", "# return ip_model\n", "\n", "def init_pipe(base_model, ip_ckpt_path, vae_model_path, image_encoder_path, device, noise_sampler, controlnet_model_path=None):\n", " torch.cuda.empty_cache()\n", " base_model = str(base_model)\n", " controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16) if controlnet_model_path is not None else None\n", " sd_pipeline = StableDiffusionControlNetPipeline if controlnet else StableDiffusionPipeline\n", " noise_sampler = noise_sampler.from_pretrained(base_model, subfolder=\"scheduler\")\n", " \n", " print(f\"{noise_sampler._class_name} was successfully loaded\")\n", "\n", "\n", " vae = AutoencoderKL.from_pretrained(str(vae_model_path)).to(dtype=torch.float16)\n", "\n", " args = {\n", " \"pretrained_model_name_or_path\": base_model,\n", " \"torch_dtype\": torch.float16,\n", " \"scheduler\": noise_sampler,\n", " \"vae\": vae,\n", " \"feature_extractor\": None,\n", " \"safety_checker\": None,\n", " }\n", "\n", " if controlnet:\n", " args[\"controlnet\"] = controlnet\n", " \n", " pipe = sd_pipeline.from_pretrained(**args)\n", " print(pipe)\n", " ip_adapt_cls = IPAdapterPlus if \"plus\" in ip_ckpt_path else IPAdapter\n", " num_tokens = 16 if \"plus\" in ip_ckpt_path else 4\n", " print(ip_adapt_cls)\n", " ip_model = ip_adapt_cls(pipe, image_encoder_path, str(ip_ckpt_path), device, num_tokens=num_tokens)\n", " print(f\"{base_model} was successfully loaded\")\n", "\n", " return ip_model\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def generate_image(ip_pipe, prompt, negative_prompt, pil_image, image=None, num_samples=4, height=512, width=512,\n", " num_inference_steps=20, seed=42, scale=1.0, guidance_scale=7.5):\n", " torch.cuda.empty_cache()\n", " print(image)\n", " args = {\n", " \"prompt\":prompt, \n", " \"negative_prompt\":negative_prompt, \n", " \"pil_image\":pil_image, \n", " \"num_samples\":int(num_samples), \n", " \"height\":int(height), \n", " \"width\":int(width),\n", " \"num_inference_steps\":int(num_inference_steps), \n", " \"seed\":int(seed), \n", " \"scale\":scale, \n", " \"guidance_scale\":guidance_scale\n", " }\n", "\n", " if image is not None:\n", " print(image.size)\n", " args[\"image\"] = image # [..., np.newaxis] #.transpose(1,2,0)[np.newaxis, ...] \n", " images = ip_pipe.generate(**args)\n", "\n", " return images " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "ip_dir = 'models/ip_adapters'\n", "base_model_choices = [\n", " 'Lykon/dreamshaper-8', 'runwayml/stable-diffusion-v1-5', \n", " 'dreamlike-art/dreamlike-anime-1.0','Lykon/AbsoluteReality',\n", " 'SG161222/Realistic_Vision_V5.1_noVAE'\n", "]\n", "\n", "noise_sampler_dict = {\n", " \"DDPM\" : DDPMScheduler,\n", " \"DEISMultiStep\" : DEISMultistepScheduler,\n", " \"DDIM\" : DDIMScheduler,\n", " \"Euler a\" : EulerAncestralDiscreteScheduler,\n", " \"Euler\" : EulerDiscreteScheduler,\n", " \"Heun\" : HeunDiscreteScheduler,\n", "}\n", "\n", "sampler_choices = noise_sampler_dict.keys()\n", "\n", "ip_choices = [os.path.join(ip_dir, f) for f in sorted(os.listdir(ip_dir)) if f]\n", "\n", "prompt_examples = [\n", " [\"A portrait of pretty person, smile, holding one cute dragon, as a beautiful fairy with glittering wings, decorations from outlandish stones, wearing a turtleneck decorated with glitter, nice even light, lofi, in a magical forest, digital art, trending on artstation, behance, deviantart, insanely detailed and intricate, many patterns\",\n", " \"dogs, wrinkled old face, cat, classic outfit, classic turtleneck, neckline, big breasts, enlarged breasts, big boobs, legs, closed eyes, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, nudes, nude breasts, ugly, disgusting, blurry, amputation, unclear, bindi, pottu\"],\n", "\n", " [\"A person, high resolution, best quality, sharp focus, 8k, highly detailed, digital art, digital painting, trending art, smooth skin, detailed facial skin, symmetric\",\n", " \"blurry, unclear, low resolution, bad quality, nsfw, nudes, nude breasts, nude, shirtless, deformed iris, deformed pupils, mutated hands and fingers, nude, naked, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation, closed eyes, big breasts, sign, mutated hands and fingers, bindi, pottu,\"],\n", " [\"high res\",\" low res\"],\n", "\n", "]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861\n", "Running on public URL: https://30456c2767331a4c45.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "