{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#| export\n", "import gradio as gr\n", "from cf_guidance import schedules, transforms\n", "from min_diffusion.core import MinimalDiffusion\n", "import torch\n", "import nbdev" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#| export\n", "\n", "## MODEL SETUP\n", "######################################\n", "######################################\n", "model_name = 'stabilityai/stable-diffusion-2'\n", "device = ('cpu','cuda')[torch.cuda.is_available()]\n", "if device == 'cuda':\n", " revision = 'fp16'\n", " dtype = torch.float16\n", "else:\n", " revision = 'fp32'\n", " dtype = torch.float32\n", "\n", "# model parameters\n", "better_vae = ''\n", "unet_attn_slice = True\n", "sampler_kls = 'dpm_multi'\n", "hf_sampler = 'dpm_multi'\n", "\n", "model_kwargs = {\n", " 'better_vae': better_vae,\n", " 'unet_attn_slice': unet_attn_slice,\n", " 'scheduler_kls': hf_sampler,\n", "}\n", "\n", "def load_model():\n", " pipeline = MinimalDiffusion(\n", " model_name,\n", " device,\n", " dtype,\n", " revision,\n", " **model_kwargs,\n", " )\n", " pipeline.load()\n", " return pipeline\n", "######################################\n", "######################################" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export \n", "\n", "## GENERATION PARAMETERS\n", "######################################\n", "######################################\n", "num_steps = 18\n", "height, width = 768, 768\n", "k_sampler = 'k_dpmpp_2m' #'k_dpmpp_sde'\n", "use_karras_sigmas = True\n", "\n", "# a good negative prompt\n", "NEG_PROMPT = \"ugly, stock photo, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy\"\n", "\n", "generation_kwargs = {\n", " 'num_steps': num_steps,\n", " 'height': height,\n", " 'width': width,\n", " 'k_sampler': k_sampler,\n", " 'negative_prompt': NEG_PROMPT,\n", " 'use_karras_sigmas': use_karras_sigmas,\n", "}\n", "######################################\n", "######################################" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export \n", "\n", "## dynamicCFG SETUP\n", "######################################\n", "######################################\n", "\n", "# default cosine schedule parameters\n", "baseline_g = 9 # default, static guidance value\n", "max_val = 9 # the max scheduled guidance scaling value\n", "min_val = 6 # the minimum scheduled guidance value\n", "num_warmup_steps = 0 # number of warmup steps\n", "warmup_init_val = 0 # the intial warmup value\n", "num_cycles = 0.5 # number of cosine cycles\n", "k_decay = 1 # k-decay for cosine curve scaling \n", "\n", "# group the default schedule parameters\n", "DEFAULT_COS_PARAMS = {\n", " 'max_val': max_val,\n", " 'num_steps': num_steps,\n", " 'min_val': min_val,\n", " 'num_cycles': num_cycles,\n", " 'k_decay': k_decay,\n", " 'num_warmup_steps': num_warmup_steps,\n", " 'warmup_init_val': warmup_init_val,\n", "}\n", "\n", "def cos_harness(new_params: dict) -> dict:\n", " '''Creates cosine schedules with updated parameters in `new_params`\n", " '''\n", " # start from the given baseline `default_params`\n", " cos_params = dict(DEFAULT_COS_PARAMS)\n", " # update the with the new, given parameters\n", " cos_params.update(new_params)\n", " \n", " # return the new cosine schedule\n", " sched = schedules.get_cos_sched(**cos_params)\n", " return sched\n", "\n", "\n", "# build the static schedule\n", "static_sched = [baseline_g] * num_steps\n", "\n", "# build the inverted kdecay schedule\n", "k_sched = cos_harness({'k_decay': 0.2})\n", "inv_k_sched = [max_val - g + min_val for g in k_sched]\n", "\n", "# group the schedules \n", "scheds = {\n", " 'cosine': {'g': inv_k_sched},\n", " 'static': {'g': static_sched},\n", "}\n", "######################################\n", "######################################" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export \n", "\n", "def compare_dynamic_guidance(prompt):\n", " '''\n", " Compares the default, static Classifier-free Guidance to a dynamic schedule. \n", "\n", " Model and sampling paramters:\n", " Stable Diffusion 2 v-model\n", " Half-precision\n", " DPM++ 2M sampler, with Karras sigma schedule\n", " 18 sampling steps\n", " (768 x 768) image\n", " Using a generic negative prompt\n", "\n", " Schedules:\n", " Static guidance with scale of 9\n", " Inverse kDecay (cosine variant) scheduled guidance\n", " '''\n", " # load the model\n", " pipeline = load_model()\n", "\n", " # stores the output images\n", " res = []\n", "\n", " # generate images with static and dynamic schedules\n", " for (name,sched) in scheds.items():\n", " # make the guidance norm\n", " gtfm = transforms.GuidanceTfm(sched)\n", " # generate the image\n", " with torch.autocast(device), torch.no_grad():\n", " img = pipeline.generate(prompt, gtfm, **generation_kwargs)\n", " # add the generated image\n", " res.append(name)\n", "\n", " # return the generated images\n", " return {\n", " 'values': res,\n", " 'label': 'Cosine vs. Static CFG'\n", " }" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "\n", "#| export\n", "\n", "iface = gr.Interface(\n", " compare_dynamic_guidance,\n", " inputs=\"text\",\n", " outputs=gr.Gallery(),\n", " title=\"Comparing image generations with dynamic Classifier-free Guidance\",\n", ")\n", "iface.launch()\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import nbdev\n", "nbdev.export.nb_export('app.ipynb', '')" ] } ], "metadata": { "kernelspec": { "display_name": "sdiffkernel", "language": "python", "name": "sdiffkernel" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "7aa72ffd68a1153f913726b8656445c52d825f656451987cb25ebe84c64ea44d" } } }, "nbformat": 4, "nbformat_minor": 2 }