#!/usr/bin/env python # coding: utf-8 import os import openai import gradio as gr import torch from diffusers import StableDiffusionPipeline from torch import autocast #from PIL import Image #from torchvision import transforms #from diffusers import StableDiffusionImageVariationPipeline openai.api_key = os.getenv('openaikey') device = "cuda" if torch.cuda.is_available() else "cpu" pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16) pipe = pipe.to(device) def predict(input, manual_query_repacement, history=[]): # gpt3 if manual_query_repacement != "": input = manual_query_repacement response = openai.Completion.create( model="text-davinci-003", prompt=input, temperature=0.9, max_tokens=150, top_p=1, frequency_penalty=0, presence_penalty=0.6) # tokenize the new input sentence responseText = response["choices"][0]["text"] history.append((input, responseText)) #img generation prompt = "Yoda" scale = 10 n_samples = 4 # Sometimes the nsfw checker is confused by the Naruto images, you can disable # it at your own risk here #disable_safety = False #if disable_safety: # def null_safety(images, **kwargs): # return images, False # pipe.safety_checker = null_safety with autocast("cuda"): images = pipe(n_samples*[prompt], guidance_scale=scale).images for idx, im in enumerate(images): im.save(f"{idx:06}.png") images_list = pipe( inp.tile(n_samples, 1, 1, 1), guidance_scale=scale, num_inference_steps=steps, generator=generator, ) images = [] for i, image in enumerate(images_list["images"]): if(images_list["nsfw_content_detected"][i]): safe_image = Image.open(r"unsafe.png") images.append(safe_image) else: images.append(image) return history, history, images inputText = gr.Textbox(value="tmp") manual_query = gr.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.") output_img = gr.Gallery(label="Generated image") output_img.style(grid=2) gr.Interface(fn=predict, inputs=[inputText,manual_query,'state'], outputs=["chatbot",'state', output_img]).launch()