Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,17 +8,20 @@ from PIL import Image
|
|
8 |
from huggingface_hub import InferenceClient
|
9 |
from google.generativeai import configure, GenerativeModel
|
10 |
from google.ai.generativelanguage import Content, Part
|
|
|
|
|
11 |
|
12 |
# Load API keys from environment variables
|
13 |
inference_api_key = os.environ.get("HF_TOKEN")
|
14 |
-
google_api_key = os.environ.get("GOOGLE_API_KEY")
|
15 |
|
16 |
# Configure Google API
|
17 |
configure(api_key=google_api_key)
|
18 |
|
19 |
-
# Global variables to store the image data URL
|
20 |
global_image_data_url = None
|
21 |
-
global_image_prompt = None
|
|
|
22 |
|
23 |
def update_difficulty_label(active_session):
|
24 |
return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"
|
@@ -36,28 +39,53 @@ def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, tre
|
|
36 |
- Autism Level: {autism_level}
|
37 |
- Topic Focus: {topic_focus}
|
38 |
- Treatment Plan: {treatment_plan}
|
39 |
-
|
40 |
Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.
|
41 |
-
|
42 |
The image should specifically focus on the topic: "{topic_focus}".
|
43 |
-
|
44 |
Please generate a prompt that instructs the image generation engine to produce an image with:
|
45 |
1. Clarity and simplicity (minimalist backgrounds, clear subject)
|
46 |
2. Literal representation with defined borders and consistent style
|
47 |
3. Soft, muted colors and reduced visual complexity
|
48 |
4. Positive, calm scenes
|
49 |
5. Clear focus on the specified topic
|
50 |
-
|
51 |
Use descriptive and detailed language.
|
52 |
"""
|
53 |
)
|
54 |
-
|
55 |
# Initialize the Gemini Pro model
|
56 |
model = GenerativeModel('gemini-2.0-flash-lite')
|
57 |
-
|
58 |
# Generate content using the Gemini model
|
59 |
response = model.generate_content(query)
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
return response.text.strip()
|
62 |
|
63 |
def generate_image_fn(selected_prompt, guidance_scale=7.5,
|
@@ -84,100 +112,166 @@ def generate_image_fn(selected_prompt, guidance_scale=7.5,
|
|
84 |
global_image_data_url = f"data:image/png;base64,{img_b64}"
|
85 |
return image
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
|
88 |
"""
|
89 |
Generate a new image (with the current difficulty) and reset the chat.
|
90 |
Now includes the topic_focus parameter to specify what the image should focus on.
|
91 |
"""
|
|
|
92 |
new_sessions = saved_sessions.copy()
|
93 |
if active_session.get("prompt"):
|
94 |
new_sessions.append(active_session)
|
95 |
-
|
|
|
96 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
|
|
|
|
97 |
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
|
|
|
|
|
98 |
image = generate_image_fn(generated_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
new_active_session = {
|
100 |
"prompt": generated_prompt,
|
101 |
"image": global_image_data_url,
|
|
|
102 |
"chat": [],
|
103 |
"treatment_plan": treatment_plan,
|
104 |
"topic_focus": topic_focus,
|
|
|
105 |
"identified_details": [],
|
|
|
106 |
"difficulty": current_difficulty,
|
107 |
"autism_level": autism_level,
|
108 |
"age": age
|
109 |
}
|
110 |
-
return image, new_active_session, new_sessions
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
"""
|
114 |
-
Evaluate the child's description using Google's Gemini
|
|
|
115 |
"""
|
116 |
-
if not global_image_data_url:
|
117 |
return "Please generate an image first."
|
118 |
|
|
|
|
|
|
|
|
|
|
|
119 |
history_text = ""
|
120 |
if chat_history:
|
121 |
history_text = "\n\n### Previous Conversation:\n"
|
122 |
-
for idx, (
|
123 |
-
history_text += f"Turn {idx}:\
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
identified_details_text = ""
|
126 |
if identified_details:
|
127 |
identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
message_text = (
|
130 |
-
f"
|
131 |
-
f"
|
132 |
-
f"
|
133 |
-
"
|
134 |
-
"
|
135 |
-
"
|
136 |
-
"
|
137 |
-
"
|
138 |
-
"
|
139 |
-
"
|
140 |
-
"
|
141 |
-
"
|
142 |
-
"
|
|
|
|
|
|
|
|
|
143 |
"{\n"
|
144 |
-
" \"
|
145 |
-
"
|
146 |
-
"
|
147 |
-
"
|
148 |
-
"
|
149 |
-
" },\n"
|
150 |
-
" \"final_score\": <number>,\n"
|
151 |
-
" \"feedback\": \"<string>\",\n"
|
152 |
-
" \"hint\": \"<string>\",\n"
|
153 |
-
" \"advance\": <boolean>\n"
|
154 |
"}\n\n"
|
155 |
-
"
|
156 |
)
|
157 |
|
158 |
-
#
|
159 |
-
|
160 |
-
|
161 |
-
# Create a Gemini Vision Pro model
|
162 |
-
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
163 |
-
|
164 |
-
# Create the content with image and text using the correct parameters
|
165 |
-
# Use 'inline_data' instead of 'content' for the image part
|
166 |
-
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
|
167 |
-
text_part = Part(text=message_text)
|
168 |
-
multimodal_content = Content(parts=[image_part, text_part])
|
169 |
-
|
170 |
-
# Generate evaluation using the vision model
|
171 |
-
response = vision_model.generate_content(multimodal_content)
|
172 |
|
|
|
|
|
173 |
return response.text
|
174 |
|
175 |
-
def
|
176 |
-
"""
|
177 |
-
Parse the JSON evaluation and decide if the child advances.
|
178 |
-
The threshold scales with difficulty:
|
179 |
-
Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90.
|
180 |
-
"""
|
181 |
try:
|
182 |
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
|
183 |
if json_match:
|
@@ -185,73 +279,164 @@ def evaluate_scores(evaluation_text, current_difficulty):
|
|
185 |
evaluation = json.loads(json_str)
|
186 |
else:
|
187 |
raise ValueError("No JSON object found in the response.")
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
except Exception as e:
|
216 |
-
|
|
|
217 |
|
218 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
"""
|
220 |
Process a new chat message.
|
221 |
-
Evaluate the child's description
|
222 |
-
update the difficulty, generate a new image (resetting image and chat), and update the difficulty label.
|
223 |
"""
|
224 |
if not active_session.get("image"):
|
225 |
bot_message = "Please generate an image first."
|
226 |
-
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
|
227 |
active_session["chat"] = updated_chat
|
228 |
-
return "", updated_chat, saved_sessions, active_session
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
-
# If the child
|
238 |
-
if
|
239 |
-
#
|
240 |
-
|
|
|
|
|
|
|
241 |
age = active_session.get("age", "3")
|
242 |
autism_level = active_session.get("autism_level", "Level 1")
|
243 |
topic_focus = active_session.get("topic_focus", "")
|
244 |
treatment_plan = active_session.get("treatment_plan", "")
|
245 |
-
new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions)
|
246 |
-
new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."))
|
247 |
-
active_session = new_active_session
|
248 |
-
bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."
|
249 |
-
saved_sessions = new_sessions
|
250 |
-
else:
|
251 |
-
updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
|
252 |
-
active_session["chat"] = updated_chat
|
253 |
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
def update_sessions(saved_sessions, active_session):
|
257 |
"""
|
@@ -265,87 +450,240 @@ def update_sessions(saved_sessions, active_session):
|
|
265 |
# Gradio Interface
|
266 |
##############################################
|
267 |
with gr.Blocks() as demo:
|
268 |
-
#
|
269 |
active_session = gr.State({
|
270 |
"prompt": None,
|
271 |
"image": None,
|
|
|
272 |
"chat": [],
|
273 |
"treatment_plan": "",
|
274 |
"topic_focus": "",
|
|
|
275 |
"identified_details": [],
|
|
|
276 |
"difficulty": "Very Simple",
|
277 |
"age": "3",
|
278 |
"autism_level": "Level 1"
|
279 |
})
|
280 |
saved_sessions = gr.State([])
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
#
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
with gr.Column():
|
289 |
-
gr.Markdown("##
|
290 |
-
gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.")
|
291 |
-
with gr.Row():
|
292 |
-
age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3")
|
293 |
-
autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")
|
294 |
-
|
295 |
-
topic_focus_input = gr.Textbox(
|
296 |
-
label="Topic Focus",
|
297 |
-
placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
|
298 |
-
lines=1
|
299 |
-
)
|
300 |
-
|
301 |
-
treatment_plan_input = gr.Textbox(
|
302 |
-
label="Treatment Plan",
|
303 |
-
placeholder="Enter the treatment plan to guide the image generation...",
|
304 |
-
lines=2
|
305 |
-
)
|
306 |
-
generate_btn = gr.Button("Generate Image")
|
307 |
-
img_output = gr.Image(label="Generated Image")
|
308 |
-
generate_btn.click(
|
309 |
-
generate_image_and_reset_chat,
|
310 |
-
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
|
311 |
-
outputs=[img_output, active_session, saved_sessions]
|
312 |
-
)
|
313 |
-
|
314 |
-
# ----- Chat Section -----
|
315 |
-
with gr.Column():
|
316 |
-
gr.Markdown("## Chat about the Image")
|
317 |
gr.Markdown(
|
318 |
-
"
|
319 |
-
"
|
320 |
-
|
321 |
-
chatbot = gr.Chatbot(label="Chat History")
|
322 |
-
with gr.Row():
|
323 |
-
chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False)
|
324 |
-
send_btn = gr.Button("Send")
|
325 |
-
send_btn.click(
|
326 |
-
chat_respond,
|
327 |
-
inputs=[chat_input, active_session, saved_sessions],
|
328 |
-
outputs=[chat_input, chatbot, saved_sessions, active_session]
|
329 |
-
)
|
330 |
-
chat_input.submit(
|
331 |
-
chat_respond,
|
332 |
-
inputs=[chat_input, active_session, saved_sessions],
|
333 |
-
outputs=[chat_input, chatbot, saved_sessions, active_session]
|
334 |
)
|
|
|
335 |
|
336 |
-
#
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
"This sidebar automatically saves finished chat sessions. "
|
341 |
-
"Each session includes the prompt used, the generated image (as a data URL), "
|
342 |
-
"the topic focus, the treatment plan, the list of identified details, and the full chat history."
|
343 |
)
|
344 |
-
sessions_output = gr.JSON(label="Session Details", value={})
|
345 |
-
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
346 |
-
# Update the current difficulty label when active_session changes.
|
347 |
-
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
|
348 |
-
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
349 |
|
350 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
demo.launch()
|
|
|
8 |
from huggingface_hub import InferenceClient
|
9 |
from google.generativeai import configure, GenerativeModel
|
10 |
from google.ai.generativelanguage import Content, Part
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
load_dotenv()
|
13 |
|
14 |
# Load API keys from environment variables
|
15 |
inference_api_key = os.environ.get("HF_TOKEN")
|
16 |
+
google_api_key = os.environ.get("GOOGLE_API_KEY")
|
17 |
|
18 |
# Configure Google API
|
19 |
configure(api_key=google_api_key)
|
20 |
|
21 |
+
# Global variables to store the image data URL, prompt, and detailed description
|
22 |
global_image_data_url = None
|
23 |
+
global_image_prompt = None
|
24 |
+
global_image_description = None # New variable to store Gemini's detailed description
|
25 |
|
26 |
def update_difficulty_label(active_session):
|
27 |
return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"
|
|
|
39 |
- Autism Level: {autism_level}
|
40 |
- Topic Focus: {topic_focus}
|
41 |
- Treatment Plan: {treatment_plan}
|
|
|
42 |
Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.
|
|
|
43 |
The image should specifically focus on the topic: "{topic_focus}".
|
|
|
44 |
Please generate a prompt that instructs the image generation engine to produce an image with:
|
45 |
1. Clarity and simplicity (minimalist backgrounds, clear subject)
|
46 |
2. Literal representation with defined borders and consistent style
|
47 |
3. Soft, muted colors and reduced visual complexity
|
48 |
4. Positive, calm scenes
|
49 |
5. Clear focus on the specified topic
|
|
|
50 |
Use descriptive and detailed language.
|
51 |
"""
|
52 |
)
|
|
|
53 |
# Initialize the Gemini Pro model
|
54 |
model = GenerativeModel('gemini-2.0-flash-lite')
|
|
|
55 |
# Generate content using the Gemini model
|
56 |
response = model.generate_content(query)
|
57 |
+
return response.text.strip()
|
58 |
|
59 |
+
def generate_detailed_description(image_data_url, prompt, difficulty, topic_focus):
|
60 |
+
"""
|
61 |
+
Generate a detailed description of the image using Gemini Vision.
|
62 |
+
"""
|
63 |
+
# Remove the data:image/png;base64, prefix to get just the base64 string
|
64 |
+
base64_img = image_data_url.split(",")[1]
|
65 |
+
query = (
|
66 |
+
f"""
|
67 |
+
You are an expert educator specializing in teaching children with autism.
|
68 |
+
Please provide a detailed description of this image that was generated based on the prompt:
|
69 |
+
"{prompt}"
|
70 |
+
The image is intended for a child with autism, focusing on the topic: "{topic_focus}" at a {difficulty} difficulty level.
|
71 |
+
In your description:
|
72 |
+
1. List all key objects, characters, and elements present in the image
|
73 |
+
2. Describe colors, shapes, positions, and relationships between elements
|
74 |
+
3. Note any emotions, actions, or interactions depicted
|
75 |
+
4. Highlight details that would be important for the child to notice
|
76 |
+
5. Organize your description in a structured, clear way
|
77 |
+
Your description will be used as a reference to evaluate the child's observations,
|
78 |
+
so please be comprehensive but focus on observable details rather than interpretations.
|
79 |
+
"""
|
80 |
+
)
|
81 |
+
# Create a Gemini Vision Pro model
|
82 |
+
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
83 |
+
# Create the content with image and text
|
84 |
+
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
|
85 |
+
text_part = Part(text=query)
|
86 |
+
multimodal_content = Content(parts=[image_part, text_part])
|
87 |
+
# Generate description using the vision model
|
88 |
+
response = vision_model.generate_content(multimodal_content)
|
89 |
return response.text.strip()
|
90 |
|
91 |
def generate_image_fn(selected_prompt, guidance_scale=7.5,
|
|
|
112 |
global_image_data_url = f"data:image/png;base64,{img_b64}"
|
113 |
return image
|
114 |
|
115 |
+
def extract_key_details(description):
|
116 |
+
"""
|
117 |
+
Extract key details from Gemini's description to use for tracking.
|
118 |
+
Returns a list of key elements/details from the description.
|
119 |
+
"""
|
120 |
+
# Create a query to extract key details
|
121 |
+
query = (
|
122 |
+
f"""
|
123 |
+
From the following detailed image description, extract a list of 10-15 key details that a child might identify.
|
124 |
+
Each detail should be a simple, clear phrase describing one observable element.
|
125 |
+
Description:
|
126 |
+
{description}
|
127 |
+
Format your response as a JSON array of strings, each representing one key detail.
|
128 |
+
Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"]
|
129 |
+
"""
|
130 |
+
)
|
131 |
+
# Use Gemini text model to extract key details
|
132 |
+
model = GenerativeModel('gemini-2.0-flash-lite')
|
133 |
+
response = model.generate_content(query)
|
134 |
+
try:
|
135 |
+
# Parse the JSON response
|
136 |
+
details_match = re.search(r'\[.*\]', response.text, re.DOTALL)
|
137 |
+
if details_match:
|
138 |
+
details_json = details_match.group(0)
|
139 |
+
key_details = json.loads(details_json)
|
140 |
+
return key_details
|
141 |
+
else:
|
142 |
+
# If no JSON found, do basic extraction
|
143 |
+
lines = description.split('\n')
|
144 |
+
details = []
|
145 |
+
for line in lines:
|
146 |
+
if line.strip().startswith('-') or line.strip().startswith('*'):
|
147 |
+
details.append(line.strip()[1:].strip())
|
148 |
+
return details[:15] if details else ["object in image", "color", "shape", "background"]
|
149 |
+
except Exception as e:
|
150 |
+
print(f"Error extracting key details: {str(e)}")
|
151 |
+
return ["object in image", "color", "shape", "background"]
|
152 |
+
|
153 |
def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
|
154 |
"""
|
155 |
Generate a new image (with the current difficulty) and reset the chat.
|
156 |
Now includes the topic_focus parameter to specify what the image should focus on.
|
157 |
"""
|
158 |
+
global global_image_description
|
159 |
new_sessions = saved_sessions.copy()
|
160 |
if active_session.get("prompt"):
|
161 |
new_sessions.append(active_session)
|
162 |
+
|
163 |
+
# Use the current difficulty from the active session
|
164 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
165 |
+
|
166 |
+
# Generate the prompt for the image
|
167 |
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
|
168 |
+
|
169 |
+
# Generate the image
|
170 |
image = generate_image_fn(generated_prompt)
|
171 |
+
|
172 |
+
# Generate a detailed description of the image using Gemini Vision
|
173 |
+
image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus)
|
174 |
+
global_image_description = image_description
|
175 |
+
|
176 |
+
# Extract key details to be identified
|
177 |
+
key_details = extract_key_details(image_description)
|
178 |
+
|
179 |
+
# Create a new active session with all the necessary information
|
180 |
new_active_session = {
|
181 |
"prompt": generated_prompt,
|
182 |
"image": global_image_data_url,
|
183 |
+
"image_description": image_description,
|
184 |
"chat": [],
|
185 |
"treatment_plan": treatment_plan,
|
186 |
"topic_focus": topic_focus,
|
187 |
+
"key_details": key_details, # Store the list of key details
|
188 |
"identified_details": [],
|
189 |
+
"used_hints": [],
|
190 |
"difficulty": current_difficulty,
|
191 |
"autism_level": autism_level,
|
192 |
"age": age
|
193 |
}
|
|
|
194 |
|
195 |
+
# Create the checklist of items to identify
|
196 |
+
checklist_items = []
|
197 |
+
for i, detail in enumerate(key_details):
|
198 |
+
checklist_items.append({"detail": detail, "identified": False, "id": i})
|
199 |
+
|
200 |
+
# Return the updated state and checklist
|
201 |
+
return image, new_active_session, new_sessions, checklist_items
|
202 |
+
|
203 |
+
def compare_details_chat_fn(user_details, active_session):
|
204 |
"""
|
205 |
+
Evaluate the child's description using Google's Gemini model.
|
206 |
+
Now uses the image description and tracks identified details and used hints.
|
207 |
"""
|
208 |
+
if not global_image_data_url or not global_image_description:
|
209 |
return "Please generate an image first."
|
210 |
|
211 |
+
# Get the detailed image description
|
212 |
+
image_description = active_session.get("image_description", global_image_description)
|
213 |
+
|
214 |
+
# Get chat history
|
215 |
+
chat_history = active_session.get("chat", [])
|
216 |
history_text = ""
|
217 |
if chat_history:
|
218 |
history_text = "\n\n### Previous Conversation:\n"
|
219 |
+
for idx, (speaker, msg) in enumerate(chat_history, 1):
|
220 |
+
history_text += f"Turn {idx}:\n{speaker}: {msg}\n"
|
221 |
|
222 |
+
# Get key details, identified details and used hints
|
223 |
+
key_details = active_session.get("key_details", [])
|
224 |
+
identified_details = active_session.get("identified_details", [])
|
225 |
+
used_hints = active_session.get("used_hints", [])
|
226 |
+
|
227 |
+
# Format for the API
|
228 |
+
key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details)
|
229 |
identified_details_text = ""
|
230 |
if identified_details:
|
231 |
identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)
|
232 |
+
used_hints_text = ""
|
233 |
+
if used_hints:
|
234 |
+
used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints)
|
235 |
+
|
236 |
+
# Current difficulty level
|
237 |
+
current_difficulty = active_session.get("difficulty", "Very Simple")
|
238 |
|
239 |
message_text = (
|
240 |
+
f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n"
|
241 |
+
f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n"
|
242 |
+
f"### Detailed Image Description (Reference):\n{image_description}\n\n"
|
243 |
+
f"### Current Difficulty Level: {current_difficulty}\n"
|
244 |
+
f"{key_details_text}{history_text}{identified_details_text}{used_hints_text}\n\n"
|
245 |
+
f"### Child's Current Description:\n'{user_details}'\n\n"
|
246 |
+
"Evaluate the child's description compared to the key details list. Use simple, clear language. "
|
247 |
+
"Praise specific correct observations. If something important is missing, provide a gentle hint "
|
248 |
+
"that hasn't been given before.\n\n"
|
249 |
+
"Follow these guidelines:\n"
|
250 |
+
"1. DO NOT mention that you're evaluating or scoring the child.\n"
|
251 |
+
"2. Keep feedback warm, positive, and encouraging.\n"
|
252 |
+
"3. If giving a hint, make it specific but not too obvious.\n"
|
253 |
+
"4. Never repeat hints that have already been given.\n"
|
254 |
+
"5. Focus on details the child hasn't yet identified.\n"
|
255 |
+
"6. Acknowledge the child's progress.\n\n"
|
256 |
+
"Return your response as a JSON object with the following format:\n"
|
257 |
"{\n"
|
258 |
+
" \"feedback\": \"Your encouraging response to the child\",\n"
|
259 |
+
" \"newly_identified_details\": [\"list\", \"of\", \"new details\", \"the child identified\"],\n"
|
260 |
+
" \"hint\": \"A new hint about something not yet identified\",\n"
|
261 |
+
" \"score\": <number from 0-100 based on how complete the description is>,\n"
|
262 |
+
" \"advance_difficulty\": <boolean indicating if child should advance>\n"
|
|
|
|
|
|
|
|
|
|
|
263 |
"}\n\n"
|
264 |
+
"Ensure the JSON is valid and contains all fields."
|
265 |
)
|
266 |
|
267 |
+
# Create a Gemini model for evaluation
|
268 |
+
model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
+
# Generate evaluation using the model
|
271 |
+
response = model.generate_content(message_text)
|
272 |
return response.text
|
273 |
|
274 |
+
def parse_evaluation(evaluation_text, active_session):
|
|
|
|
|
|
|
|
|
|
|
275 |
try:
|
276 |
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
|
277 |
if json_match:
|
|
|
279 |
evaluation = json.loads(json_str)
|
280 |
else:
|
281 |
raise ValueError("No JSON object found in the response.")
|
282 |
+
|
283 |
+
# Extract data from the evaluation
|
284 |
+
feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.")
|
285 |
+
newly_identified_details = evaluation.get("newly_identified_details", [])
|
286 |
+
hint = evaluation.get("hint", "")
|
287 |
+
score = evaluation.get("score", 0)
|
288 |
+
advance_difficulty = evaluation.get("advance_difficulty", False)
|
289 |
+
|
290 |
+
# Update the session with newly identified details
|
291 |
+
identified_details = active_session.get("identified_details", [])
|
292 |
+
for detail in newly_identified_details:
|
293 |
+
if detail not in identified_details:
|
294 |
+
identified_details.append(detail)
|
295 |
+
active_session["identified_details"] = identified_details
|
296 |
+
|
297 |
+
# Add the hint to used hints if one was provided
|
298 |
+
if hint:
|
299 |
+
used_hints = active_session.get("used_hints", [])
|
300 |
+
if hint not in used_hints:
|
301 |
+
used_hints.append(hint)
|
302 |
+
active_session["used_hints"] = used_hints
|
303 |
+
|
304 |
+
# Add the hint to the feedback if it's not already included
|
305 |
+
if hint.strip() and hint.strip() not in feedback:
|
306 |
+
feedback += f"\n\nπ‘ Hint: {hint}"
|
307 |
+
|
308 |
+
# Get current difficulty and check if it should be advanced
|
309 |
+
current_difficulty = active_session.get("difficulty", "Very Simple")
|
310 |
+
should_advance = False
|
311 |
+
|
312 |
+
if advance_difficulty:
|
313 |
+
difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"]
|
314 |
+
current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0
|
315 |
+
if current_index < len(difficulties) - 1:
|
316 |
+
current_difficulty = difficulties[current_index + 1]
|
317 |
+
should_advance = True
|
318 |
+
|
319 |
+
return feedback, current_difficulty, should_advance, newly_identified_details
|
320 |
+
|
321 |
except Exception as e:
|
322 |
+
print(f"Error processing evaluation: {str(e)}")
|
323 |
+
return f"That's interesting! Can you tell me more about what you see?", active_session.get("difficulty", "Very Simple"), False, []
|
324 |
|
325 |
+
def update_checklist(checklist, newly_identified, key_details):
|
326 |
+
"""
|
327 |
+
Update the checklist based on newly identified details.
|
328 |
+
Returns an updated checklist.
|
329 |
+
"""
|
330 |
+
new_checklist = []
|
331 |
+
for item in checklist:
|
332 |
+
detail = item["detail"]
|
333 |
+
# Check if this detail has been identified
|
334 |
+
is_identified = item["identified"]
|
335 |
+
|
336 |
+
# If newly identified, update status
|
337 |
+
for identified in newly_identified:
|
338 |
+
# Check if the identified detail matches or is similar to the key detail
|
339 |
+
if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or
|
340 |
+
any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)):
|
341 |
+
is_identified = True
|
342 |
+
break
|
343 |
+
|
344 |
+
new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]})
|
345 |
+
|
346 |
+
return new_checklist
|
347 |
+
|
348 |
+
def chat_respond(user_message, active_session, saved_sessions, checklist):
|
349 |
"""
|
350 |
Process a new chat message.
|
351 |
+
Evaluate the child's description, update identified details, and advance difficulty if needed.
|
|
|
352 |
"""
|
353 |
if not active_session.get("image"):
|
354 |
bot_message = "Please generate an image first."
|
355 |
+
updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)]
|
356 |
active_session["chat"] = updated_chat
|
357 |
+
return "", updated_chat, saved_sessions, active_session, checklist, None # Return None for image
|
358 |
|
359 |
+
# Get the evaluation from Gemini
|
360 |
+
raw_evaluation = compare_details_chat_fn(user_message, active_session)
|
361 |
+
|
362 |
+
# Parse the evaluation and update session
|
363 |
+
feedback, updated_difficulty, should_advance, newly_identified = parse_evaluation(raw_evaluation, active_session)
|
364 |
+
|
365 |
+
# Update the checklist with newly identified details
|
366 |
+
updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", []))
|
367 |
+
|
368 |
+
# Add the current exchange to the chat history
|
369 |
+
updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)]
|
370 |
+
active_session["chat"] = updated_chat
|
371 |
+
|
372 |
+
# Check if all items have been identified
|
373 |
+
all_identified = all(item["identified"] for item in updated_checklist)
|
374 |
+
|
375 |
+
# Modify this line to generate new image when all details are identified
|
376 |
+
should_generate_new_image = should_advance or all_identified
|
377 |
|
378 |
+
# If the child should advance to a new difficulty or has identified all details
|
379 |
+
if should_generate_new_image:
|
380 |
+
# Save the current session
|
381 |
+
new_sessions = saved_sessions.copy()
|
382 |
+
new_sessions.append(active_session.copy())
|
383 |
+
|
384 |
+
# Get parameters for generating new image
|
385 |
age = active_session.get("age", "3")
|
386 |
autism_level = active_session.get("autism_level", "Level 1")
|
387 |
topic_focus = active_session.get("topic_focus", "")
|
388 |
treatment_plan = active_session.get("treatment_plan", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
+
# Use current difficulty if not advancing, otherwise use updated difficulty
|
391 |
+
difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple")
|
392 |
+
|
393 |
+
# Generate a new prompt with the difficulty
|
394 |
+
generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan)
|
395 |
+
|
396 |
+
# Generate the new image - returns a PIL Image
|
397 |
+
new_image = generate_image_fn(generated_prompt)
|
398 |
+
|
399 |
+
# Now the global_image_data_url should be updated
|
400 |
+
|
401 |
+
# Generate a detailed description of the image using Gemini Vision
|
402 |
+
image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus)
|
403 |
+
|
404 |
+
# Extract key details to be identified
|
405 |
+
key_details = extract_key_details(image_description)
|
406 |
+
|
407 |
+
# Create fresh active session with the new image
|
408 |
+
new_active_session = {
|
409 |
+
"prompt": generated_prompt,
|
410 |
+
"image": global_image_data_url,
|
411 |
+
"image_description": image_description,
|
412 |
+
"chat": [],
|
413 |
+
"treatment_plan": treatment_plan,
|
414 |
+
"topic_focus": topic_focus,
|
415 |
+
"key_details": key_details,
|
416 |
+
"identified_details": [],
|
417 |
+
"used_hints": [],
|
418 |
+
"difficulty": difficulty_to_use,
|
419 |
+
"autism_level": autism_level,
|
420 |
+
"age": age
|
421 |
+
}
|
422 |
+
|
423 |
+
# Create new checklist for the new image
|
424 |
+
new_checklist = []
|
425 |
+
for i, detail in enumerate(key_details):
|
426 |
+
new_checklist.append({"detail": detail, "identified": False, "id": i})
|
427 |
+
|
428 |
+
# Initialize the new chat with an appropriate message
|
429 |
+
if updated_difficulty != active_session.get("difficulty", "Very Simple"):
|
430 |
+
advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe."
|
431 |
+
else:
|
432 |
+
advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level."
|
433 |
+
|
434 |
+
new_active_session["chat"] = [("System", advancement_message)]
|
435 |
+
|
436 |
+
return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image
|
437 |
+
|
438 |
+
# If not advancing, return None for the image to indicate no change
|
439 |
+
return "", updated_chat, saved_sessions, active_session, updated_checklist, None
|
440 |
|
441 |
def update_sessions(saved_sessions, active_session):
|
442 |
"""
|
|
|
450 |
# Gradio Interface
|
451 |
##############################################
|
452 |
with gr.Blocks() as demo:
|
453 |
+
# Initialize the active session with default values
|
454 |
active_session = gr.State({
|
455 |
"prompt": None,
|
456 |
"image": None,
|
457 |
+
"image_description": None,
|
458 |
"chat": [],
|
459 |
"treatment_plan": "",
|
460 |
"topic_focus": "",
|
461 |
+
"key_details": [],
|
462 |
"identified_details": [],
|
463 |
+
"used_hints": [],
|
464 |
"difficulty": "Very Simple",
|
465 |
"age": "3",
|
466 |
"autism_level": "Level 1"
|
467 |
})
|
468 |
saved_sessions = gr.State([])
|
469 |
+
checklist_state = gr.State([])
|
470 |
+
|
471 |
+
with gr.Row():
|
472 |
+
# Main content area
|
473 |
+
with gr.Column(scale=2):
|
474 |
+
gr.Markdown("# Autism Education Image Description Tool")
|
475 |
+
# Display current difficulty label
|
476 |
+
difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
|
477 |
+
|
478 |
+
# ----- Image Generation Section -----
|
479 |
+
with gr.Column():
|
480 |
+
gr.Markdown("## Generate Image")
|
481 |
+
gr.Markdown("Enter the child's details to generate an appropriate educational image.")
|
482 |
+
with gr.Row():
|
483 |
+
age_input = gr.Textbox(label="Child's Age", placeholder="Enter age...", value="3")
|
484 |
+
autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")
|
485 |
+
topic_focus_input = gr.Textbox(
|
486 |
+
label="Topic Focus",
|
487 |
+
placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
|
488 |
+
lines=1
|
489 |
+
)
|
490 |
+
treatment_plan_input = gr.Textbox(
|
491 |
+
label="Treatment Plan",
|
492 |
+
placeholder="Enter the treatment plan to guide the image generation...",
|
493 |
+
lines=2
|
494 |
+
)
|
495 |
+
generate_btn = gr.Button("Generate Image")
|
496 |
+
img_output = gr.Image(label="Generated Image")
|
497 |
+
|
498 |
+
# ----- Chat Section -----
|
499 |
+
with gr.Column():
|
500 |
+
gr.Markdown("## Image Description Practice")
|
501 |
+
gr.Markdown(
|
502 |
+
"After generating an image, ask the child to describe what they see. "
|
503 |
+
"Type their description in the box below. The system will provide supportive feedback "
|
504 |
+
"and track their progress in identifying details."
|
505 |
+
)
|
506 |
+
chatbot = gr.Chatbot(label="Conversation History")
|
507 |
+
with gr.Row():
|
508 |
+
chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True)
|
509 |
+
send_btn = gr.Button("Submit")
|
510 |
+
|
511 |
+
# Sidebar - Checklist of items to identify
|
512 |
+
with gr.Column(scale=1):
|
513 |
+
gr.Markdown("## Details to Identify")
|
514 |
+
gr.Markdown("The child should try to identify these elements in the image:")
|
515 |
+
|
516 |
+
# Create a custom HTML component to display the checklist with checkboxes
|
517 |
+
checklist_html = gr.HTML("""
|
518 |
+
<div id="checklist-container">
|
519 |
+
<p>Generate an image to see details to identify.</p>
|
520 |
+
</div>
|
521 |
+
""")
|
522 |
+
|
523 |
+
# Add a function to update the checklist HTML
|
524 |
+
def update_checklist_html(checklist):
|
525 |
+
if not checklist:
|
526 |
+
return """
|
527 |
+
<div id="checklist-container">
|
528 |
+
<p>Generate an image to see details to identify.</p>
|
529 |
+
</div>
|
530 |
+
"""
|
531 |
+
|
532 |
+
html_content = """
|
533 |
+
<div id="checklist-container" style="padding: 10px;">
|
534 |
+
<style>
|
535 |
+
.checklist-item {
|
536 |
+
display: flex;
|
537 |
+
align-items: center;
|
538 |
+
margin-bottom: 10px;
|
539 |
+
padding: 8px;
|
540 |
+
border-radius: 5px;
|
541 |
+
transition: background-color 0.3s;
|
542 |
+
}
|
543 |
+
.identified {
|
544 |
+
background-color: #e6f7e6;
|
545 |
+
text-decoration: line-through;
|
546 |
+
color: #4CAF50;
|
547 |
+
}
|
548 |
+
.not-identified {
|
549 |
+
background-color: #f5f5f5;
|
550 |
+
}
|
551 |
+
.checkmark {
|
552 |
+
margin-right: 10px;
|
553 |
+
font-size: 1.2em;
|
554 |
+
}
|
555 |
+
</style>
|
556 |
+
"""
|
557 |
+
|
558 |
+
for item in checklist:
|
559 |
+
detail = item["detail"]
|
560 |
+
identified = item["identified"]
|
561 |
+
css_class = "identified" if identified else "not-identified"
|
562 |
+
checkmark = "β
" if identified else "β¬"
|
563 |
+
|
564 |
+
html_content += f"""
|
565 |
+
<div class="checklist-item {css_class}">
|
566 |
+
<span class="checkmark">{checkmark}</span>
|
567 |
+
<span>{detail}</span>
|
568 |
+
</div>
|
569 |
+
"""
|
570 |
+
|
571 |
+
html_content += """
|
572 |
+
</div>
|
573 |
+
"""
|
574 |
+
return html_content
|
575 |
+
|
576 |
+
# Progress summary
|
577 |
+
progress_html = gr.HTML("""
|
578 |
+
<div id="progress-container">
|
579 |
+
<p>No active session.</p>
|
580 |
+
</div>
|
581 |
+
""")
|
582 |
+
|
583 |
+
def update_progress_html(checklist):
|
584 |
+
if not checklist:
|
585 |
+
return """
|
586 |
+
<div id="progress-container">
|
587 |
+
<p>No active session.</p>
|
588 |
+
</div>
|
589 |
+
"""
|
590 |
+
|
591 |
+
total_items = len(checklist)
|
592 |
+
identified_items = sum(1 for item in checklist if item["identified"])
|
593 |
+
percentage = (identified_items / total_items) * 100 if total_items > 0 else 0
|
594 |
+
|
595 |
+
progress_bar_width = f"{percentage}%"
|
596 |
+
all_identified = identified_items == total_items
|
597 |
+
|
598 |
+
html_content = f"""
|
599 |
+
<div id="progress-container" style="padding: 10px;">
|
600 |
+
<h3>Progress: {identified_items} / {total_items} details</h3>
|
601 |
+
<div style="width: 100%; background-color: #f1f1f1; border-radius: 5px; margin-bottom: 10px;">
|
602 |
+
<div style="width: {progress_bar_width}; height: 24px; background-color: #4CAF50; border-radius: 5px;"></div>
|
603 |
+
</div>
|
604 |
+
<p style="font-size: 16px; font-weight: bold; text-align: center;">
|
605 |
+
"""
|
606 |
+
|
607 |
+
if all_identified:
|
608 |
+
html_content += "π Amazing! All details identified! π"
|
609 |
+
elif percentage >= 75:
|
610 |
+
html_content += "Almost there! Keep going!"
|
611 |
+
elif percentage >= 50:
|
612 |
+
html_content += "Halfway there! You're doing great!"
|
613 |
+
elif percentage >= 25:
|
614 |
+
html_content += "Good start! Keep looking!"
|
615 |
+
else:
|
616 |
+
html_content += "Let's find more details!"
|
617 |
+
|
618 |
+
html_content += """
|
619 |
+
</p>
|
620 |
+
</div>
|
621 |
+
"""
|
622 |
+
return html_content
|
623 |
+
|
624 |
+
# ----- Session Details Section -----
|
625 |
+
with gr.Row():
|
626 |
with gr.Column():
|
627 |
+
gr.Markdown("## Progress Tracking")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
gr.Markdown(
|
629 |
+
"This section tracks the child's progress across sessions. "
|
630 |
+
"Each session includes the difficulty level, identified details, "
|
631 |
+
"and the full conversation history."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
)
|
633 |
+
sessions_output = gr.JSON(label="Session Details", value={})
|
634 |
|
635 |
+
# Process chat and update image as needed
|
636 |
+
def process_chat_and_image(user_msg, active_session, saved_sessions, checklist):
|
637 |
+
chat_input, chatbot, new_sessions, new_active_session, new_checklist, new_image = chat_respond(
|
638 |
+
user_msg, active_session, saved_sessions, checklist
|
|
|
|
|
|
|
639 |
)
|
|
|
|
|
|
|
|
|
|
|
640 |
|
641 |
+
# Only return a new image if one was generated (advancement case)
|
642 |
+
if new_image is not None:
|
643 |
+
return chat_input, chatbot, new_sessions, new_active_session, new_checklist, new_image
|
644 |
+
else:
|
645 |
+
# Return a no-update flag for the image to keep the current one
|
646 |
+
return chat_input, chatbot, new_sessions, new_active_session, new_checklist, gr.update()
|
647 |
+
|
648 |
+
# Connect event handlers
|
649 |
+
generate_btn.click(
|
650 |
+
generate_image_and_reset_chat,
|
651 |
+
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
|
652 |
+
outputs=[img_output, active_session, saved_sessions, checklist_state]
|
653 |
+
)
|
654 |
+
|
655 |
+
send_btn.click(
|
656 |
+
process_chat_and_image,
|
657 |
+
inputs=[chat_input, active_session, saved_sessions, checklist_state],
|
658 |
+
outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
|
659 |
+
)
|
660 |
+
|
661 |
+
chat_input.submit(
|
662 |
+
process_chat_and_image,
|
663 |
+
inputs=[chat_input, active_session, saved_sessions, checklist_state],
|
664 |
+
outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
|
665 |
+
)
|
666 |
+
|
667 |
+
# Update the checklist HTML when checklist state changes
|
668 |
+
checklist_state.change(
|
669 |
+
update_checklist_html,
|
670 |
+
inputs=[checklist_state],
|
671 |
+
outputs=[checklist_html]
|
672 |
+
)
|
673 |
+
|
674 |
+
# Update the progress HTML when checklist state changes
|
675 |
+
checklist_state.change(
|
676 |
+
update_progress_html,
|
677 |
+
inputs=[checklist_state],
|
678 |
+
outputs=[progress_html]
|
679 |
+
)
|
680 |
+
|
681 |
+
# Update the current difficulty label when active_session changes
|
682 |
+
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
|
683 |
+
|
684 |
+
# Update sessions when active_session or saved_sessions change
|
685 |
+
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
686 |
+
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
687 |
+
|
688 |
+
# Launch the app
|
689 |
demo.launch()
|