Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -49,9 +49,7 @@ def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, tre
|
|
49 |
Use descriptive and detailed language.
|
50 |
"""
|
51 |
)
|
52 |
-
|
53 |
-
model = GenerativeModel('gemini-2.0-flash-lite')
|
54 |
-
# Generate content using the Gemini model
|
55 |
response = model.generate_content(query)
|
56 |
return response.text.strip()
|
57 |
|
@@ -59,7 +57,6 @@ def generate_detailed_description(image_data_url, prompt, difficulty, topic_focu
|
|
59 |
"""
|
60 |
Generate a detailed description of the image using Gemini Vision.
|
61 |
"""
|
62 |
-
# Remove the data:image/png;base64, prefix to get just the base64 string
|
63 |
base64_img = image_data_url.split(",")[1]
|
64 |
query = (
|
65 |
f"""
|
@@ -77,69 +74,54 @@ def generate_detailed_description(image_data_url, prompt, difficulty, topic_focu
|
|
77 |
so please be comprehensive but focus on observable details rather than interpretations.
|
78 |
"""
|
79 |
)
|
80 |
-
# Create a Gemini Vision Pro model
|
81 |
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
82 |
-
# Create the content with image and text
|
83 |
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
|
84 |
text_part = Part(text=query)
|
85 |
multimodal_content = Content(parts=[image_part, text_part])
|
86 |
-
# Generate description using the vision model
|
87 |
response = vision_model.generate_content(multimodal_content)
|
88 |
return response.text.strip()
|
89 |
|
90 |
-
def
|
91 |
-
negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs",
|
92 |
-
num_inference_steps=50):
|
93 |
-
"""
|
94 |
-
Generate an image from the prompt via the Hugging Face Inference API.
|
95 |
-
Convert the image to a data URL.
|
96 |
-
"""
|
97 |
-
global global_image_data_url, global_image_prompt
|
98 |
-
global_image_prompt = selected_prompt
|
99 |
-
image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
|
100 |
-
image = image_client.text_to_image(
|
101 |
-
selected_prompt,
|
102 |
-
model="stabilityai/stable-diffusion-3.5-large-turbo",
|
103 |
-
guidance_scale=guidance_scale,
|
104 |
-
negative_prompt=negative_prompt,
|
105 |
-
num_inference_steps=num_inference_steps
|
106 |
-
)
|
107 |
-
buffered = io.BytesIO()
|
108 |
-
image.save(buffered, format="PNG")
|
109 |
-
img_bytes = buffered.getvalue()
|
110 |
-
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
111 |
-
global_image_data_url = f"data:image/png;base64,{img_b64}"
|
112 |
-
return image
|
113 |
-
|
114 |
-
def extract_key_details(description):
|
115 |
"""
|
116 |
-
Extract key details from
|
117 |
-
Returns a list of key elements/details from the
|
118 |
"""
|
119 |
-
|
120 |
query = (
|
121 |
f"""
|
122 |
-
|
|
|
|
|
|
|
123 |
Each detail should be a simple, clear phrase describing one observable element.
|
124 |
-
|
125 |
-
|
126 |
Format your response as a JSON array of strings, each representing one key detail.
|
127 |
Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"]
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
"""
|
129 |
)
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
try:
|
134 |
-
# Parse the JSON response
|
135 |
details_match = re.search(r'\[.*\]', response.text, re.DOTALL)
|
136 |
if details_match:
|
137 |
details_json = details_match.group(0)
|
138 |
key_details = json.loads(details_json)
|
139 |
return key_details
|
140 |
else:
|
141 |
-
# If no JSON found,
|
142 |
-
lines =
|
143 |
details = []
|
144 |
for line in lines:
|
145 |
if line.strip().startswith('-') or line.strip().startswith('*'):
|
@@ -149,33 +131,47 @@ def extract_key_details(description):
|
|
149 |
print(f"Error extracting key details: {str(e)}")
|
150 |
return ["object in image", "color", "shape", "background"]
|
151 |
|
152 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
"""
|
154 |
Generate a new image (with the current difficulty) and reset the chat.
|
155 |
-
|
156 |
"""
|
157 |
global global_image_description
|
158 |
new_sessions = saved_sessions.copy()
|
159 |
if active_session.get("prompt"):
|
160 |
new_sessions.append(active_session)
|
161 |
|
162 |
-
# Use the current difficulty from the active session
|
163 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
164 |
-
|
165 |
-
# Generate the prompt for the image
|
166 |
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
|
167 |
-
|
168 |
-
# Generate the image
|
169 |
image = generate_image_fn(generated_prompt)
|
170 |
-
|
171 |
-
# Generate a detailed description of the image using Gemini Vision
|
172 |
image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus)
|
173 |
global_image_description = image_description
|
|
|
174 |
|
175 |
-
# Extract key details to be identified
|
176 |
-
key_details = extract_key_details(image_description)
|
177 |
-
|
178 |
-
# Create a new active session with all the necessary information
|
179 |
new_active_session = {
|
180 |
"prompt": generated_prompt,
|
181 |
"image": global_image_data_url,
|
@@ -183,34 +179,30 @@ def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan
|
|
183 |
"chat": [],
|
184 |
"treatment_plan": treatment_plan,
|
185 |
"topic_focus": topic_focus,
|
186 |
-
"key_details": key_details,
|
187 |
"identified_details": [],
|
188 |
"used_hints": [],
|
189 |
"difficulty": current_difficulty,
|
190 |
"autism_level": autism_level,
|
191 |
-
"age": age
|
|
|
|
|
192 |
}
|
193 |
|
194 |
-
# Create the checklist of items to identify
|
195 |
checklist_items = []
|
196 |
for i, detail in enumerate(key_details):
|
197 |
checklist_items.append({"detail": detail, "identified": False, "id": i})
|
198 |
|
199 |
-
# Return the updated state and checklist
|
200 |
return image, new_active_session, new_sessions, checklist_items
|
201 |
|
202 |
def compare_details_chat_fn(user_details, active_session):
|
203 |
"""
|
204 |
Evaluate the child's description using Google's Gemini model.
|
205 |
-
Now uses the image description and tracks identified details and used hints.
|
206 |
"""
|
207 |
if not global_image_data_url or not global_image_description:
|
208 |
return "Please generate an image first."
|
209 |
|
210 |
-
# Get the detailed image description
|
211 |
image_description = active_session.get("image_description", global_image_description)
|
212 |
-
|
213 |
-
# Get chat history
|
214 |
chat_history = active_session.get("chat", [])
|
215 |
history_text = ""
|
216 |
if chat_history:
|
@@ -218,12 +210,10 @@ def compare_details_chat_fn(user_details, active_session):
|
|
218 |
for idx, (speaker, msg) in enumerate(chat_history, 1):
|
219 |
history_text += f"Turn {idx}:\n{speaker}: {msg}\n"
|
220 |
|
221 |
-
# Get key details, identified details and used hints
|
222 |
key_details = active_session.get("key_details", [])
|
223 |
identified_details = active_session.get("identified_details", [])
|
224 |
used_hints = active_session.get("used_hints", [])
|
225 |
|
226 |
-
# Format for the API
|
227 |
key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details)
|
228 |
identified_details_text = ""
|
229 |
if identified_details:
|
@@ -232,9 +222,7 @@ def compare_details_chat_fn(user_details, active_session):
|
|
232 |
if used_hints:
|
233 |
used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints)
|
234 |
|
235 |
-
# Current difficulty level
|
236 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
237 |
-
|
238 |
message_text = (
|
239 |
f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n"
|
240 |
f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n"
|
@@ -263,14 +251,15 @@ def compare_details_chat_fn(user_details, active_session):
|
|
263 |
"Ensure the JSON is valid and contains all fields."
|
264 |
)
|
265 |
|
266 |
-
|
267 |
-
model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
268 |
-
|
269 |
-
# Generate evaluation using the model
|
270 |
response = model.generate_content(message_text)
|
271 |
return response.text
|
272 |
|
273 |
def parse_evaluation(evaluation_text, active_session):
|
|
|
|
|
|
|
|
|
274 |
try:
|
275 |
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
|
276 |
if json_match:
|
@@ -279,35 +268,28 @@ def parse_evaluation(evaluation_text, active_session):
|
|
279 |
else:
|
280 |
raise ValueError("No JSON object found in the response.")
|
281 |
|
282 |
-
# Extract data from the evaluation
|
283 |
feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.")
|
284 |
newly_identified_details = evaluation.get("newly_identified_details", [])
|
285 |
hint = evaluation.get("hint", "")
|
286 |
score = evaluation.get("score", 0)
|
287 |
advance_difficulty = evaluation.get("advance_difficulty", False)
|
288 |
|
289 |
-
# Update the session with newly identified details
|
290 |
identified_details = active_session.get("identified_details", [])
|
291 |
for detail in newly_identified_details:
|
292 |
if detail not in identified_details:
|
293 |
identified_details.append(detail)
|
294 |
active_session["identified_details"] = identified_details
|
295 |
|
296 |
-
# Add the hint to used hints if one was provided
|
297 |
if hint:
|
298 |
used_hints = active_session.get("used_hints", [])
|
299 |
if hint not in used_hints:
|
300 |
used_hints.append(hint)
|
301 |
active_session["used_hints"] = used_hints
|
302 |
-
|
303 |
-
# Add the hint to the feedback if it's not already included
|
304 |
if hint.strip() and hint.strip() not in feedback:
|
305 |
feedback += f"\n\nπ‘ Hint: {hint}"
|
306 |
|
307 |
-
# Get current difficulty and check if it should be advanced
|
308 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
309 |
should_advance = False
|
310 |
-
|
311 |
if advance_difficulty:
|
312 |
difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"]
|
313 |
current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0
|
@@ -315,95 +297,71 @@ def parse_evaluation(evaluation_text, active_session):
|
|
315 |
current_difficulty = difficulties[current_index + 1]
|
316 |
should_advance = True
|
317 |
|
318 |
-
return feedback, current_difficulty, should_advance, newly_identified_details
|
319 |
-
|
320 |
except Exception as e:
|
321 |
print(f"Error processing evaluation: {str(e)}")
|
322 |
-
return
|
|
|
|
|
|
|
|
|
323 |
|
324 |
def update_checklist(checklist, newly_identified, key_details):
|
325 |
"""
|
326 |
Update the checklist based on newly identified details.
|
327 |
-
Returns an updated checklist.
|
328 |
"""
|
329 |
new_checklist = []
|
330 |
for item in checklist:
|
331 |
detail = item["detail"]
|
332 |
-
# Check if this detail has been identified
|
333 |
is_identified = item["identified"]
|
334 |
-
|
335 |
-
# If newly identified, update status
|
336 |
for identified in newly_identified:
|
337 |
-
# Check if the identified detail matches or is similar to the key detail
|
338 |
if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or
|
339 |
any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)):
|
340 |
is_identified = True
|
341 |
break
|
342 |
-
|
343 |
new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]})
|
344 |
-
|
345 |
return new_checklist
|
346 |
|
347 |
def chat_respond(user_message, active_session, saved_sessions, checklist):
|
348 |
"""
|
349 |
Process a new chat message.
|
350 |
Evaluate the child's description, update identified details, and advance difficulty if needed.
|
|
|
351 |
"""
|
352 |
if not active_session.get("image"):
|
353 |
bot_message = "Please generate an image first."
|
354 |
updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)]
|
355 |
active_session["chat"] = updated_chat
|
356 |
-
return "", updated_chat, saved_sessions, active_session, checklist, None
|
357 |
|
358 |
-
# Get the evaluation from Gemini
|
359 |
raw_evaluation = compare_details_chat_fn(user_message, active_session)
|
|
|
360 |
|
361 |
-
#
|
362 |
-
|
|
|
363 |
|
364 |
-
# Update the checklist with newly identified details
|
365 |
updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", []))
|
366 |
-
|
367 |
-
# Add the current exchange to the chat history
|
368 |
updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)]
|
369 |
active_session["chat"] = updated_chat
|
370 |
|
371 |
-
# Check if all items have been identified
|
372 |
all_identified = all(item["identified"] for item in updated_checklist)
|
|
|
|
|
373 |
|
374 |
-
# Modify this line to generate new image when all details are identified
|
375 |
-
should_generate_new_image = should_advance or all_identified
|
376 |
-
|
377 |
-
# If the child should advance to a new difficulty or has identified all details
|
378 |
if should_generate_new_image:
|
379 |
-
# Save the current session
|
380 |
new_sessions = saved_sessions.copy()
|
381 |
new_sessions.append(active_session.copy())
|
382 |
-
|
383 |
-
# Get parameters for generating new image
|
384 |
age = active_session.get("age", "3")
|
385 |
autism_level = active_session.get("autism_level", "Level 1")
|
386 |
topic_focus = active_session.get("topic_focus", "")
|
387 |
treatment_plan = active_session.get("treatment_plan", "")
|
388 |
-
|
389 |
-
# Use current difficulty if not advancing, otherwise use updated difficulty
|
390 |
difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple")
|
391 |
-
|
392 |
-
# Generate a new prompt with the difficulty
|
393 |
generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan)
|
394 |
-
|
395 |
-
# Generate the new image - returns a PIL Image
|
396 |
new_image = generate_image_fn(generated_prompt)
|
397 |
-
|
398 |
-
# Now the global_image_data_url should be updated
|
399 |
-
|
400 |
-
# Generate a detailed description of the image using Gemini Vision
|
401 |
image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus)
|
|
|
402 |
|
403 |
-
# Extract key details to be identified
|
404 |
-
key_details = extract_key_details(image_description)
|
405 |
-
|
406 |
-
# Create fresh active session with the new image
|
407 |
new_active_session = {
|
408 |
"prompt": generated_prompt,
|
409 |
"image": global_image_data_url,
|
@@ -416,25 +374,25 @@ def chat_respond(user_message, active_session, saved_sessions, checklist):
|
|
416 |
"used_hints": [],
|
417 |
"difficulty": difficulty_to_use,
|
418 |
"autism_level": autism_level,
|
419 |
-
"age": age
|
|
|
|
|
420 |
}
|
421 |
|
422 |
-
# Create new checklist for the new image
|
423 |
new_checklist = []
|
424 |
for i, detail in enumerate(key_details):
|
425 |
new_checklist.append({"detail": detail, "identified": False, "id": i})
|
426 |
|
427 |
-
|
428 |
-
|
|
|
429 |
advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe."
|
430 |
else:
|
431 |
advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level."
|
432 |
|
433 |
new_active_session["chat"] = [("System", advancement_message)]
|
434 |
-
|
435 |
return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image
|
436 |
|
437 |
-
# If not advancing, return None for the image to indicate no change
|
438 |
return "", updated_chat, saved_sessions, active_session, updated_checklist, None
|
439 |
|
440 |
def update_sessions(saved_sessions, active_session):
|
@@ -449,7 +407,6 @@ def update_sessions(saved_sessions, active_session):
|
|
449 |
# Gradio Interface
|
450 |
##############################################
|
451 |
with gr.Blocks() as demo:
|
452 |
-
# Initialize the active session with default values
|
453 |
active_session = gr.State({
|
454 |
"prompt": None,
|
455 |
"image": None,
|
@@ -462,19 +419,17 @@ with gr.Blocks() as demo:
|
|
462 |
"used_hints": [],
|
463 |
"difficulty": "Very Simple",
|
464 |
"age": "3",
|
465 |
-
"autism_level": "Level 1"
|
|
|
|
|
466 |
})
|
467 |
saved_sessions = gr.State([])
|
468 |
checklist_state = gr.State([])
|
469 |
|
470 |
with gr.Row():
|
471 |
-
# Main content area
|
472 |
with gr.Column(scale=2):
|
473 |
gr.Markdown("# Autism Education Image Description Tool")
|
474 |
-
# Display current difficulty label
|
475 |
difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
|
476 |
-
|
477 |
-
# ----- Image Generation Section -----
|
478 |
with gr.Column():
|
479 |
gr.Markdown("## Generate Image")
|
480 |
gr.Markdown("Enter the child's details to generate an appropriate educational image.")
|
@@ -491,10 +446,9 @@ with gr.Blocks() as demo:
|
|
491 |
placeholder="Enter the treatment plan to guide the image generation...",
|
492 |
lines=2
|
493 |
)
|
|
|
494 |
generate_btn = gr.Button("Generate Image")
|
495 |
img_output = gr.Image(label="Generated Image")
|
496 |
-
|
497 |
-
# ----- Chat Section -----
|
498 |
with gr.Column():
|
499 |
gr.Markdown("## Image Description Practice")
|
500 |
gr.Markdown(
|
@@ -506,20 +460,19 @@ with gr.Blocks() as demo:
|
|
506 |
with gr.Row():
|
507 |
chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True)
|
508 |
send_btn = gr.Button("Submit")
|
509 |
-
|
510 |
-
# Sidebar - Checklist of items to identify
|
511 |
with gr.Column(scale=1):
|
512 |
gr.Markdown("## Details to Identify")
|
513 |
gr.Markdown("The child should try to identify these elements in the image:")
|
514 |
-
|
515 |
-
# Create a custom HTML component to display the checklist with checkboxes
|
516 |
checklist_html = gr.HTML("""
|
517 |
<div id="checklist-container">
|
518 |
<p>Generate an image to see details to identify.</p>
|
519 |
</div>
|
520 |
""")
|
521 |
-
|
522 |
-
|
|
|
|
|
|
|
523 |
def update_checklist_html(checklist):
|
524 |
if not checklist:
|
525 |
return """
|
@@ -527,7 +480,6 @@ with gr.Blocks() as demo:
|
|
527 |
<p>Generate an image to see details to identify.</p>
|
528 |
</div>
|
529 |
"""
|
530 |
-
|
531 |
html_content = """
|
532 |
<div id="checklist-container" style="padding: 10px;">
|
533 |
<style>
|
@@ -553,32 +505,26 @@ with gr.Blocks() as demo:
|
|
553 |
}
|
554 |
</style>
|
555 |
"""
|
556 |
-
|
557 |
for item in checklist:
|
558 |
detail = item["detail"]
|
559 |
identified = item["identified"]
|
560 |
css_class = "identified" if identified else "not-identified"
|
561 |
checkmark = "β
" if identified else "β¬"
|
562 |
-
|
563 |
html_content += f"""
|
564 |
<div class="checklist-item {css_class}">
|
565 |
<span class="checkmark">{checkmark}</span>
|
566 |
<span>{detail}</span>
|
567 |
</div>
|
568 |
"""
|
569 |
-
|
570 |
html_content += """
|
571 |
</div>
|
572 |
"""
|
573 |
return html_content
|
574 |
-
|
575 |
-
# Progress summary
|
576 |
progress_html = gr.HTML("""
|
577 |
<div id="progress-container">
|
578 |
<p>No active session.</p>
|
579 |
</div>
|
580 |
""")
|
581 |
-
|
582 |
def update_progress_html(checklist):
|
583 |
if not checklist:
|
584 |
return """
|
@@ -586,14 +532,11 @@ with gr.Blocks() as demo:
|
|
586 |
<p>No active session.</p>
|
587 |
</div>
|
588 |
"""
|
589 |
-
|
590 |
total_items = len(checklist)
|
591 |
identified_items = sum(1 for item in checklist if item["identified"])
|
592 |
percentage = (identified_items / total_items) * 100 if total_items > 0 else 0
|
593 |
-
|
594 |
progress_bar_width = f"{percentage}%"
|
595 |
all_identified = identified_items == total_items
|
596 |
-
|
597 |
html_content = f"""
|
598 |
<div id="progress-container" style="padding: 10px;">
|
599 |
<h3>Progress: {identified_items} / {total_items} details</h3>
|
@@ -602,7 +545,6 @@ with gr.Blocks() as demo:
|
|
602 |
</div>
|
603 |
<p style="font-size: 16px; font-weight: bold; text-align: center;">
|
604 |
"""
|
605 |
-
|
606 |
if all_identified:
|
607 |
html_content += "π Amazing! All details identified! π"
|
608 |
elif percentage >= 75:
|
@@ -613,14 +555,20 @@ with gr.Blocks() as demo:
|
|
613 |
html_content += "Good start! Keep looking!"
|
614 |
else:
|
615 |
html_content += "Let's find more details!"
|
616 |
-
|
617 |
html_content += """
|
618 |
</p>
|
619 |
</div>
|
620 |
"""
|
621 |
return html_content
|
622 |
|
623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
with gr.Row():
|
625 |
with gr.Column():
|
626 |
gr.Markdown("## Progress Tracking")
|
@@ -630,59 +578,41 @@ with gr.Blocks() as demo:
|
|
630 |
"and the full conversation history."
|
631 |
)
|
632 |
sessions_output = gr.JSON(label="Session Details", value={})
|
633 |
-
|
634 |
-
# Process chat and update image as needed
|
635 |
def process_chat_and_image(user_msg, active_session, saved_sessions, checklist):
|
636 |
-
|
637 |
user_msg, active_session, saved_sessions, checklist
|
638 |
)
|
639 |
-
|
640 |
-
# Only return a new image if one was generated (advancement case)
|
641 |
if new_image is not None:
|
642 |
-
return
|
643 |
else:
|
644 |
-
|
645 |
-
return chat_input, chatbot, new_sessions, new_active_session, new_checklist, gr.update()
|
646 |
-
|
647 |
-
# Connect event handlers
|
648 |
generate_btn.click(
|
649 |
generate_image_and_reset_chat,
|
650 |
-
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
|
651 |
outputs=[img_output, active_session, saved_sessions, checklist_state]
|
652 |
)
|
653 |
-
|
654 |
send_btn.click(
|
655 |
process_chat_and_image,
|
656 |
inputs=[chat_input, active_session, saved_sessions, checklist_state],
|
657 |
outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
|
658 |
)
|
659 |
-
|
660 |
chat_input.submit(
|
661 |
process_chat_and_image,
|
662 |
inputs=[chat_input, active_session, saved_sessions, checklist_state],
|
663 |
outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
|
664 |
)
|
665 |
-
|
666 |
-
# Update the checklist HTML when checklist state changes
|
667 |
checklist_state.change(
|
668 |
update_checklist_html,
|
669 |
inputs=[checklist_state],
|
670 |
outputs=[checklist_html]
|
671 |
)
|
672 |
-
|
673 |
-
# Update the progress HTML when checklist state changes
|
674 |
checklist_state.change(
|
675 |
update_progress_html,
|
676 |
inputs=[checklist_state],
|
677 |
outputs=[progress_html]
|
678 |
)
|
679 |
-
|
680 |
-
# Update the current difficulty label when active_session changes
|
681 |
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
|
682 |
-
|
683 |
-
# Update sessions when active_session or saved_sessions change
|
684 |
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
685 |
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
686 |
-
|
687 |
-
# Launch the app
|
688 |
demo.launch()
|
|
|
49 |
Use descriptive and detailed language.
|
50 |
"""
|
51 |
)
|
52 |
+
model = GenerativeModel('gemini-2.0-flash')
|
|
|
|
|
53 |
response = model.generate_content(query)
|
54 |
return response.text.strip()
|
55 |
|
|
|
57 |
"""
|
58 |
Generate a detailed description of the image using Gemini Vision.
|
59 |
"""
|
|
|
60 |
base64_img = image_data_url.split(",")[1]
|
61 |
query = (
|
62 |
f"""
|
|
|
74 |
so please be comprehensive but focus on observable details rather than interpretations.
|
75 |
"""
|
76 |
)
|
|
|
77 |
vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
|
|
|
78 |
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
|
79 |
text_part = Part(text=query)
|
80 |
multimodal_content = Content(parts=[image_part, text_part])
|
|
|
81 |
response = vision_model.generate_content(multimodal_content)
|
82 |
return response.text.strip()
|
83 |
|
84 |
+
def extract_key_details(image_data_url, prompt, topic_focus):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
"""
|
86 |
+
Extract key details directly from the image using Gemini Vision.
|
87 |
+
Returns a list of key elements/details from the image.
|
88 |
"""
|
89 |
+
base64_img = image_data_url.split(",")[1]
|
90 |
query = (
|
91 |
f"""
|
92 |
+
You are analyzing an educational image created for a child with autism, based on the prompt: "{prompt}".
|
93 |
+
The image focuses on the topic: "{topic_focus}".
|
94 |
+
|
95 |
+
Please extract a list of 10-15 key details that a child might identify in this image.
|
96 |
Each detail should be a simple, clear phrase describing one observable element.
|
97 |
+
Focus on concrete, visible elements rather than abstract concepts.
|
98 |
+
|
99 |
Format your response as a JSON array of strings, each representing one key detail.
|
100 |
Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"]
|
101 |
+
|
102 |
+
Ensure each detail is:
|
103 |
+
1. Directly observable in the image
|
104 |
+
2. Unique (not a duplicate)
|
105 |
+
3. Described in simple, concrete language
|
106 |
+
4. Relevant to what a child would notice
|
107 |
"""
|
108 |
)
|
109 |
+
|
110 |
+
vision_model = GenerativeModel('gemini-2.0-flash')
|
111 |
+
image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
|
112 |
+
text_part = Part(text=query)
|
113 |
+
multimodal_content = Content(parts=[image_part, text_part])
|
114 |
+
response = vision_model.generate_content(multimodal_content)
|
115 |
+
|
116 |
try:
|
|
|
117 |
details_match = re.search(r'\[.*\]', response.text, re.DOTALL)
|
118 |
if details_match:
|
119 |
details_json = details_match.group(0)
|
120 |
key_details = json.loads(details_json)
|
121 |
return key_details
|
122 |
else:
|
123 |
+
# If no JSON array is found, try to extract bullet points or lines
|
124 |
+
lines = response.text.split('\n')
|
125 |
details = []
|
126 |
for line in lines:
|
127 |
if line.strip().startswith('-') or line.strip().startswith('*'):
|
|
|
131 |
print(f"Error extracting key details: {str(e)}")
|
132 |
return ["object in image", "color", "shape", "background"]
|
133 |
|
134 |
+
def generate_image_fn(selected_prompt, guidance_scale=7.5,
|
135 |
+
negative_prompt="ugly, blurry, poorly drawn hands, nude, deformed, missing limbs, missing body parts",
|
136 |
+
num_inference_steps=45):
|
137 |
+
"""
|
138 |
+
Generate an image from the prompt via the Hugging Face Inference API.
|
139 |
+
Convert the image to a data URL.
|
140 |
+
"""
|
141 |
+
global global_image_data_url, global_image_prompt
|
142 |
+
global_image_prompt = selected_prompt
|
143 |
+
image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
|
144 |
+
image = image_client.text_to_image(
|
145 |
+
selected_prompt,
|
146 |
+
model="stabilityai/stable-diffusion-3.5-large-turbo",
|
147 |
+
guidance_scale=guidance_scale,
|
148 |
+
negative_prompt=negative_prompt,
|
149 |
+
num_inference_steps=num_inference_steps
|
150 |
+
)
|
151 |
+
buffered = io.BytesIO()
|
152 |
+
image.save(buffered, format="PNG")
|
153 |
+
img_bytes = buffered.getvalue()
|
154 |
+
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
155 |
+
global_image_data_url = f"data:image/png;base64,{img_b64}"
|
156 |
+
return image
|
157 |
+
|
158 |
+
def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, attempt_limit_input, active_session, saved_sessions):
|
159 |
"""
|
160 |
Generate a new image (with the current difficulty) and reset the chat.
|
161 |
+
Also resets the attempt count and uses the user-entered attempt limit.
|
162 |
"""
|
163 |
global global_image_description
|
164 |
new_sessions = saved_sessions.copy()
|
165 |
if active_session.get("prompt"):
|
166 |
new_sessions.append(active_session)
|
167 |
|
|
|
168 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
|
|
|
|
169 |
generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
|
|
|
|
|
170 |
image = generate_image_fn(generated_prompt)
|
|
|
|
|
171 |
image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus)
|
172 |
global_image_description = image_description
|
173 |
+
key_details = extract_key_details(global_image_data_url, generated_prompt, topic_focus)
|
174 |
|
|
|
|
|
|
|
|
|
175 |
new_active_session = {
|
176 |
"prompt": generated_prompt,
|
177 |
"image": global_image_data_url,
|
|
|
179 |
"chat": [],
|
180 |
"treatment_plan": treatment_plan,
|
181 |
"topic_focus": topic_focus,
|
182 |
+
"key_details": key_details,
|
183 |
"identified_details": [],
|
184 |
"used_hints": [],
|
185 |
"difficulty": current_difficulty,
|
186 |
"autism_level": autism_level,
|
187 |
+
"age": age,
|
188 |
+
"attempt_limit": int(attempt_limit_input) if attempt_limit_input else 3,
|
189 |
+
"attempt_count": 0
|
190 |
}
|
191 |
|
|
|
192 |
checklist_items = []
|
193 |
for i, detail in enumerate(key_details):
|
194 |
checklist_items.append({"detail": detail, "identified": False, "id": i})
|
195 |
|
|
|
196 |
return image, new_active_session, new_sessions, checklist_items
|
197 |
|
198 |
def compare_details_chat_fn(user_details, active_session):
|
199 |
"""
|
200 |
Evaluate the child's description using Google's Gemini model.
|
|
|
201 |
"""
|
202 |
if not global_image_data_url or not global_image_description:
|
203 |
return "Please generate an image first."
|
204 |
|
|
|
205 |
image_description = active_session.get("image_description", global_image_description)
|
|
|
|
|
206 |
chat_history = active_session.get("chat", [])
|
207 |
history_text = ""
|
208 |
if chat_history:
|
|
|
210 |
for idx, (speaker, msg) in enumerate(chat_history, 1):
|
211 |
history_text += f"Turn {idx}:\n{speaker}: {msg}\n"
|
212 |
|
|
|
213 |
key_details = active_session.get("key_details", [])
|
214 |
identified_details = active_session.get("identified_details", [])
|
215 |
used_hints = active_session.get("used_hints", [])
|
216 |
|
|
|
217 |
key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details)
|
218 |
identified_details_text = ""
|
219 |
if identified_details:
|
|
|
222 |
if used_hints:
|
223 |
used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints)
|
224 |
|
|
|
225 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
|
|
226 |
message_text = (
|
227 |
f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n"
|
228 |
f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n"
|
|
|
251 |
"Ensure the JSON is valid and contains all fields."
|
252 |
)
|
253 |
|
254 |
+
model = GenerativeModel('gemini-2.0-flash')
|
|
|
|
|
|
|
255 |
response = model.generate_content(message_text)
|
256 |
return response.text
|
257 |
|
258 |
def parse_evaluation(evaluation_text, active_session):
|
259 |
+
"""
|
260 |
+
Parse the evaluation JSON and return feedback, updated difficulty, whether to advance,
|
261 |
+
newly identified details, and the score.
|
262 |
+
"""
|
263 |
try:
|
264 |
json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
|
265 |
if json_match:
|
|
|
268 |
else:
|
269 |
raise ValueError("No JSON object found in the response.")
|
270 |
|
|
|
271 |
feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.")
|
272 |
newly_identified_details = evaluation.get("newly_identified_details", [])
|
273 |
hint = evaluation.get("hint", "")
|
274 |
score = evaluation.get("score", 0)
|
275 |
advance_difficulty = evaluation.get("advance_difficulty", False)
|
276 |
|
|
|
277 |
identified_details = active_session.get("identified_details", [])
|
278 |
for detail in newly_identified_details:
|
279 |
if detail not in identified_details:
|
280 |
identified_details.append(detail)
|
281 |
active_session["identified_details"] = identified_details
|
282 |
|
|
|
283 |
if hint:
|
284 |
used_hints = active_session.get("used_hints", [])
|
285 |
if hint not in used_hints:
|
286 |
used_hints.append(hint)
|
287 |
active_session["used_hints"] = used_hints
|
|
|
|
|
288 |
if hint.strip() and hint.strip() not in feedback:
|
289 |
feedback += f"\n\nπ‘ Hint: {hint}"
|
290 |
|
|
|
291 |
current_difficulty = active_session.get("difficulty", "Very Simple")
|
292 |
should_advance = False
|
|
|
293 |
if advance_difficulty:
|
294 |
difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"]
|
295 |
current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0
|
|
|
297 |
current_difficulty = difficulties[current_index + 1]
|
298 |
should_advance = True
|
299 |
|
300 |
+
return feedback, current_difficulty, should_advance, newly_identified_details, score
|
|
|
301 |
except Exception as e:
|
302 |
print(f"Error processing evaluation: {str(e)}")
|
303 |
+
return ("That's interesting! Can you tell me more about what you see?",
|
304 |
+
active_session.get("difficulty", "Very Simple"),
|
305 |
+
False,
|
306 |
+
[],
|
307 |
+
0)
|
308 |
|
309 |
def update_checklist(checklist, newly_identified, key_details):
|
310 |
"""
|
311 |
Update the checklist based on newly identified details.
|
|
|
312 |
"""
|
313 |
new_checklist = []
|
314 |
for item in checklist:
|
315 |
detail = item["detail"]
|
|
|
316 |
is_identified = item["identified"]
|
|
|
|
|
317 |
for identified in newly_identified:
|
|
|
318 |
if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or
|
319 |
any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)):
|
320 |
is_identified = True
|
321 |
break
|
|
|
322 |
new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]})
|
|
|
323 |
return new_checklist
|
324 |
|
325 |
def chat_respond(user_message, active_session, saved_sessions, checklist):
|
326 |
"""
|
327 |
Process a new chat message.
|
328 |
Evaluate the child's description, update identified details, and advance difficulty if needed.
|
329 |
+
Only increment the attempt count if no new details were identified.
|
330 |
"""
|
331 |
if not active_session.get("image"):
|
332 |
bot_message = "Please generate an image first."
|
333 |
updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)]
|
334 |
active_session["chat"] = updated_chat
|
335 |
+
return "", updated_chat, saved_sessions, active_session, checklist, None
|
336 |
|
|
|
337 |
raw_evaluation = compare_details_chat_fn(user_message, active_session)
|
338 |
+
feedback, updated_difficulty, should_advance, newly_identified, score = parse_evaluation(raw_evaluation, active_session)
|
339 |
|
340 |
+
# Only count a failed attempt if no new details were identified
|
341 |
+
if not newly_identified:
|
342 |
+
active_session["attempt_count"] = active_session.get("attempt_count", 0) + 1
|
343 |
|
|
|
344 |
updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", []))
|
|
|
|
|
345 |
updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)]
|
346 |
active_session["chat"] = updated_chat
|
347 |
|
|
|
348 |
all_identified = all(item["identified"] for item in updated_checklist)
|
349 |
+
attempts_exhausted = active_session.get("attempt_count", 0) >= active_session.get("attempt_limit", 3)
|
350 |
+
should_generate_new_image = should_advance or all_identified or attempts_exhausted
|
351 |
|
|
|
|
|
|
|
|
|
352 |
if should_generate_new_image:
|
|
|
353 |
new_sessions = saved_sessions.copy()
|
354 |
new_sessions.append(active_session.copy())
|
|
|
|
|
355 |
age = active_session.get("age", "3")
|
356 |
autism_level = active_session.get("autism_level", "Level 1")
|
357 |
topic_focus = active_session.get("topic_focus", "")
|
358 |
treatment_plan = active_session.get("treatment_plan", "")
|
|
|
|
|
359 |
difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple")
|
|
|
|
|
360 |
generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan)
|
|
|
|
|
361 |
new_image = generate_image_fn(generated_prompt)
|
|
|
|
|
|
|
|
|
362 |
image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus)
|
363 |
+
key_details = extract_key_details(global_image_data_url, generated_prompt, topic_focus)
|
364 |
|
|
|
|
|
|
|
|
|
365 |
new_active_session = {
|
366 |
"prompt": generated_prompt,
|
367 |
"image": global_image_data_url,
|
|
|
374 |
"used_hints": [],
|
375 |
"difficulty": difficulty_to_use,
|
376 |
"autism_level": autism_level,
|
377 |
+
"age": age,
|
378 |
+
"attempt_limit": active_session.get("attempt_limit", 3),
|
379 |
+
"attempt_count": 0
|
380 |
}
|
381 |
|
|
|
382 |
new_checklist = []
|
383 |
for i, detail in enumerate(key_details):
|
384 |
new_checklist.append({"detail": detail, "identified": False, "id": i})
|
385 |
|
386 |
+
if attempts_exhausted:
|
387 |
+
advancement_message = "You've used all your allowed attempts. Let's try a new image."
|
388 |
+
elif updated_difficulty != active_session.get("difficulty", "Very Simple"):
|
389 |
advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe."
|
390 |
else:
|
391 |
advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level."
|
392 |
|
393 |
new_active_session["chat"] = [("System", advancement_message)]
|
|
|
394 |
return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image
|
395 |
|
|
|
396 |
return "", updated_chat, saved_sessions, active_session, updated_checklist, None
|
397 |
|
398 |
def update_sessions(saved_sessions, active_session):
|
|
|
407 |
# Gradio Interface
|
408 |
##############################################
|
409 |
with gr.Blocks() as demo:
|
|
|
410 |
active_session = gr.State({
|
411 |
"prompt": None,
|
412 |
"image": None,
|
|
|
419 |
"used_hints": [],
|
420 |
"difficulty": "Very Simple",
|
421 |
"age": "3",
|
422 |
+
"autism_level": "Level 1",
|
423 |
+
"attempt_limit": 3,
|
424 |
+
"attempt_count": 0
|
425 |
})
|
426 |
saved_sessions = gr.State([])
|
427 |
checklist_state = gr.State([])
|
428 |
|
429 |
with gr.Row():
|
|
|
430 |
with gr.Column(scale=2):
|
431 |
gr.Markdown("# Autism Education Image Description Tool")
|
|
|
432 |
difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
|
|
|
|
|
433 |
with gr.Column():
|
434 |
gr.Markdown("## Generate Image")
|
435 |
gr.Markdown("Enter the child's details to generate an appropriate educational image.")
|
|
|
446 |
placeholder="Enter the treatment plan to guide the image generation...",
|
447 |
lines=2
|
448 |
)
|
449 |
+
attempt_limit_input = gr.Number(label="Allowed Attempts", value=3, precision=0)
|
450 |
generate_btn = gr.Button("Generate Image")
|
451 |
img_output = gr.Image(label="Generated Image")
|
|
|
|
|
452 |
with gr.Column():
|
453 |
gr.Markdown("## Image Description Practice")
|
454 |
gr.Markdown(
|
|
|
460 |
with gr.Row():
|
461 |
chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True)
|
462 |
send_btn = gr.Button("Submit")
|
|
|
|
|
463 |
with gr.Column(scale=1):
|
464 |
gr.Markdown("## Details to Identify")
|
465 |
gr.Markdown("The child should try to identify these elements in the image:")
|
|
|
|
|
466 |
checklist_html = gr.HTML("""
|
467 |
<div id="checklist-container">
|
468 |
<p>Generate an image to see details to identify.</p>
|
469 |
</div>
|
470 |
""")
|
471 |
+
attempt_counter_html = gr.HTML("""
|
472 |
+
<div id="attempt-counter" style="margin-top: 10px; padding: 10px; background-color: #030404; border-radius: 5px;">
|
473 |
+
<p style="margin: 0; font-weight: bold;">Attempts: 0/3</p>
|
474 |
+
</div>
|
475 |
+
""")
|
476 |
def update_checklist_html(checklist):
|
477 |
if not checklist:
|
478 |
return """
|
|
|
480 |
<p>Generate an image to see details to identify.</p>
|
481 |
</div>
|
482 |
"""
|
|
|
483 |
html_content = """
|
484 |
<div id="checklist-container" style="padding: 10px;">
|
485 |
<style>
|
|
|
505 |
}
|
506 |
</style>
|
507 |
"""
|
|
|
508 |
for item in checklist:
|
509 |
detail = item["detail"]
|
510 |
identified = item["identified"]
|
511 |
css_class = "identified" if identified else "not-identified"
|
512 |
checkmark = "β
" if identified else "β¬"
|
|
|
513 |
html_content += f"""
|
514 |
<div class="checklist-item {css_class}">
|
515 |
<span class="checkmark">{checkmark}</span>
|
516 |
<span>{detail}</span>
|
517 |
</div>
|
518 |
"""
|
|
|
519 |
html_content += """
|
520 |
</div>
|
521 |
"""
|
522 |
return html_content
|
|
|
|
|
523 |
progress_html = gr.HTML("""
|
524 |
<div id="progress-container">
|
525 |
<p>No active session.</p>
|
526 |
</div>
|
527 |
""")
|
|
|
528 |
def update_progress_html(checklist):
|
529 |
if not checklist:
|
530 |
return """
|
|
|
532 |
<p>No active session.</p>
|
533 |
</div>
|
534 |
"""
|
|
|
535 |
total_items = len(checklist)
|
536 |
identified_items = sum(1 for item in checklist if item["identified"])
|
537 |
percentage = (identified_items / total_items) * 100 if total_items > 0 else 0
|
|
|
538 |
progress_bar_width = f"{percentage}%"
|
539 |
all_identified = identified_items == total_items
|
|
|
540 |
html_content = f"""
|
541 |
<div id="progress-container" style="padding: 10px;">
|
542 |
<h3>Progress: {identified_items} / {total_items} details</h3>
|
|
|
545 |
</div>
|
546 |
<p style="font-size: 16px; font-weight: bold; text-align: center;">
|
547 |
"""
|
|
|
548 |
if all_identified:
|
549 |
html_content += "π Amazing! All details identified! π"
|
550 |
elif percentage >= 75:
|
|
|
555 |
html_content += "Good start! Keep looking!"
|
556 |
else:
|
557 |
html_content += "Let's find more details!"
|
|
|
558 |
html_content += """
|
559 |
</p>
|
560 |
</div>
|
561 |
"""
|
562 |
return html_content
|
563 |
|
564 |
+
def update_attempt_counter(active_session):
|
565 |
+
current_count = active_session.get("attempt_count", 0)
|
566 |
+
limit = active_session.get("attempt_limit", 3)
|
567 |
+
return f"""
|
568 |
+
<div id="attempt-counter" style="margin-top: 10px; padding: 10px; background-color: #bfbfbf; border-radius: 5px; border: 1px solid #ddd;">
|
569 |
+
<p style="margin: 0; font-weight: bold; text-align: center;">Attempts: {current_count}/{limit}</p>
|
570 |
+
</div>
|
571 |
+
"""
|
572 |
with gr.Row():
|
573 |
with gr.Column():
|
574 |
gr.Markdown("## Progress Tracking")
|
|
|
578 |
"and the full conversation history."
|
579 |
)
|
580 |
sessions_output = gr.JSON(label="Session Details", value={})
|
|
|
|
|
581 |
def process_chat_and_image(user_msg, active_session, saved_sessions, checklist):
|
582 |
+
chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, new_image = chat_respond(
|
583 |
user_msg, active_session, saved_sessions, checklist
|
584 |
)
|
|
|
|
|
585 |
if new_image is not None:
|
586 |
+
return chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, new_image
|
587 |
else:
|
588 |
+
return chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, gr.update()
|
|
|
|
|
|
|
589 |
generate_btn.click(
|
590 |
generate_image_and_reset_chat,
|
591 |
+
inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, attempt_limit_input, active_session, saved_sessions],
|
592 |
outputs=[img_output, active_session, saved_sessions, checklist_state]
|
593 |
)
|
|
|
594 |
send_btn.click(
|
595 |
process_chat_and_image,
|
596 |
inputs=[chat_input, active_session, saved_sessions, checklist_state],
|
597 |
outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
|
598 |
)
|
|
|
599 |
chat_input.submit(
|
600 |
process_chat_and_image,
|
601 |
inputs=[chat_input, active_session, saved_sessions, checklist_state],
|
602 |
outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
|
603 |
)
|
|
|
|
|
604 |
checklist_state.change(
|
605 |
update_checklist_html,
|
606 |
inputs=[checklist_state],
|
607 |
outputs=[checklist_html]
|
608 |
)
|
|
|
|
|
609 |
checklist_state.change(
|
610 |
update_progress_html,
|
611 |
inputs=[checklist_state],
|
612 |
outputs=[progress_html]
|
613 |
)
|
|
|
|
|
614 |
active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
|
615 |
+
active_session.change(update_attempt_counter, inputs=[active_session], outputs=[attempt_counter_html])
|
|
|
616 |
active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
617 |
saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
|
|
|
|
|
618 |
demo.launch()
|