Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,292 +2,312 @@ import gradio as gr
|
|
2 |
import io
|
3 |
import base64
|
4 |
import os
|
|
|
|
|
5 |
from PIL import Image
|
6 |
from huggingface_hub import InferenceClient
|
7 |
-
from
|
|
|
8 |
|
9 |
# Load API keys from environment variables
|
10 |
inference_api_key = os.environ.get("HF_TOKEN")
|
11 |
-
|
12 |
-
chat_api_key2 = os.environ.get("OPENROUTER_TOKEN")
|
13 |
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
# Global variable to store the image data URL and prompt for the currently generated image.
|
17 |
global_image_data_url = None
|
18 |
global_image_prompt = None # Still stored if needed elsewhere
|
19 |
|
20 |
-
def
|
|
|
|
|
|
|
21 |
"""
|
22 |
-
|
23 |
-
based on the selected difficulty, age, autism level, and any extra details the user provides.
|
24 |
"""
|
25 |
query = (
|
26 |
-
|
27 |
f"""
|
28 |
-
Follow the instructions below to
|
29 |
-
Consider the following parameters
|
30 |
-
- Difficulty: {difficulty}
|
31 |
-
- Age: {age}
|
32 |
-
- Autism Level: {
|
33 |
-
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
* **Literal Interpretation:** The images should be highly literal. Avoid metaphors, symbolism, or implied meanings. If depicting a sequence of events, make each step visually distinct.
|
49 |
-
* **Defined Borders:** Consider using clear outlines or borders around objects and people to enhance visual separation and definition.
|
50 |
-
* **Consistent Style:** Maintain a consistent visual style across multiple images. This helps build familiarity and predictability.
|
51 |
-
|
52 |
-
**3. Sensory Considerations:**
|
53 |
-
|
54 |
-
* **Soft Color Palette:** Favor muted, calming colors. Avoid overly bright, saturated, or fluorescent colors.
|
55 |
-
* **Reduced Visual Complexity:** Limit the number of elements in the image to prevent sensory overload.
|
56 |
-
* **Smooth Textures:** If textures are depicted, they should appear smooth and non-threatening. Avoid rough, jagged, or overly detailed textures.
|
57 |
-
|
58 |
-
**4. Positive and Supportive Imagery:**
|
59 |
-
|
60 |
-
* **Positive Reinforcement:** Images should be encouraging and positive. Depict success, cooperation, and positive social interactions.
|
61 |
-
* **Calm and Relaxing Scenes:** Consider scenes that promote calmness, such as nature scenes (e.g., a quiet forest, a calm beach), or familiar, safe environments (e.g., a cozy bedroom, a well-organized classroom).
|
62 |
-
* **Avoidance of Triggers:** Be mindful of potential triggers for anxiety or distress. Avoid images that depict conflict, overwhelming crowds, or potentially frightening situations.
|
63 |
-
|
64 |
-
**5. Specific Use Cases (Adapt as needed):**
|
65 |
-
|
66 |
-
* **Social Stories:** If generating images for a social story, ensure each image clearly illustrates a single step in the sequence. Use consistent characters and settings throughout the story.
|
67 |
-
* **Visual Schedules:** If creating images for a visual schedule, make each activity easily identifiable and visually distinct.
|
68 |
-
* **Emotion Recognition:** If depicting emotions, use clear facial expressions and body language. Consider using a consistent character to represent different emotions.
|
69 |
-
* **Communication Aids:** If creating images for communication, ensure the objects or actions are clearly depicted and easily recognizable.
|
70 |
-
* **Daily Routines**: Brushing teeth, eating food, going to school.
|
71 |
-
* **Learning concepts**: Shapes, colors, animals, numbers, alphabet.
|
72 |
-
|
73 |
-
**Prompting Instructions:**
|
74 |
-
|
75 |
-
When providing a prompt to the model, be as specific as possible, including:
|
76 |
-
|
77 |
-
* **The subject of the image:** "A boy brushing his teeth."
|
78 |
-
* **The desired style:** "Simple, clear, with a solid light blue background."
|
79 |
-
* **The intended use:** "For a visual schedule."
|
80 |
-
* **Any specific details:** "The boy should be smiling. The toothbrush should be blue."
|
81 |
-
* **Emotions:** Clearly state the emotion "happy" or "calm."
|
82 |
-
|
83 |
-
**Example Prompts (using the above system prompt as a base):**
|
84 |
-
|
85 |
-
* "Generate an image for a visual schedule. The subject is 'eating lunch.' Show a child sitting at a table with a plate of food (sandwich, apple slices, and a glass of milk). The background should be a solid, pale green. The child should be smiling. Use a clear, simple style with defined outlines."
|
86 |
-
* "Generate an image to help with emotion recognition. The subject is 'sad.' Show a child's face with a single tear rolling down their cheek and a downturned mouth. The background should be a solid, light gray. Use a simple, realistic style."
|
87 |
-
* "Generate an image for a social story about going to the doctor. Show a child sitting in a doctor's waiting room, calmly looking at a book. The room should have a few simple toys and a window. The background should be a soft blue. The style should be clear and uncluttered."
|
88 |
-
* "Generate a picture of two block shapes in a simple, cartoon style. One red square and one blue circle. Place them on a white background."
|
89 |
-
* "Generate a cartoon image of a dog. Make the dog appear to be friendly and non-threatening. Use warm colors."
|
90 |
-
|
91 |
-
Ensure your Prompts are acccurate and ensure the images are accurate and dont have any irregularities or deforamtions in them.
|
92 |
-
use descriptive and detailed prompts
|
93 |
"""
|
94 |
)
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
"role": "user",
|
99 |
-
"content": query
|
100 |
-
}
|
101 |
-
]
|
102 |
-
|
103 |
-
client = OpenAI(
|
104 |
-
base_url="https://openrouter.ai/api/v1",
|
105 |
-
api_key=chat_api_key2
|
106 |
-
)
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
temperature=0.5,
|
111 |
-
messages=messages,
|
112 |
-
max_tokens=8192,
|
113 |
-
stream=True
|
114 |
-
)
|
115 |
|
116 |
-
|
117 |
-
for chunk in stream:
|
118 |
-
response_text += chunk.choices[0].delta.content
|
119 |
-
return response_text.strip()
|
120 |
|
121 |
-
def generate_image_fn(selected_prompt, guidance_scale=7.5,
|
|
|
|
|
122 |
"""
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
Additional parameters:
|
127 |
-
- guidance_scale: Influences how strongly the image generation adheres to the prompt.
|
128 |
-
- negative_prompt: Specifies undesirable elements to avoid in the generated image.
|
129 |
-
- num_inference_steps: The number of denoising steps for image generation.
|
130 |
"""
|
131 |
global global_image_data_url, global_image_prompt
|
132 |
-
|
133 |
-
# Save the chosen prompt for potential future use.
|
134 |
global_image_prompt = selected_prompt
|
135 |
-
|
136 |
-
image_client = InferenceClient(
|
137 |
-
provider="hf-inference",
|
138 |
-
api_key=inference_api_key
|
139 |
-
)
|
140 |
-
|
141 |
image = image_client.text_to_image(
|
142 |
selected_prompt,
|
143 |
-
model="stabilityai/stable-diffusion-3.5-large-turbo",
|
144 |
guidance_scale=guidance_scale,
|
145 |
negative_prompt=negative_prompt,
|
146 |
num_inference_steps=num_inference_steps
|
147 |
)
|
148 |
-
|
149 |
buffered = io.BytesIO()
|
150 |
image.save(buffered, format="PNG")
|
151 |
img_bytes = buffered.getvalue()
|
152 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
153 |
global_image_data_url = f"data:image/png;base64,{img_b64}"
|
154 |
-
|
155 |
return image
|
156 |
|
157 |
-
def generate_image_and_reset_chat(
|
158 |
"""
|
159 |
-
|
160 |
-
|
161 |
"""
|
162 |
new_sessions = saved_sessions.copy()
|
163 |
if active_session.get("prompt"):
|
164 |
new_sessions.append(active_session)
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
new_active_session = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
return image, new_active_session, new_sessions
|
171 |
|
172 |
-
def compare_details_chat_fn(user_details):
|
173 |
"""
|
174 |
-
|
175 |
-
The message includes both the image (using its data URL) and the user’s text.
|
176 |
"""
|
177 |
if not global_image_data_url:
|
178 |
return "Please generate an image first."
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
)
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def chat_respond(user_message, active_session, saved_sessions):
|
232 |
"""
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
"""
|
237 |
if not active_session.get("image"):
|
238 |
bot_message = "Please generate an image first."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
else:
|
240 |
-
|
|
|
241 |
|
242 |
-
|
243 |
-
active_session["chat"] = updated_chat
|
244 |
-
return "", updated_chat, saved_sessions, active_session
|
245 |
|
246 |
def update_sessions(saved_sessions, active_session):
|
247 |
"""
|
248 |
-
|
249 |
-
so that the sidebar always displays the complete session details.
|
250 |
"""
|
251 |
if active_session and active_session.get("prompt"):
|
252 |
return saved_sessions + [active_session]
|
253 |
return saved_sessions
|
254 |
|
255 |
##############################################
|
256 |
-
#
|
257 |
-
##############################################
|
258 |
-
difficulty_options = ["Simple", "Average", "Detailed"]
|
259 |
-
level_options = ["Level 1", "Level 2", "Level 3"]
|
260 |
-
|
261 |
-
##############################################
|
262 |
-
# Create the Gradio Interface (Single-Page) with a Sidebar for Session Details
|
263 |
##############################################
|
264 |
with gr.Blocks() as demo:
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
saved_sessions = gr.State([])
|
267 |
|
268 |
with gr.Column():
|
269 |
gr.Markdown("# Image Generation & Chat Inference")
|
|
|
|
|
270 |
|
271 |
# ----- Image Generation Section -----
|
272 |
with gr.Column():
|
273 |
gr.Markdown("## Generate Image")
|
274 |
-
gr.Markdown("
|
275 |
with gr.Row():
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
lines=2
|
285 |
)
|
286 |
generate_btn = gr.Button("Generate Image")
|
287 |
img_output = gr.Image(label="Generated Image")
|
288 |
generate_btn.click(
|
289 |
generate_image_and_reset_chat,
|
290 |
-
inputs=[
|
291 |
outputs=[img_output, active_session, saved_sessions]
|
292 |
)
|
293 |
|
@@ -296,15 +316,12 @@ with gr.Blocks() as demo:
|
|
296 |
gr.Markdown("## Chat about the Image")
|
297 |
gr.Markdown(
|
298 |
"After generating an image, type details or descriptions about it. "
|
299 |
-
"Your message
|
300 |
-
"which will evaluate your description based on what it sees in the image. "
|
301 |
-
"The response will include a correctness percentage and hints if needed."
|
302 |
)
|
303 |
chatbot = gr.Chatbot(label="Chat History")
|
304 |
with gr.Row():
|
305 |
chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False)
|
306 |
send_btn = gr.Button("Send")
|
307 |
-
|
308 |
send_btn.click(
|
309 |
chat_respond,
|
310 |
inputs=[chat_input, active_session, saved_sessions],
|
@@ -322,11 +339,13 @@ with gr.Blocks() as demo:
|
|
322 |
gr.Markdown(
|
323 |
"This sidebar automatically saves finished chat sessions. "
|
324 |
"Each session includes the prompt used, the generated image (as a data URL), "
|
325 |
-
"
|
326 |
)
|
327 |
sessions_output = gr.JSON(label="Session Details", value={})
|
328 |
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
|
|
|
|
329 |
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
330 |
|
331 |
-
# Launch the app.
|
332 |
demo.launch()
|
|
|
2 |
import io
|
3 |
import base64
|
4 |
import os
|
5 |
+
import json
|
6 |
+
import re
|
7 |
from PIL import Image
|
8 |
from huggingface_hub import InferenceClient
|
9 |
+
from google.generativeai import configure, GenerativeModel
|
10 |
+
from google.ai.generativelanguage import Content, Part
|
11 |
|
12 |
# Load API keys from environment variables
|
13 |
inference_api_key = os.environ.get("HF_TOKEN")
|
14 |
+
google_api_key = os.environ.get("GOOGLE_API_KEY") # New Google API key
|
|
|
15 |
|
16 |
+
# Configure Google API
|
17 |
+
configure(api_key=google_api_key)
|
18 |
|
19 |
+
# Global variables to store the image data URL and prompt for the currently generated image.
|
|
|
20 |
global_image_data_url = None
|
21 |
global_image_prompt = None # Still stored if needed elsewhere
|
22 |
|
23 |
+
def update_difficulty_label(active_session):
|
24 |
+
return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"
|
25 |
+
|
26 |
+
def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, treatment_plan=""):
|
27 |
"""
|
28 |
+
Generate an image prompt using Google's Gemini model.
|
|
|
29 |
"""
|
30 |
query = (
|
|
|
31 |
f"""
|
32 |
+
Follow the instructions below to generate an image generation prompt for an educational image intended for autistic children.
|
33 |
+
Consider the following parameters:
|
34 |
+
- Difficulty: {difficulty}
|
35 |
+
- Age: {age}
|
36 |
+
- Autism Level: {autism_level}
|
37 |
+
- Topic Focus: {topic_focus}
|
38 |
+
- Treatment Plan: {treatment_plan}
|
39 |
+
|
40 |
+
Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.
|
41 |
+
|
42 |
+
The image should specifically focus on the topic: "{topic_focus}".
|
43 |
+
|
44 |
+
Please generate a prompt that instructs the image generation engine to produce an image with:
|
45 |
+
1. Clarity and simplicity (minimalist backgrounds, clear subject)
|
46 |
+
2. Literal representation with defined borders and consistent style
|
47 |
+
3. Soft, muted colors and reduced visual complexity
|
48 |
+
4. Positive, calm scenes
|
49 |
+
5. Clear focus on the specified topic
|
50 |
+
|
51 |
+
Use descriptive and detailed language.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
"""
|
53 |
)
|
54 |
|
55 |
+
# Initialize the Gemini Pro model
|
56 |
+
model = GenerativeModel('gemini-2.0-flash-lite')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# Generate content using the Gemini model
|
59 |
+
response = model.generate_content(query)
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
return response.text.strip()
|
|
|
|
|
|
|
62 |
|
63 |
+
def generate_image_fn(selected_prompt, guidance_scale=7.5,
|
64 |
+
negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs",
|
65 |
+
num_inference_steps=50):
|
66 |
"""
|
67 |
+
Generate an image from the prompt via the Hugging Face Inference API.
|
68 |
+
Convert the image to a data URL.
|
|
|
|
|
|
|
|
|
|
|
69 |
"""
|
70 |
global global_image_data_url, global_image_prompt
|
|
|
|
|
71 |
global_image_prompt = selected_prompt
|
72 |
+
image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
|
|
|
|
|
|
|
|
|
|
|
73 |
image = image_client.text_to_image(
|
74 |
selected_prompt,
|
75 |
+
model="stabilityai/stable-diffusion-3.5-large-turbo",
|
76 |
guidance_scale=guidance_scale,
|
77 |
negative_prompt=negative_prompt,
|
78 |
num_inference_steps=num_inference_steps
|
79 |
)
|
|
|
80 |
buffered = io.BytesIO()
|
81 |
image.save(buffered, format="PNG")
|
82 |
img_bytes = buffered.getvalue()
|
83 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
84 |
global_image_data_url = f"data:image/png;base64,{img_b64}"
|
|
|
85 |
return image
|
86 |
|
87 |
+
def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
|
88 |
"""
|
89 |
+
Generate a new image (with the current difficulty) and reset the chat.
|
90 |
+
Now includes the topic_focus parameter to specify what the image should focus on.
|
91 |
"""
|
92 |
new_sessions = saved_sessions.copy()
|
93 |
if active_session.get("prompt"):
|
94 |
new_sessions.append(active_session)
|
95 |
+
# Use the current difficulty from the active session (which should be updated if advanced)
|
96 |
+
current_difficulty = active_session.get("difficulty", "Very Simple")
|
97 |
+
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
|
98 |
+
image = generate_image_fn(generated_prompt)
|
99 |
+
new_active_session = {
|
100 |
+
"prompt": generated_prompt,
|
101 |
+
"image": global_image_data_url,
|
102 |
+
"chat": [],
|
103 |
+
"treatment_plan": treatment_plan,
|
104 |
+
"topic_focus": topic_focus,
|
105 |
+
"identified_details": [],
|
106 |
+
"difficulty": current_difficulty,
|
107 |
+
"autism_level": autism_level,
|
108 |
+
"age": age
|
109 |
+
}
|
110 |
return image, new_active_session, new_sessions
|
111 |
|
112 |
+
def compare_details_chat_fn(user_details, treatment_plan, chat_history, identified_details):
|
113 |
"""
|
114 |
+
Evaluate the child's description using Google's Gemini Vision model.
|
|
|
115 |
"""
|
116 |
if not global_image_data_url:
|
117 |
return "Please generate an image first."
|
118 |
|
119 |
+
history_text = ""
|
120 |
+
if chat_history:
|
121 |
+
history_text = "\n\n### Previous Conversation:\n"
|
122 |
+
for idx, (user_msg, bot_msg) in enumerate(chat_history, 1):
|
123 |
+
history_text += f"Turn {idx}:\nUser: {user_msg}\nTeacher: {bot_msg}\n"
|
124 |
+
|
125 |
+
identified_details_text = ""
|
126 |
+
if identified_details:
|
127 |
+
identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)
|
128 |
+
|
129 |
+
message_text = (
|
130 |
+
f"{history_text}{identified_details_text}\n\n"
|
131 |
+
f"Based on the image provided above, please evaluate the following description given by the child:\n"
|
132 |
+
f"'{user_details}'\n\n"
|
133 |
+
"You are a kind and encouraging teacher speaking to a child. Use simple, clear language. "
|
134 |
+
"Praise the child's correct observations and provide a gentle hint if something is missing. "
|
135 |
+
"Keep your feedback positive and easy to understand.\n\n"
|
136 |
+
"Focus on these evaluation criteria:\n"
|
137 |
+
"1. **Object Identification** – Did the child mention the main objects?\n"
|
138 |
+
"2. **Color & Shape Accuracy** – Were the colors and shapes described correctly?\n"
|
139 |
+
"3. **Clarity & Simplicity** – Was the description clear and easy to understand?\n"
|
140 |
+
"4. **Overall Communication** – How well did the child communicate their thoughts?\n\n"
|
141 |
+
"Note: As difficulty increases, the expected level of detail is higher. Evaluate accordingly.\n\n"
|
142 |
+
"Return your evaluation strictly as a JSON object with the following keys:\n"
|
143 |
+
"{\n"
|
144 |
+
" \"scores\": {\n"
|
145 |
+
" \"object_identification\": <number>,\n"
|
146 |
+
" \"color_shape_accuracy\": <number>,\n"
|
147 |
+
" \"clarity_simplicity\": <number>,\n"
|
148 |
+
" \"overall_communication\": <number>\n"
|
149 |
+
" },\n"
|
150 |
+
" \"final_score\": <number>,\n"
|
151 |
+
" \"feedback\": \"<string>\",\n"
|
152 |
+
" \"hint\": \"<string>\",\n"
|
153 |
+
" \"advance\": <boolean>\n"
|
154 |
+
"}\n\n"
|
155 |
+
"Do not include any additional text outside the JSON."
|
156 |
)
|
157 |
|
158 |
+
# Remove the data:image/png;base64, prefix to get just the base64 string
|
159 |
+
base64_img = global_image_data_url.split(",")[1]
|
160 |
+
|
161 |
+
# Create a Gemini Vision Pro model
|
162 |
+
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
163 |
+
|
164 |
+
# Create the content with image and text using the correct parameters
|
165 |
+
# Use 'inline_data' instead of 'content' for the image part
|
166 |
+
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
|
167 |
+
text_part = Part(text=message_text)
|
168 |
+
multimodal_content = Content(parts=[image_part, text_part])
|
169 |
|
170 |
+
# Generate evaluation using the vision model
|
171 |
+
response = vision_model.generate_content(multimodal_content)
|
172 |
+
|
173 |
+
return response.text
|
174 |
+
|
175 |
+
def evaluate_scores(evaluation_text, current_difficulty):
|
176 |
+
"""
|
177 |
+
Parse the JSON evaluation and decide if the child advances.
|
178 |
+
The threshold scales with difficulty:
|
179 |
+
Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90.
|
180 |
+
"""
|
181 |
+
try:
|
182 |
+
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
|
183 |
+
if json_match:
|
184 |
+
json_str = json_match.group(0)
|
185 |
+
evaluation = json.loads(json_str)
|
186 |
+
else:
|
187 |
+
raise ValueError("No JSON object found in the response.")
|
188 |
+
final_score = evaluation.get("final_score", 0)
|
189 |
+
hint = evaluation.get("hint", "Keep trying!")
|
190 |
+
advance = evaluation.get("advance", False)
|
191 |
+
difficulty_thresholds = {
|
192 |
+
"Very Simple": 70,
|
193 |
+
"Simple": 75,
|
194 |
+
"Moderate": 80,
|
195 |
+
"Detailed": 85,
|
196 |
+
"Very Detailed": 90
|
197 |
+
}
|
198 |
+
current_threshold = difficulty_thresholds.get(current_difficulty, 70)
|
199 |
+
difficulty_mapping = {
|
200 |
+
"Very Simple": "Simple",
|
201 |
+
"Simple": "Moderate",
|
202 |
+
"Moderate": "Detailed",
|
203 |
+
"Detailed": "Very Detailed",
|
204 |
+
"Very Detailed": "Very Detailed"
|
205 |
+
}
|
206 |
+
if final_score >= current_threshold or advance:
|
207 |
+
new_difficulty = difficulty_mapping.get(current_difficulty, current_difficulty)
|
208 |
+
response_msg = (f"Great job! Your final score is {final_score}, which meets the target of {current_threshold}. "
|
209 |
+
f"You've advanced to {new_difficulty} difficulty.")
|
210 |
+
return response_msg, new_difficulty
|
211 |
+
else:
|
212 |
+
response_msg = (f"Your final score is {final_score} (\n target: {current_threshold}). {hint} \n "
|
213 |
+
f"Please try again at the {current_difficulty} level.")
|
214 |
+
return response_msg, current_difficulty
|
215 |
+
except Exception as e:
|
216 |
+
return f"Error processing evaluation output: {str(e)}", current_difficulty
|
217 |
|
218 |
def chat_respond(user_message, active_session, saved_sessions):
|
219 |
"""
|
220 |
+
Process a new chat message.
|
221 |
+
Evaluate the child's description. If the evaluation indicates advancement,
|
222 |
+
update the difficulty, generate a new image (resetting image and chat), and update the difficulty label.
|
223 |
"""
|
224 |
if not active_session.get("image"):
|
225 |
bot_message = "Please generate an image first."
|
226 |
+
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
|
227 |
+
active_session["chat"] = updated_chat
|
228 |
+
return "", updated_chat, saved_sessions, active_session
|
229 |
+
|
230 |
+
chat_history = active_session.get("chat", [])
|
231 |
+
identified_details = active_session.get("identified_details", [])
|
232 |
+
raw_evaluation = compare_details_chat_fn(user_message, "", chat_history, identified_details)
|
233 |
+
current_difficulty = active_session.get("difficulty", "Very Simple")
|
234 |
+
evaluation_response, updated_difficulty = evaluate_scores(raw_evaluation, current_difficulty)
|
235 |
+
bot_message = evaluation_response
|
236 |
+
|
237 |
+
# If the child advanced, update difficulty and generate a new image
|
238 |
+
if updated_difficulty != current_difficulty:
|
239 |
+
# Update the active session's difficulty before generating a new prompt
|
240 |
+
active_session["difficulty"] = updated_difficulty
|
241 |
+
age = active_session.get("age", "3")
|
242 |
+
autism_level = active_session.get("autism_level", "Level 1")
|
243 |
+
topic_focus = active_session.get("topic_focus", "")
|
244 |
+
treatment_plan = active_session.get("treatment_plan", "")
|
245 |
+
new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions)
|
246 |
+
new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."))
|
247 |
+
active_session = new_active_session
|
248 |
+
bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."
|
249 |
+
saved_sessions = new_sessions
|
250 |
else:
|
251 |
+
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
|
252 |
+
active_session["chat"] = updated_chat
|
253 |
|
254 |
+
return "", active_session["chat"], saved_sessions, active_session
|
|
|
|
|
255 |
|
256 |
def update_sessions(saved_sessions, active_session):
|
257 |
"""
|
258 |
+
Combine finished sessions with the active session for display.
|
|
|
259 |
"""
|
260 |
if active_session and active_session.get("prompt"):
|
261 |
return saved_sessions + [active_session]
|
262 |
return saved_sessions
|
263 |
|
264 |
##############################################
|
265 |
+
# Gradio Interface
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
##############################################
|
267 |
with gr.Blocks() as demo:
|
268 |
+
# The active session now starts with difficulty "Very Simple"
|
269 |
+
active_session = gr.State({
|
270 |
+
"prompt": None,
|
271 |
+
"image": None,
|
272 |
+
"chat": [],
|
273 |
+
"treatment_plan": "",
|
274 |
+
"topic_focus": "",
|
275 |
+
"identified_details": [],
|
276 |
+
"difficulty": "Very Simple",
|
277 |
+
"age": "3",
|
278 |
+
"autism_level": "Level 1"
|
279 |
+
})
|
280 |
saved_sessions = gr.State([])
|
281 |
|
282 |
with gr.Column():
|
283 |
gr.Markdown("# Image Generation & Chat Inference")
|
284 |
+
# Display current difficulty label
|
285 |
+
difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
|
286 |
|
287 |
# ----- Image Generation Section -----
|
288 |
with gr.Column():
|
289 |
gr.Markdown("## Generate Image")
|
290 |
+
gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.")
|
291 |
with gr.Row():
|
292 |
+
age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3")
|
293 |
+
autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")
|
294 |
+
|
295 |
+
topic_focus_input = gr.Textbox(
|
296 |
+
label="Topic Focus",
|
297 |
+
placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
|
298 |
+
lines=1
|
299 |
+
)
|
300 |
+
|
301 |
+
treatment_plan_input = gr.Textbox(
|
302 |
+
label="Treatment Plan",
|
303 |
+
placeholder="Enter the treatment plan to guide the image generation...",
|
304 |
lines=2
|
305 |
)
|
306 |
generate_btn = gr.Button("Generate Image")
|
307 |
img_output = gr.Image(label="Generated Image")
|
308 |
generate_btn.click(
|
309 |
generate_image_and_reset_chat,
|
310 |
+
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
|
311 |
outputs=[img_output, active_session, saved_sessions]
|
312 |
)
|
313 |
|
|
|
316 |
gr.Markdown("## Chat about the Image")
|
317 |
gr.Markdown(
|
318 |
"After generating an image, type details or descriptions about it. "
|
319 |
+
"Your message, along with the generated image and conversation history, will be sent for evaluation."
|
|
|
|
|
320 |
)
|
321 |
chatbot = gr.Chatbot(label="Chat History")
|
322 |
with gr.Row():
|
323 |
chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False)
|
324 |
send_btn = gr.Button("Send")
|
|
|
325 |
send_btn.click(
|
326 |
chat_respond,
|
327 |
inputs=[chat_input, active_session, saved_sessions],
|
|
|
339 |
gr.Markdown(
|
340 |
"This sidebar automatically saves finished chat sessions. "
|
341 |
"Each session includes the prompt used, the generated image (as a data URL), "
|
342 |
+
"the topic focus, the treatment plan, the list of identified details, and the full chat history."
|
343 |
)
|
344 |
sessions_output = gr.JSON(label="Session Details", value={})
|
345 |
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
346 |
+
# Update the current difficulty label when active_session changes.
|
347 |
+
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
|
348 |
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
349 |
|
350 |
+
# Launch the app with public sharing enabled.
|
351 |
demo.launch()
|