File size: 5,256 Bytes
31f2f07 27ed9e0 31f2f07 27ed9e0 31f2f07 a48c418 31f2f07 5ed6d89 31f2f07 27ed9e0 31f2f07 27ed9e0 31f2f07 27ed9e0 31f2f07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# inference handler for lightning ai
import re
import os
import logging
# import json
from pydantic import BaseModel
from typing import Any, Dict, Optional, TYPE_CHECKING
from dataclasses import dataclass
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import lightning as L
from lightning.app.components.serve import PythonServer, Text
from lightning.app import BuildConfig
class _DefaultInputData(BaseModel):
prompt: str
class _DefaultOutputData(BaseModel):
img_data: str
parameters: str
@dataclass
class CustomBuildConfig(BuildConfig):
def build_commands(self):
dir_path = "/content/"
model_path = os.path.join(dir_path, "models/Stable-diffusion")
# model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors"
model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors"
download_cmd = "wget -P {} {}".format(str(model_path), model_url)
vae_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/VAE/vae-ft-mse-840000-ema-pruned.ckpt"
vae_path = os.path.join(dir_path, "models/VAE")
down2 = "wget -P {} {}".format(str(vae_path), vae_url)
lora_url1 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/koreanDollLikeness_v10.safetensors"
lora_url2 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/taiwanDollLikeness_v10.safetensors"
lora_path = os.path.join(dir_path, "models/Lora")
down3 = "wget -P {} {}".format(str(lora_path), lora_url1)
down4 = "wget -P {} {}".format(str(lora_path), lora_url2)
# https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo
cmd1 = "pip3 install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117"
cmd2 = "pip3 install torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117"
cmd_31 = "sudo apt-get update"
cmd3 = "sudo apt-get install libgl1-mesa-glx"
cmd4 = "sudo apt-get install libglib2.0-0"
return [download_cmd, down2, down3, down4, cmd_31, cmd3, cmd4]
class PyTorchServer(PythonServer):
def __init__(
self,
input_type: type = _DefaultInputData,
output_type: type = _DefaultOutputData,
**kwargs: Any,
):
super().__init__(input_type=input_type, output_type=output_type, **kwargs)
# Use the custom build config
self.cloud_build_config = CustomBuildConfig()
def setup(self):
# need to install dependancies first to import packages
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from handler import initialize
initialize()
def predict(self, request):
from modules.api.api import encode_pil_to_base64
from modules import shared
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
args = {
"do_not_save_samples": True,
"do_not_save_grid": True,
"outpath_samples": "/content/desktop",
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed",
"sampler_name": "DPM++ SDE Karras",
"steps": 20, # 25
"cfg_scale": 8,
"width": 512,
"height": 768,
"seed": -1,
}
print("&&&&&&&&&&&&&&&&&&&&&&&&",request)
if request.prompt:
prompt = request.prompt
print("get prompt from request: ", prompt)
args["prompt"] = prompt
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
processed = process_images(p)
single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
return {
"img_data": single_image_b64,
"parameters": processed.images[0].info.get('parameters', ""),
}
component = PyTorchServer(
cloud_compute=L.CloudCompute('gpu', disk_size=20, idle_timeout=30)
)
# lightning run app app.py --cloud
app = L.LightningApp(component)
|