DnDItem / app.py
stale2000's picture
Update app.py
9973325
raw
history blame
2.38 kB
#!/usr/bin/env python
# coding: utf-8
import os
import openai
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from torch import autocast
#from PIL import Image
#from torchvision import transforms
#from diffusers import StableDiffusionImageVariationPipeline
openai.api_key = os.getenv('openaikey')
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16)
pipe = pipe.to(device)
def predict(input, manual_query_repacement, history=[]):
# gpt3
if manual_query_repacement != "":
input = manual_query_repacement
response = openai.Completion.create(
model="text-davinci-003",
prompt=input,
temperature=0.9,
max_tokens=150,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6)
# tokenize the new input sentence
responseText = response["choices"][0]["text"]
history.append((input, responseText))
#img generation
prompt = "Yoda"
scale = 10
n_samples = 4
# Sometimes the nsfw checker is confused by the Naruto images, you can disable
# it at your own risk here
#disable_safety = False
#if disable_safety:
# def null_safety(images, **kwargs):
# return images, False
# pipe.safety_checker = null_safety
with autocast("cuda"):
images = pipe(n_samples*[prompt], guidance_scale=scale).images
for idx, im in enumerate(images):
im.save(f"{idx:06}.png")
images_list = pipe(
inp.tile(n_samples, 1, 1, 1),
guidance_scale=scale,
num_inference_steps=steps,
generator=generator,
)
images = []
for i, image in enumerate(images_list["images"]):
if(images_list["nsfw_content_detected"][i]):
safe_image = Image.open(r"unsafe.png")
images.append(safe_image)
else:
images.append(image)
return history, history, images
inputText = gr.Textbox(value="tmp")
manual_query = gr.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.")
output_img = gr.Gallery(label="Generated image")
output_img.style(grid=2)
gr.Interface(fn=predict,
inputs=[inputText,manual_query,'state'],
outputs=["chatbot",'state', output_img]).launch()