File size: 3,438 Bytes
8dc81a8
 
 
f44af0a
 
fdee321
7e6f588
4f9ea63
 
98fc1c9
 
3675354
 
4f9ea63
 
 
98fc1c9
 
 
 
 
 
 
 
1ef45ad
 
 
 
 
 
98fc1c9
a79e202
 
 
8dc81a8
7e1b030
c98a613
 
d4119c7
7e1b030
 
 
f44af0a
c98a613
 
 
7e1b030
f44af0a
3675354
 
 
 
f44af0a
 
8dc81a8
 
 
 
 
 
3675354
8dc81a8
dc6ab6b
3675354
7e6f588
a79e202
8dc81a8
3675354
5400559
7e6f588
7e1b030
 
 
 
 
8dc81a8
 
7e6f588
8dc81a8
 
 
 
98fc1c9
 
0c8f446
 
7e6f588
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import time
import sys
import re
import random
import torch
from PIL import Image
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pydantic import Field
from diffusers import StableDiffusionPipeline
from googletrans import Translator

app = FastAPI()

class Data(BaseModel):
	string: str
	member_secret: str

class ItemOut(BaseModel):
	status: str
	file: str

@app.get("/")
def index():
	return "SORRY! This file is member only."

@app.post("/draw", response_model=ItemOut)
def draw(data: Data):
	if data.member_secret != "" and data.member_secret == os.environ.get("MEMBER_SECRET"):

		device = "cuda" if torch.cuda.is_available() else "cpu"
		device_dict = {"cuda": 0, "cpu": -1}

		seedno = 0
		if '_seed' in data.string:
			seed = 1024
			stext = re.search(r'_seed[1-4]?', data.string)
			stext = re.sub('_seed', '', stext.group())
			if stext.isnumeric() and 0 < int(stext) < 5:
				seedno = stext
		else:
		 	seed = random.randrange(1024)

		text = re.sub('^#', '', data.string)
		text = re.sub('_seed[1-4]?', '', text)

		translator = Translator()
		translation = translator.translate(text, dest="en")
		text = translation.text

		# prompt = '(('+text+')) (( photograph )), highly detailed, sharp focus, 8k, 4k, (( photorealism )), detailed, saturated, portrait, 50mm, F/2.8, 1m away, ( global illumination, studio light, volumetric light ), ((( multicolor lights )))'
		prompt = '(('+text+')) (( photograph )), highly detailed, sharp focus, 8k, 4k, (( photorealism )), detailed, saturated, portrait, 50mm, F/2.8, 1m away, ((( multicolor lights )))'
		n_prompt = 'text, blurry, art, painting, rendering, drawing, sketch, (( ugly )), (( duplicate )), ( morbid ), (( mutilated )), ( mutated ), ( deformed ), ( disfigured ), ( extra limbs ), ( malformed limbs ), ( missing arms ), ( missing legs ), ( extra arms ), ( extra legs ), ( fused fingers ), ( too many fingers ), long neck, low quality, worst quality'

		# https://huggingface.co/docs/hub/spaces-sdks-docker-first-demo
		# how to validation: https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9
		# https://github.com/huggingface/diffusers

		model_id = 'stabilityai/stable-diffusion-2'

		#pipe = StableDiffusionPipeline.from_pretrained(model_id)
		pipe = StableDiffusionPipeline.from_pretrained(model_id, revision='fp16', torch_dtype=torch.float16)
		pipe.enable_attention_slicing()	# reduce gpu usage
		pipe = pipe.to(device)

		generator = torch.Generator(device).manual_seed(seed)
		images = pipe(prompt, negative_prompt=n_prompt, guidance_scale=7.5, generator=generator, num_images_per_prompt=1).images

		if 0 < int(seedno) < 5:
			grid = images[(int(seedno) - 1)]
		else:
			# Limit of T4 small...
			grid = image_grid(images, rows=1, cols=1)

		fileName = "sd_" + str(time.time()) + '.png'
		grid.save("/code/tmpdir/" + fileName)

		print(fileName)

		return {"status": "OK", "file": fileName}
	else:
		return {"status": "SORRY! This file is member only.", "file": ""}

app.mount("/static", StaticFiles(directory="/code/tmpdir"), name="/static")

# helper function taken from: https://huggingface.co/blog/stable_diffusion
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid