File size: 8,974 Bytes
3cef652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863b7e9
3cef652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import gradio as gr
import openai
import requests
import os
from dotenv import load_dotenv
import io
import sys
import json
import PIL
import time
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import markdown2

title="Character Generator AI"
inputs_label="Please tell me the characteristics of the character you want to create"
outputs_label="AI will generate a character description and visualize it"
visual_outputs_label="Character Visuals"
description="""
- Please input within the limits of 1000 characters. It may take about 50 seconds to generate a response.
"""

article = """
"""

load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
stability_api = client.StabilityInference(
    key=os.getenv('STABILITY_KEY'), 
    verbose=True,
    engine="stable-diffusion-xl-1024-v1-0",
)
MODEL = "gpt-4"

def get_filetext(filename, cache={}):
    if filename in cache:
        return cache[filename]
    else:
        if not os.path.exists(filename):
            raise ValueError(f"File '{filename}' not found")
        with open(filename, "r") as f:
            text = f.read()
        cache[filename] = text
        return text

def get_functions_from_schema(filename):
    schema = get_filetext(filename)
    schema_json = json.loads(schema)
    functions = schema_json.get("functions")
    return functions

class StabilityAI:
    @classmethod
    def generate_image(cls, visualize_prompt):

        print("visualize_prompt:"+visualize_prompt)

        answers = stability_api.generate(
            prompt=visualize_prompt,
        )
        
        for resp in answers:
            for artifact in resp.artifacts:
                if artifact.finish_reason == generation.FILTER:
                    print("NSFW")
                if artifact.type == generation.ARTIFACT_IMAGE:
                    img = PIL.Image.open(io.BytesIO(artifact.binary))
                    return img

class OpenAI:
    
    @classmethod
    def chat_completion(cls, prompt, start_with=""):
        constraints = get_filetext(filename = "constraints.md")
        template = get_filetext(filename = "template.md")
        
        data = {
            "model": MODEL,
            "messages": [
            {"role": "system", "content": constraints}
                ,{"role": "assistant", "content": template}
            ,{"role": "user", "content": prompt}
            ,{"role": "assistant", "content": start_with}
                    ],
                }
                
        start = time.time()
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {openai.api_key}"
            },
            json=data
            )
        print("gpt generation time: "+str(time.time() - start))

        result = response.json()
        print(result)

        content = result["choices"][0]["message"]["content"].strip()
        
        visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1]
        
        answers = stability_api.generate(
            prompt=visualize_prompt,
        )
        
    @classmethod
    def chat_completion_with_function(cls, prompt, messages, functions):
        print("prompt:"+prompt)
                
        # 文章生成にかかる時間を計測する
        start = time.time()
        # ChatCompletion APIを呼び出す
        response = openai.ChatCompletion.create(
                model=MODEL,
                messages=messages,
                functions=functions,
                function_call={"name": "create_character"}
            )
        print("gpt generation time: "+str(time.time() - start))

        # ChatCompletion APIから返された結果を取得する
        message = response.choices[0].message
        print("chat completion message: " + json.dumps(message, indent=2))
        
        return message
        
class NajiminoAI:

    def __init__(self, user_message):
        self.user_message = user_message

    def generate_recipe_prompt(self):
        template = get_filetext(filename="template.md")
        prompt = f"""
        {self.user_message}
        ---
        上記を元に、下記テンプレートを埋めてください。
        ---
        {template}
        """
        return prompt

    def create_character(self, lang, title, description, prompt_for_visual_expression):

        template = get_filetext(filename = "template.md")
        debug_message = template.format(
            lang=lang,
            title=title,
            description=description,
            prompt_for_visual_expression=prompt_for_visual_expression
        )
        
        print("debug_message: "+debug_message)
        
        return debug_message
    
    @classmethod
    def generate(cls, user_message):
        
        najiminoai = NajiminoAI(user_message)
        
        return najiminoai.generate_recipe()
    
    def generate_recipe(self):
        
        user_message = self.user_message
        constraints = get_filetext(filename = "constraints.md")
        
        messages = [
            {"role": "system", "content": constraints}
            ,{"role": "user", "content": user_message}
        ]
        
        functions = get_functions_from_schema('schema.json')
        
        message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions)
        
        image = None
        html = None
        if message.get("function_call"):
            function_name = message["function_call"]["name"]
            
            args = json.loads(message["function_call"]["arguments"])
            
            lang=args.get("lang")
            title=args.get("title")
            description=args.get("description")
            prompt_for_visual_expression_in_en=args.get("prompt_for_visual_expression_in_en")
            
            prompt_for_visual_expression = \
                prompt_for_visual_expression_in_en
                
            print("prompt_for_visual_expression: "+prompt_for_visual_expression)
            
            # 画像生成にかかる時間を計測する
            start = time.time()
            image = StabilityAI.generate_image(prompt_for_visual_expression)
            print("image generation time: "+str(time.time() - start))
            
            function_response = self.create_character(
                lang=lang,
                title=title,
                description=description,
                prompt_for_visual_expression=prompt_for_visual_expression
            )
            
            html = (
                "<div style='max-width:100%; overflow:auto'>"
                + "<p>"
                + markdown2.markdown(function_response)
                + "</div>"
            )
        return [image, html]
            
def main():
    # インプット例をクリックした時のコールバック関数
    def click_example(example):
        # クリックされたインプット例をテキストボックスに自動入力
        inputs.value = example
        time.sleep(0.1)  # テキストボックスに文字が表示されるまで待機
        # 自動入力後に実行ボタンをクリックして結果を表示
        execute_button.click()

    iface = gr.Interface(fn=NajiminoAI.generate,
                        examples=[
                            ["子どもに愛される丸顔で青いイルカのキャラクター"],
                            ["活発で夏が似合う赤髪の女の子"],
                            ["黒い鎧を着た魔王"],
                        ],
                        inputs=gr.Textbox(label=inputs_label),
                        outputs=[
                            gr.Image(label="Visual Expression"),
                            "html"
                            ],
                        title=title,
                        description=description,
                        article=article
                        )

    iface.launch()

if __name__ == '__main__':
    function = ''
    if len(sys.argv) > 1:
        function = sys.argv[1]
    
    if function == 'generate':
        NajiminoAI.generate("A brave knight with a mighty sword and strong armor")
        
    elif function == 'generate_image':
        image = StabilityAI.generate_image("Imagine a brave knight with a mighty sword and strong armor. He has a chiseled jawline and a confident expression on his face. His armor gleams under the sunlight, showing off its intricate design and craftsmanship. He holds his sword with pride, ready to protect his kingdom and its people at any cost.")
        print("image: " + image)
        
        if type(image) == PIL.PngImagePlugin.PngImageFile:
            image.save("image.png")

    else:
        main()