kumahiyo
add translator
3675354
raw
history blame
No virus
3.44 kB
import os
import time
import sys
import re
import random
import torch
from PIL import Image
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pydantic import Field
from diffusers import StableDiffusionPipeline
from googletrans import Translator
app = FastAPI()
class Data(BaseModel):
string: str
member_secret: str
class ItemOut(BaseModel):
status: str
file: str
@app.get("/")
def index():
return "SORRY! This file is member only."
@app.post("/draw", response_model=ItemOut)
def draw(data: Data):
if data.member_secret != "" and data.member_secret == os.environ.get("MEMBER_SECRET"):
device = "cuda" if torch.cuda.is_available() else "cpu"
device_dict = {"cuda": 0, "cpu": -1}
seedno = 0
if '_seed' in data.string:
seed = 1024
stext = re.search(r'_seed[1-4]?', data.string)
stext = re.sub('_seed', '', stext.group())
if stext.isnumeric() and 0 < int(stext) < 5:
seedno = stext
else:
seed = random.randrange(1024)
text = re.sub('^#', '', data.string)
text = re.sub('_seed[1-4]?', '', text)
translator = Translator()
translation = translator.translate(text, dest="en")
text = translation.text
# prompt = '(('+text+')) (( photograph )), highly detailed, sharp focus, 8k, 4k, (( photorealism )), detailed, saturated, portrait, 50mm, F/2.8, 1m away, ( global illumination, studio light, volumetric light ), ((( multicolor lights )))'
prompt = '(('+text+')) (( photograph )), highly detailed, sharp focus, 8k, 4k, (( photorealism )), detailed, saturated, portrait, 50mm, F/2.8, 1m away, ((( multicolor lights )))'
n_prompt = 'text, blurry, art, painting, rendering, drawing, sketch, (( ugly )), (( duplicate )), ( morbid ), (( mutilated )), ( mutated ), ( deformed ), ( disfigured ), ( extra limbs ), ( malformed limbs ), ( missing arms ), ( missing legs ), ( extra arms ), ( extra legs ), ( fused fingers ), ( too many fingers ), long neck, low quality, worst quality'
# https://huggingface.co/docs/hub/spaces-sdks-docker-first-demo
# how to validation: https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9
# https://github.com/huggingface/diffusers
model_id = 'stabilityai/stable-diffusion-2'
#pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision='fp16', torch_dtype=torch.float16)
pipe.enable_attention_slicing() # reduce gpu usage
pipe = pipe.to(device)
generator = torch.Generator(device).manual_seed(seed)
images = pipe(prompt, negative_prompt=n_prompt, guidance_scale=7.5, generator=generator, num_images_per_prompt=1).images
if 0 < int(seedno) < 5:
grid = images[(int(seedno) - 1)]
else:
# Limit of T4 small...
grid = image_grid(images, rows=1, cols=1)
fileName = "sd_" + str(time.time()) + '.png'
grid.save("/code/tmpdir/" + fileName)
print(fileName)
return {"status": "OK", "file": fileName}
else:
return {"status": "SORRY! This file is member only.", "file": ""}
app.mount("/static", StaticFiles(directory="/code/tmpdir"), name="/static")
# helper function taken from: https://huggingface.co/blog/stable_diffusion
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid