|
import gradio as gr |
|
import openai |
|
import requests |
|
import os |
|
from dotenv import load_dotenv |
|
import io |
|
import sys |
|
import json |
|
import PIL |
|
import time |
|
from stability_sdk import client |
|
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation |
|
import markdown2 |
|
|
|
title="Character Generator AI" |
|
inputs_label="Please tell me the characteristics of the character you want to create" |
|
outputs_label="AI will generate a character description and visualize it" |
|
visual_outputs_label="Character Visuals" |
|
description=""" |
|
- Please input within the limits of 1000 characters. It may take about 50 seconds to generate a response. |
|
""" |
|
|
|
article = """ |
|
""" |
|
|
|
load_dotenv() |
|
openai.api_key = os.getenv('OPENAI_API_KEY') |
|
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' |
|
stability_api = client.StabilityInference( |
|
key=os.getenv('STABILITY_KEY'), |
|
verbose=True, |
|
engine="stable-diffusion-xl-1024-v1-0", |
|
) |
|
MODEL = "gpt-4" |
|
|
|
def get_filetext(filename, cache={}): |
|
if filename in cache: |
|
return cache[filename] |
|
else: |
|
if not os.path.exists(filename): |
|
raise ValueError(f"File '{filename}' not found") |
|
with open(filename, "r") as f: |
|
text = f.read() |
|
cache[filename] = text |
|
return text |
|
|
|
def get_functions_from_schema(filename): |
|
schema = get_filetext(filename) |
|
schema_json = json.loads(schema) |
|
functions = schema_json.get("functions") |
|
return functions |
|
|
|
class StabilityAI: |
|
@classmethod |
|
def generate_image(cls, visualize_prompt): |
|
|
|
print("visualize_prompt:"+visualize_prompt) |
|
|
|
answers = stability_api.generate( |
|
prompt=visualize_prompt, |
|
) |
|
|
|
for resp in answers: |
|
for artifact in resp.artifacts: |
|
if artifact.finish_reason == generation.FILTER: |
|
print("NSFW") |
|
if artifact.type == generation.ARTIFACT_IMAGE: |
|
img = PIL.Image.open(io.BytesIO(artifact.binary)) |
|
return img |
|
|
|
class OpenAI: |
|
|
|
@classmethod |
|
def chat_completion(cls, prompt, start_with=""): |
|
constraints = get_filetext(filename = "constraints.md") |
|
template = get_filetext(filename = "template.md") |
|
|
|
data = { |
|
"model": MODEL, |
|
"messages": [ |
|
{"role": "system", "content": constraints} |
|
,{"role": "assistant", "content": template} |
|
,{"role": "user", "content": prompt} |
|
,{"role": "assistant", "content": start_with} |
|
], |
|
} |
|
|
|
start = time.time() |
|
response = requests.post( |
|
"https://api.openai.com/v1/chat/completions", |
|
headers={ |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {openai.api_key}" |
|
}, |
|
json=data |
|
) |
|
print("gpt generation time: "+str(time.time() - start)) |
|
|
|
result = response.json() |
|
print(result) |
|
|
|
content = result["choices"][0]["message"]["content"].strip() |
|
|
|
visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1] |
|
|
|
answers = stability_api.generate( |
|
prompt=visualize_prompt, |
|
) |
|
|
|
@classmethod |
|
def chat_completion_with_function(cls, prompt, messages, functions): |
|
print("prompt:"+prompt) |
|
|
|
|
|
start = time.time() |
|
|
|
response = openai.ChatCompletion.create( |
|
model=MODEL, |
|
messages=messages, |
|
functions=functions, |
|
function_call={"name": "create_character"} |
|
) |
|
print("gpt generation time: "+str(time.time() - start)) |
|
|
|
|
|
message = response.choices[0].message |
|
print("chat completion message: " + json.dumps(message, indent=2)) |
|
|
|
return message |
|
|
|
class NajiminoAI: |
|
|
|
def __init__(self, user_message): |
|
self.user_message = user_message |
|
|
|
def generate_recipe_prompt(self): |
|
template = get_filetext(filename="template.md") |
|
prompt = f""" |
|
{self.user_message} |
|
--- |
|
ไธ่จใๅ
ใซใไธ่จใใณใใฌใผใใๅใใฆใใ ใใใ |
|
--- |
|
{template} |
|
""" |
|
return prompt |
|
|
|
def create_character(self, lang, title, description, prompt_for_visual_expression): |
|
|
|
template = get_filetext(filename = "template.md") |
|
debug_message = template.format( |
|
lang=lang, |
|
title=title, |
|
description=description, |
|
prompt_for_visual_expression=prompt_for_visual_expression |
|
) |
|
|
|
print("debug_message: "+debug_message) |
|
|
|
return debug_message |
|
|
|
@classmethod |
|
def generate(cls, user_message): |
|
|
|
najiminoai = NajiminoAI(user_message) |
|
|
|
return najiminoai.generate_recipe() |
|
|
|
def generate_recipe(self): |
|
|
|
user_message = self.user_message |
|
constraints = get_filetext(filename = "constraints.md") |
|
|
|
messages = [ |
|
{"role": "system", "content": constraints} |
|
,{"role": "user", "content": user_message} |
|
] |
|
|
|
functions = get_functions_from_schema('schema.json') |
|
|
|
message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions) |
|
|
|
image = None |
|
html = None |
|
if message.get("function_call"): |
|
function_name = message["function_call"]["name"] |
|
|
|
args = json.loads(message["function_call"]["arguments"]) |
|
|
|
lang=args.get("lang") |
|
title=args.get("title") |
|
description=args.get("description") |
|
prompt_for_visual_expression_in_en=args.get("prompt_for_visual_expression_in_en") |
|
|
|
prompt_for_visual_expression = \ |
|
prompt_for_visual_expression_in_en |
|
|
|
print("prompt_for_visual_expression: "+prompt_for_visual_expression) |
|
|
|
|
|
start = time.time() |
|
image = StabilityAI.generate_image(prompt_for_visual_expression) |
|
print("image generation time: "+str(time.time() - start)) |
|
|
|
function_response = self.create_character( |
|
lang=lang, |
|
title=title, |
|
description=description, |
|
prompt_for_visual_expression=prompt_for_visual_expression |
|
) |
|
|
|
html = ( |
|
"<div style='max-width:100%; overflow:auto'>" |
|
+ "<p>" |
|
+ markdown2.markdown(function_response) |
|
+ "</div>" |
|
) |
|
return [image, html] |
|
|
|
def main(): |
|
|
|
def click_example(example): |
|
|
|
inputs.value = example |
|
time.sleep(0.1) |
|
|
|
execute_button.click() |
|
|
|
iface = gr.Interface(fn=NajiminoAI.generate, |
|
examples=[ |
|
["ๅญใฉใใซๆใใใไธธ้กใง้ใใคใซใซใฎใญใฃใฉใฏใฟใผ"], |
|
["ๆดป็บใงๅคใไผผๅใ่ตค้ซชใฎๅฅณใฎๅญ"], |
|
["้ปใ้งใ็ใ้ญ็"], |
|
], |
|
inputs=gr.Textbox(label=inputs_label), |
|
outputs=[ |
|
gr.Image(label="Visual Expression"), |
|
"html" |
|
], |
|
title=title, |
|
description=description, |
|
article=article |
|
) |
|
|
|
iface.launch() |
|
|
|
if __name__ == '__main__': |
|
function = '' |
|
if len(sys.argv) > 1: |
|
function = sys.argv[1] |
|
|
|
if function == 'generate': |
|
NajiminoAI.generate("A brave knight with a mighty sword and strong armor") |
|
|
|
elif function == 'generate_image': |
|
image = StabilityAI.generate_image("Imagine a brave knight with a mighty sword and strong armor. He has a chiseled jawline and a confident expression on his face. His armor gleams under the sunlight, showing off its intricate design and craftsmanship. He holds his sword with pride, ready to protect his kingdom and its people at any cost.") |
|
print("image: " + image) |
|
|
|
if type(image) == PIL.PngImagePlugin.PngImageFile: |
|
image.save("image.png") |
|
|
|
else: |
|
main() |