Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import os | |
import openai | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from torch import autocast | |
#from PIL import Image | |
#from torchvision import transforms | |
#from diffusers import StableDiffusionImageVariationPipeline | |
openai.api_key = os.getenv('openaikey') | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
def predict(input, manual_query_repacement, history=[]): | |
# gpt3 | |
if manual_query_repacement != "": | |
input = manual_query_repacement | |
response = openai.Completion.create( | |
model="text-davinci-003", | |
prompt=input, | |
temperature=0.9, | |
max_tokens=150, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0.6) | |
# tokenize the new input sentence | |
responseText = response["choices"][0]["text"] | |
history.append((input, responseText)) | |
#img generation | |
prompt = "Yoda" | |
scale = 10 | |
n_samples = 4 | |
# Sometimes the nsfw checker is confused by the Naruto images, you can disable | |
# it at your own risk here | |
#disable_safety = False | |
#if disable_safety: | |
# def null_safety(images, **kwargs): | |
# return images, False | |
# pipe.safety_checker = null_safety | |
with autocast("cuda"): | |
images = pipe(n_samples*[prompt], guidance_scale=scale).images | |
for idx, im in enumerate(images): | |
im.save(f"{idx:06}.png") | |
images_list = pipe( | |
inp.tile(n_samples, 1, 1, 1), | |
guidance_scale=scale, | |
num_inference_steps=steps, | |
generator=generator, | |
) | |
images = [] | |
for i, image in enumerate(images_list["images"]): | |
if(images_list["nsfw_content_detected"][i]): | |
safe_image = Image.open(r"unsafe.png") | |
images.append(safe_image) | |
else: | |
images.append(image) | |
return history, history, images | |
inputText = gr.Textbox(value="tmp") | |
manual_query = gr.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.") | |
output_img = gr.Gallery(label="Generated image") | |
output_img.style(grid=2) | |
gr.Interface(fn=predict, | |
inputs=[inputText,manual_query,'state'], | |
outputs=["chatbot",'state', output_img]).launch() | |