File size: 8,974 Bytes
3cef652 863b7e9 3cef652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
import gradio as gr
import openai
import requests
import os
from dotenv import load_dotenv
import io
import sys
import json
import PIL
import time
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import markdown2
title="Character Generator AI"
inputs_label="Please tell me the characteristics of the character you want to create"
outputs_label="AI will generate a character description and visualize it"
visual_outputs_label="Character Visuals"
description="""
- Please input within the limits of 1000 characters. It may take about 50 seconds to generate a response.
"""
article = """
"""
load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
stability_api = client.StabilityInference(
key=os.getenv('STABILITY_KEY'),
verbose=True,
engine="stable-diffusion-xl-1024-v1-0",
)
MODEL = "gpt-4"
def get_filetext(filename, cache={}):
if filename in cache:
return cache[filename]
else:
if not os.path.exists(filename):
raise ValueError(f"File '{filename}' not found")
with open(filename, "r") as f:
text = f.read()
cache[filename] = text
return text
def get_functions_from_schema(filename):
schema = get_filetext(filename)
schema_json = json.loads(schema)
functions = schema_json.get("functions")
return functions
class StabilityAI:
@classmethod
def generate_image(cls, visualize_prompt):
print("visualize_prompt:"+visualize_prompt)
answers = stability_api.generate(
prompt=visualize_prompt,
)
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
print("NSFW")
if artifact.type == generation.ARTIFACT_IMAGE:
img = PIL.Image.open(io.BytesIO(artifact.binary))
return img
class OpenAI:
@classmethod
def chat_completion(cls, prompt, start_with=""):
constraints = get_filetext(filename = "constraints.md")
template = get_filetext(filename = "template.md")
data = {
"model": MODEL,
"messages": [
{"role": "system", "content": constraints}
,{"role": "assistant", "content": template}
,{"role": "user", "content": prompt}
,{"role": "assistant", "content": start_with}
],
}
start = time.time()
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}"
},
json=data
)
print("gpt generation time: "+str(time.time() - start))
result = response.json()
print(result)
content = result["choices"][0]["message"]["content"].strip()
visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1]
answers = stability_api.generate(
prompt=visualize_prompt,
)
@classmethod
def chat_completion_with_function(cls, prompt, messages, functions):
print("prompt:"+prompt)
# 文章生成にかかる時間を計測する
start = time.time()
# ChatCompletion APIを呼び出す
response = openai.ChatCompletion.create(
model=MODEL,
messages=messages,
functions=functions,
function_call={"name": "create_character"}
)
print("gpt generation time: "+str(time.time() - start))
# ChatCompletion APIから返された結果を取得する
message = response.choices[0].message
print("chat completion message: " + json.dumps(message, indent=2))
return message
class NajiminoAI:
def __init__(self, user_message):
self.user_message = user_message
def generate_recipe_prompt(self):
template = get_filetext(filename="template.md")
prompt = f"""
{self.user_message}
---
上記を元に、下記テンプレートを埋めてください。
---
{template}
"""
return prompt
def create_character(self, lang, title, description, prompt_for_visual_expression):
template = get_filetext(filename = "template.md")
debug_message = template.format(
lang=lang,
title=title,
description=description,
prompt_for_visual_expression=prompt_for_visual_expression
)
print("debug_message: "+debug_message)
return debug_message
@classmethod
def generate(cls, user_message):
najiminoai = NajiminoAI(user_message)
return najiminoai.generate_recipe()
def generate_recipe(self):
user_message = self.user_message
constraints = get_filetext(filename = "constraints.md")
messages = [
{"role": "system", "content": constraints}
,{"role": "user", "content": user_message}
]
functions = get_functions_from_schema('schema.json')
message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions)
image = None
html = None
if message.get("function_call"):
function_name = message["function_call"]["name"]
args = json.loads(message["function_call"]["arguments"])
lang=args.get("lang")
title=args.get("title")
description=args.get("description")
prompt_for_visual_expression_in_en=args.get("prompt_for_visual_expression_in_en")
prompt_for_visual_expression = \
prompt_for_visual_expression_in_en
print("prompt_for_visual_expression: "+prompt_for_visual_expression)
# 画像生成にかかる時間を計測する
start = time.time()
image = StabilityAI.generate_image(prompt_for_visual_expression)
print("image generation time: "+str(time.time() - start))
function_response = self.create_character(
lang=lang,
title=title,
description=description,
prompt_for_visual_expression=prompt_for_visual_expression
)
html = (
"<div style='max-width:100%; overflow:auto'>"
+ "<p>"
+ markdown2.markdown(function_response)
+ "</div>"
)
return [image, html]
def main():
# インプット例をクリックした時のコールバック関数
def click_example(example):
# クリックされたインプット例をテキストボックスに自動入力
inputs.value = example
time.sleep(0.1) # テキストボックスに文字が表示されるまで待機
# 自動入力後に実行ボタンをクリックして結果を表示
execute_button.click()
iface = gr.Interface(fn=NajiminoAI.generate,
examples=[
["子どもに愛される丸顔で青いイルカのキャラクター"],
["活発で夏が似合う赤髪の女の子"],
["黒い鎧を着た魔王"],
],
inputs=gr.Textbox(label=inputs_label),
outputs=[
gr.Image(label="Visual Expression"),
"html"
],
title=title,
description=description,
article=article
)
iface.launch()
if __name__ == '__main__':
function = ''
if len(sys.argv) > 1:
function = sys.argv[1]
if function == 'generate':
NajiminoAI.generate("A brave knight with a mighty sword and strong armor")
elif function == 'generate_image':
image = StabilityAI.generate_image("Imagine a brave knight with a mighty sword and strong armor. He has a chiseled jawline and a confident expression on his face. His armor gleams under the sunlight, showing off its intricate design and craftsmanship. He holds his sword with pride, ready to protect his kingdom and its people at any cost.")
print("image: " + image)
if type(image) == PIL.PngImagePlugin.PngImageFile:
image.save("image.png")
else:
main() |