Daemontatox commited on
Commit
7e52f0b
Β·
verified Β·
1 Parent(s): 0039777

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -175
app.py CHANGED
@@ -49,9 +49,7 @@ def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, tre
49
  Use descriptive and detailed language.
50
  """
51
  )
52
- # Initialize the Gemini Pro model
53
- model = GenerativeModel('gemini-2.0-flash-lite')
54
- # Generate content using the Gemini model
55
  response = model.generate_content(query)
56
  return response.text.strip()
57
 
@@ -59,7 +57,6 @@ def generate_detailed_description(image_data_url, prompt, difficulty, topic_focu
59
  """
60
  Generate a detailed description of the image using Gemini Vision.
61
  """
62
- # Remove the data:image/png;base64, prefix to get just the base64 string
63
  base64_img = image_data_url.split(",")[1]
64
  query = (
65
  f"""
@@ -77,69 +74,54 @@ def generate_detailed_description(image_data_url, prompt, difficulty, topic_focu
77
  so please be comprehensive but focus on observable details rather than interpretations.
78
  """
79
  )
80
- # Create a Gemini Vision Pro model
81
  vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
82
- # Create the content with image and text
83
  image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
84
  text_part = Part(text=query)
85
  multimodal_content = Content(parts=[image_part, text_part])
86
- # Generate description using the vision model
87
  response = vision_model.generate_content(multimodal_content)
88
  return response.text.strip()
89
 
90
- def generate_image_fn(selected_prompt, guidance_scale=7.5,
91
- negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs",
92
- num_inference_steps=50):
93
- """
94
- Generate an image from the prompt via the Hugging Face Inference API.
95
- Convert the image to a data URL.
96
- """
97
- global global_image_data_url, global_image_prompt
98
- global_image_prompt = selected_prompt
99
- image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
100
- image = image_client.text_to_image(
101
- selected_prompt,
102
- model="stabilityai/stable-diffusion-3.5-large-turbo",
103
- guidance_scale=guidance_scale,
104
- negative_prompt=negative_prompt,
105
- num_inference_steps=num_inference_steps
106
- )
107
- buffered = io.BytesIO()
108
- image.save(buffered, format="PNG")
109
- img_bytes = buffered.getvalue()
110
- img_b64 = base64.b64encode(img_bytes).decode("utf-8")
111
- global_image_data_url = f"data:image/png;base64,{img_b64}"
112
- return image
113
-
114
- def extract_key_details(description):
115
  """
116
- Extract key details from Gemini's description to use for tracking.
117
- Returns a list of key elements/details from the description.
118
  """
119
- # Create a query to extract key details
120
  query = (
121
  f"""
122
- From the following detailed image description, extract a list of 10-15 key details that a child might identify.
 
 
 
123
  Each detail should be a simple, clear phrase describing one observable element.
124
- Description:
125
- {description}
126
  Format your response as a JSON array of strings, each representing one key detail.
127
  Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"]
 
 
 
 
 
 
128
  """
129
  )
130
- # Use Gemini text model to extract key details
131
- model = GenerativeModel('gemini-2.0-flash-lite')
132
- response = model.generate_content(query)
 
 
 
 
133
  try:
134
- # Parse the JSON response
135
  details_match = re.search(r'\[.*\]', response.text, re.DOTALL)
136
  if details_match:
137
  details_json = details_match.group(0)
138
  key_details = json.loads(details_json)
139
  return key_details
140
  else:
141
- # If no JSON found, do basic extraction
142
- lines = description.split('\n')
143
  details = []
144
  for line in lines:
145
  if line.strip().startswith('-') or line.strip().startswith('*'):
@@ -149,33 +131,47 @@ def extract_key_details(description):
149
  print(f"Error extracting key details: {str(e)}")
150
  return ["object in image", "color", "shape", "background"]
151
 
152
- def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  """
154
  Generate a new image (with the current difficulty) and reset the chat.
155
- Now includes the topic_focus parameter to specify what the image should focus on.
156
  """
157
  global global_image_description
158
  new_sessions = saved_sessions.copy()
159
  if active_session.get("prompt"):
160
  new_sessions.append(active_session)
161
 
162
- # Use the current difficulty from the active session
163
  current_difficulty = active_session.get("difficulty", "Very Simple")
164
-
165
- # Generate the prompt for the image
166
  generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
167
-
168
- # Generate the image
169
  image = generate_image_fn(generated_prompt)
170
-
171
- # Generate a detailed description of the image using Gemini Vision
172
  image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus)
173
  global_image_description = image_description
 
174
 
175
- # Extract key details to be identified
176
- key_details = extract_key_details(image_description)
177
-
178
- # Create a new active session with all the necessary information
179
  new_active_session = {
180
  "prompt": generated_prompt,
181
  "image": global_image_data_url,
@@ -183,34 +179,30 @@ def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan
183
  "chat": [],
184
  "treatment_plan": treatment_plan,
185
  "topic_focus": topic_focus,
186
- "key_details": key_details, # Store the list of key details
187
  "identified_details": [],
188
  "used_hints": [],
189
  "difficulty": current_difficulty,
190
  "autism_level": autism_level,
191
- "age": age
 
 
192
  }
193
 
194
- # Create the checklist of items to identify
195
  checklist_items = []
196
  for i, detail in enumerate(key_details):
197
  checklist_items.append({"detail": detail, "identified": False, "id": i})
198
 
199
- # Return the updated state and checklist
200
  return image, new_active_session, new_sessions, checklist_items
201
 
202
  def compare_details_chat_fn(user_details, active_session):
203
  """
204
  Evaluate the child's description using Google's Gemini model.
205
- Now uses the image description and tracks identified details and used hints.
206
  """
207
  if not global_image_data_url or not global_image_description:
208
  return "Please generate an image first."
209
 
210
- # Get the detailed image description
211
  image_description = active_session.get("image_description", global_image_description)
212
-
213
- # Get chat history
214
  chat_history = active_session.get("chat", [])
215
  history_text = ""
216
  if chat_history:
@@ -218,12 +210,10 @@ def compare_details_chat_fn(user_details, active_session):
218
  for idx, (speaker, msg) in enumerate(chat_history, 1):
219
  history_text += f"Turn {idx}:\n{speaker}: {msg}\n"
220
 
221
- # Get key details, identified details and used hints
222
  key_details = active_session.get("key_details", [])
223
  identified_details = active_session.get("identified_details", [])
224
  used_hints = active_session.get("used_hints", [])
225
 
226
- # Format for the API
227
  key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details)
228
  identified_details_text = ""
229
  if identified_details:
@@ -232,9 +222,7 @@ def compare_details_chat_fn(user_details, active_session):
232
  if used_hints:
233
  used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints)
234
 
235
- # Current difficulty level
236
  current_difficulty = active_session.get("difficulty", "Very Simple")
237
-
238
  message_text = (
239
  f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n"
240
  f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n"
@@ -263,14 +251,15 @@ def compare_details_chat_fn(user_details, active_session):
263
  "Ensure the JSON is valid and contains all fields."
264
  )
265
 
266
- # Create a Gemini model for evaluation
267
- model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
268
-
269
- # Generate evaluation using the model
270
  response = model.generate_content(message_text)
271
  return response.text
272
 
273
  def parse_evaluation(evaluation_text, active_session):
 
 
 
 
274
  try:
275
  json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
276
  if json_match:
@@ -279,35 +268,28 @@ def parse_evaluation(evaluation_text, active_session):
279
  else:
280
  raise ValueError("No JSON object found in the response.")
281
 
282
- # Extract data from the evaluation
283
  feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.")
284
  newly_identified_details = evaluation.get("newly_identified_details", [])
285
  hint = evaluation.get("hint", "")
286
  score = evaluation.get("score", 0)
287
  advance_difficulty = evaluation.get("advance_difficulty", False)
288
 
289
- # Update the session with newly identified details
290
  identified_details = active_session.get("identified_details", [])
291
  for detail in newly_identified_details:
292
  if detail not in identified_details:
293
  identified_details.append(detail)
294
  active_session["identified_details"] = identified_details
295
 
296
- # Add the hint to used hints if one was provided
297
  if hint:
298
  used_hints = active_session.get("used_hints", [])
299
  if hint not in used_hints:
300
  used_hints.append(hint)
301
  active_session["used_hints"] = used_hints
302
-
303
- # Add the hint to the feedback if it's not already included
304
  if hint.strip() and hint.strip() not in feedback:
305
  feedback += f"\n\nπŸ’‘ Hint: {hint}"
306
 
307
- # Get current difficulty and check if it should be advanced
308
  current_difficulty = active_session.get("difficulty", "Very Simple")
309
  should_advance = False
310
-
311
  if advance_difficulty:
312
  difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"]
313
  current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0
@@ -315,95 +297,71 @@ def parse_evaluation(evaluation_text, active_session):
315
  current_difficulty = difficulties[current_index + 1]
316
  should_advance = True
317
 
318
- return feedback, current_difficulty, should_advance, newly_identified_details
319
-
320
  except Exception as e:
321
  print(f"Error processing evaluation: {str(e)}")
322
- return f"That's interesting! Can you tell me more about what you see?", active_session.get("difficulty", "Very Simple"), False, []
 
 
 
 
323
 
324
  def update_checklist(checklist, newly_identified, key_details):
325
  """
326
  Update the checklist based on newly identified details.
327
- Returns an updated checklist.
328
  """
329
  new_checklist = []
330
  for item in checklist:
331
  detail = item["detail"]
332
- # Check if this detail has been identified
333
  is_identified = item["identified"]
334
-
335
- # If newly identified, update status
336
  for identified in newly_identified:
337
- # Check if the identified detail matches or is similar to the key detail
338
  if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or
339
  any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)):
340
  is_identified = True
341
  break
342
-
343
  new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]})
344
-
345
  return new_checklist
346
 
347
  def chat_respond(user_message, active_session, saved_sessions, checklist):
348
  """
349
  Process a new chat message.
350
  Evaluate the child's description, update identified details, and advance difficulty if needed.
 
351
  """
352
  if not active_session.get("image"):
353
  bot_message = "Please generate an image first."
354
  updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)]
355
  active_session["chat"] = updated_chat
356
- return "", updated_chat, saved_sessions, active_session, checklist, None # Return None for image
357
 
358
- # Get the evaluation from Gemini
359
  raw_evaluation = compare_details_chat_fn(user_message, active_session)
 
360
 
361
- # Parse the evaluation and update session
362
- feedback, updated_difficulty, should_advance, newly_identified = parse_evaluation(raw_evaluation, active_session)
 
363
 
364
- # Update the checklist with newly identified details
365
  updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", []))
366
-
367
- # Add the current exchange to the chat history
368
  updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)]
369
  active_session["chat"] = updated_chat
370
 
371
- # Check if all items have been identified
372
  all_identified = all(item["identified"] for item in updated_checklist)
 
 
373
 
374
- # Modify this line to generate new image when all details are identified
375
- should_generate_new_image = should_advance or all_identified
376
-
377
- # If the child should advance to a new difficulty or has identified all details
378
  if should_generate_new_image:
379
- # Save the current session
380
  new_sessions = saved_sessions.copy()
381
  new_sessions.append(active_session.copy())
382
-
383
- # Get parameters for generating new image
384
  age = active_session.get("age", "3")
385
  autism_level = active_session.get("autism_level", "Level 1")
386
  topic_focus = active_session.get("topic_focus", "")
387
  treatment_plan = active_session.get("treatment_plan", "")
388
-
389
- # Use current difficulty if not advancing, otherwise use updated difficulty
390
  difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple")
391
-
392
- # Generate a new prompt with the difficulty
393
  generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan)
394
-
395
- # Generate the new image - returns a PIL Image
396
  new_image = generate_image_fn(generated_prompt)
397
-
398
- # Now the global_image_data_url should be updated
399
-
400
- # Generate a detailed description of the image using Gemini Vision
401
  image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus)
 
402
 
403
- # Extract key details to be identified
404
- key_details = extract_key_details(image_description)
405
-
406
- # Create fresh active session with the new image
407
  new_active_session = {
408
  "prompt": generated_prompt,
409
  "image": global_image_data_url,
@@ -416,25 +374,25 @@ def chat_respond(user_message, active_session, saved_sessions, checklist):
416
  "used_hints": [],
417
  "difficulty": difficulty_to_use,
418
  "autism_level": autism_level,
419
- "age": age
 
 
420
  }
421
 
422
- # Create new checklist for the new image
423
  new_checklist = []
424
  for i, detail in enumerate(key_details):
425
  new_checklist.append({"detail": detail, "identified": False, "id": i})
426
 
427
- # Initialize the new chat with an appropriate message
428
- if updated_difficulty != active_session.get("difficulty", "Very Simple"):
 
429
  advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe."
430
  else:
431
  advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level."
432
 
433
  new_active_session["chat"] = [("System", advancement_message)]
434
-
435
  return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image
436
 
437
- # If not advancing, return None for the image to indicate no change
438
  return "", updated_chat, saved_sessions, active_session, updated_checklist, None
439
 
440
  def update_sessions(saved_sessions, active_session):
@@ -449,7 +407,6 @@ def update_sessions(saved_sessions, active_session):
449
  # Gradio Interface
450
  ##############################################
451
  with gr.Blocks() as demo:
452
- # Initialize the active session with default values
453
  active_session = gr.State({
454
  "prompt": None,
455
  "image": None,
@@ -462,19 +419,17 @@ with gr.Blocks() as demo:
462
  "used_hints": [],
463
  "difficulty": "Very Simple",
464
  "age": "3",
465
- "autism_level": "Level 1"
 
 
466
  })
467
  saved_sessions = gr.State([])
468
  checklist_state = gr.State([])
469
 
470
  with gr.Row():
471
- # Main content area
472
  with gr.Column(scale=2):
473
  gr.Markdown("# Autism Education Image Description Tool")
474
- # Display current difficulty label
475
  difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
476
-
477
- # ----- Image Generation Section -----
478
  with gr.Column():
479
  gr.Markdown("## Generate Image")
480
  gr.Markdown("Enter the child's details to generate an appropriate educational image.")
@@ -491,10 +446,9 @@ with gr.Blocks() as demo:
491
  placeholder="Enter the treatment plan to guide the image generation...",
492
  lines=2
493
  )
 
494
  generate_btn = gr.Button("Generate Image")
495
  img_output = gr.Image(label="Generated Image")
496
-
497
- # ----- Chat Section -----
498
  with gr.Column():
499
  gr.Markdown("## Image Description Practice")
500
  gr.Markdown(
@@ -506,20 +460,19 @@ with gr.Blocks() as demo:
506
  with gr.Row():
507
  chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True)
508
  send_btn = gr.Button("Submit")
509
-
510
- # Sidebar - Checklist of items to identify
511
  with gr.Column(scale=1):
512
  gr.Markdown("## Details to Identify")
513
  gr.Markdown("The child should try to identify these elements in the image:")
514
-
515
- # Create a custom HTML component to display the checklist with checkboxes
516
  checklist_html = gr.HTML("""
517
  <div id="checklist-container">
518
  <p>Generate an image to see details to identify.</p>
519
  </div>
520
  """)
521
-
522
- # Add a function to update the checklist HTML
 
 
 
523
  def update_checklist_html(checklist):
524
  if not checklist:
525
  return """
@@ -527,7 +480,6 @@ with gr.Blocks() as demo:
527
  <p>Generate an image to see details to identify.</p>
528
  </div>
529
  """
530
-
531
  html_content = """
532
  <div id="checklist-container" style="padding: 10px;">
533
  <style>
@@ -553,32 +505,26 @@ with gr.Blocks() as demo:
553
  }
554
  </style>
555
  """
556
-
557
  for item in checklist:
558
  detail = item["detail"]
559
  identified = item["identified"]
560
  css_class = "identified" if identified else "not-identified"
561
  checkmark = "βœ…" if identified else "⬜"
562
-
563
  html_content += f"""
564
  <div class="checklist-item {css_class}">
565
  <span class="checkmark">{checkmark}</span>
566
  <span>{detail}</span>
567
  </div>
568
  """
569
-
570
  html_content += """
571
  </div>
572
  """
573
  return html_content
574
-
575
- # Progress summary
576
  progress_html = gr.HTML("""
577
  <div id="progress-container">
578
  <p>No active session.</p>
579
  </div>
580
  """)
581
-
582
  def update_progress_html(checklist):
583
  if not checklist:
584
  return """
@@ -586,14 +532,11 @@ with gr.Blocks() as demo:
586
  <p>No active session.</p>
587
  </div>
588
  """
589
-
590
  total_items = len(checklist)
591
  identified_items = sum(1 for item in checklist if item["identified"])
592
  percentage = (identified_items / total_items) * 100 if total_items > 0 else 0
593
-
594
  progress_bar_width = f"{percentage}%"
595
  all_identified = identified_items == total_items
596
-
597
  html_content = f"""
598
  <div id="progress-container" style="padding: 10px;">
599
  <h3>Progress: {identified_items} / {total_items} details</h3>
@@ -602,7 +545,6 @@ with gr.Blocks() as demo:
602
  </div>
603
  <p style="font-size: 16px; font-weight: bold; text-align: center;">
604
  """
605
-
606
  if all_identified:
607
  html_content += "πŸŽ‰ Amazing! All details identified! πŸŽ‰"
608
  elif percentage >= 75:
@@ -613,14 +555,20 @@ with gr.Blocks() as demo:
613
  html_content += "Good start! Keep looking!"
614
  else:
615
  html_content += "Let's find more details!"
616
-
617
  html_content += """
618
  </p>
619
  </div>
620
  """
621
  return html_content
622
 
623
- # ----- Session Details Section -----
 
 
 
 
 
 
 
624
  with gr.Row():
625
  with gr.Column():
626
  gr.Markdown("## Progress Tracking")
@@ -630,59 +578,41 @@ with gr.Blocks() as demo:
630
  "and the full conversation history."
631
  )
632
  sessions_output = gr.JSON(label="Session Details", value={})
633
-
634
- # Process chat and update image as needed
635
  def process_chat_and_image(user_msg, active_session, saved_sessions, checklist):
636
- chat_input, chatbot, new_sessions, new_active_session, new_checklist, new_image = chat_respond(
637
  user_msg, active_session, saved_sessions, checklist
638
  )
639
-
640
- # Only return a new image if one was generated (advancement case)
641
  if new_image is not None:
642
- return chat_input, chatbot, new_sessions, new_active_session, new_checklist, new_image
643
  else:
644
- # Return a no-update flag for the image to keep the current one
645
- return chat_input, chatbot, new_sessions, new_active_session, new_checklist, gr.update()
646
-
647
- # Connect event handlers
648
  generate_btn.click(
649
  generate_image_and_reset_chat,
650
- inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
651
  outputs=[img_output, active_session, saved_sessions, checklist_state]
652
  )
653
-
654
  send_btn.click(
655
  process_chat_and_image,
656
  inputs=[chat_input, active_session, saved_sessions, checklist_state],
657
  outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
658
  )
659
-
660
  chat_input.submit(
661
  process_chat_and_image,
662
  inputs=[chat_input, active_session, saved_sessions, checklist_state],
663
  outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
664
  )
665
-
666
- # Update the checklist HTML when checklist state changes
667
  checklist_state.change(
668
  update_checklist_html,
669
  inputs=[checklist_state],
670
  outputs=[checklist_html]
671
  )
672
-
673
- # Update the progress HTML when checklist state changes
674
  checklist_state.change(
675
  update_progress_html,
676
  inputs=[checklist_state],
677
  outputs=[progress_html]
678
  )
679
-
680
- # Update the current difficulty label when active_session changes
681
  active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
682
-
683
- # Update sessions when active_session or saved_sessions change
684
  active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
685
  saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
686
-
687
- # Launch the app
688
  demo.launch()
 
49
  Use descriptive and detailed language.
50
  """
51
  )
52
+ model = GenerativeModel('gemini-2.0-flash')
 
 
53
  response = model.generate_content(query)
54
  return response.text.strip()
55
 
 
57
  """
58
  Generate a detailed description of the image using Gemini Vision.
59
  """
 
60
  base64_img = image_data_url.split(",")[1]
61
  query = (
62
  f"""
 
74
  so please be comprehensive but focus on observable details rather than interpretations.
75
  """
76
  )
 
77
  vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
 
78
  image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
79
  text_part = Part(text=query)
80
  multimodal_content = Content(parts=[image_part, text_part])
 
81
  response = vision_model.generate_content(multimodal_content)
82
  return response.text.strip()
83
 
84
+ def extract_key_details(image_data_url, prompt, topic_focus):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
+ Extract key details directly from the image using Gemini Vision.
87
+ Returns a list of key elements/details from the image.
88
  """
89
+ base64_img = image_data_url.split(",")[1]
90
  query = (
91
  f"""
92
+ You are analyzing an educational image created for a child with autism, based on the prompt: "{prompt}".
93
+ The image focuses on the topic: "{topic_focus}".
94
+
95
+ Please extract a list of 10-15 key details that a child might identify in this image.
96
  Each detail should be a simple, clear phrase describing one observable element.
97
+ Focus on concrete, visible elements rather than abstract concepts.
98
+
99
  Format your response as a JSON array of strings, each representing one key detail.
100
  Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"]
101
+
102
+ Ensure each detail is:
103
+ 1. Directly observable in the image
104
+ 2. Unique (not a duplicate)
105
+ 3. Described in simple, concrete language
106
+ 4. Relevant to what a child would notice
107
  """
108
  )
109
+
110
+ vision_model = GenerativeModel('gemini-2.0-flash')
111
+ image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
112
+ text_part = Part(text=query)
113
+ multimodal_content = Content(parts=[image_part, text_part])
114
+ response = vision_model.generate_content(multimodal_content)
115
+
116
  try:
 
117
  details_match = re.search(r'\[.*\]', response.text, re.DOTALL)
118
  if details_match:
119
  details_json = details_match.group(0)
120
  key_details = json.loads(details_json)
121
  return key_details
122
  else:
123
+ # If no JSON array is found, try to extract bullet points or lines
124
+ lines = response.text.split('\n')
125
  details = []
126
  for line in lines:
127
  if line.strip().startswith('-') or line.strip().startswith('*'):
 
131
  print(f"Error extracting key details: {str(e)}")
132
  return ["object in image", "color", "shape", "background"]
133
 
134
+ def generate_image_fn(selected_prompt, guidance_scale=7.5,
135
+ negative_prompt="ugly, blurry, poorly drawn hands, nude, deformed, missing limbs, missing body parts",
136
+ num_inference_steps=45):
137
+ """
138
+ Generate an image from the prompt via the Hugging Face Inference API.
139
+ Convert the image to a data URL.
140
+ """
141
+ global global_image_data_url, global_image_prompt
142
+ global_image_prompt = selected_prompt
143
+ image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key)
144
+ image = image_client.text_to_image(
145
+ selected_prompt,
146
+ model="stabilityai/stable-diffusion-3.5-large-turbo",
147
+ guidance_scale=guidance_scale,
148
+ negative_prompt=negative_prompt,
149
+ num_inference_steps=num_inference_steps
150
+ )
151
+ buffered = io.BytesIO()
152
+ image.save(buffered, format="PNG")
153
+ img_bytes = buffered.getvalue()
154
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
155
+ global_image_data_url = f"data:image/png;base64,{img_b64}"
156
+ return image
157
+
158
+ def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, attempt_limit_input, active_session, saved_sessions):
159
  """
160
  Generate a new image (with the current difficulty) and reset the chat.
161
+ Also resets the attempt count and uses the user-entered attempt limit.
162
  """
163
  global global_image_description
164
  new_sessions = saved_sessions.copy()
165
  if active_session.get("prompt"):
166
  new_sessions.append(active_session)
167
 
 
168
  current_difficulty = active_session.get("difficulty", "Very Simple")
 
 
169
  generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
 
 
170
  image = generate_image_fn(generated_prompt)
 
 
171
  image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus)
172
  global_image_description = image_description
173
+ key_details = extract_key_details(global_image_data_url, generated_prompt, topic_focus)
174
 
 
 
 
 
175
  new_active_session = {
176
  "prompt": generated_prompt,
177
  "image": global_image_data_url,
 
179
  "chat": [],
180
  "treatment_plan": treatment_plan,
181
  "topic_focus": topic_focus,
182
+ "key_details": key_details,
183
  "identified_details": [],
184
  "used_hints": [],
185
  "difficulty": current_difficulty,
186
  "autism_level": autism_level,
187
+ "age": age,
188
+ "attempt_limit": int(attempt_limit_input) if attempt_limit_input else 3,
189
+ "attempt_count": 0
190
  }
191
 
 
192
  checklist_items = []
193
  for i, detail in enumerate(key_details):
194
  checklist_items.append({"detail": detail, "identified": False, "id": i})
195
 
 
196
  return image, new_active_session, new_sessions, checklist_items
197
 
198
  def compare_details_chat_fn(user_details, active_session):
199
  """
200
  Evaluate the child's description using Google's Gemini model.
 
201
  """
202
  if not global_image_data_url or not global_image_description:
203
  return "Please generate an image first."
204
 
 
205
  image_description = active_session.get("image_description", global_image_description)
 
 
206
  chat_history = active_session.get("chat", [])
207
  history_text = ""
208
  if chat_history:
 
210
  for idx, (speaker, msg) in enumerate(chat_history, 1):
211
  history_text += f"Turn {idx}:\n{speaker}: {msg}\n"
212
 
 
213
  key_details = active_session.get("key_details", [])
214
  identified_details = active_session.get("identified_details", [])
215
  used_hints = active_session.get("used_hints", [])
216
 
 
217
  key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details)
218
  identified_details_text = ""
219
  if identified_details:
 
222
  if used_hints:
223
  used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints)
224
 
 
225
  current_difficulty = active_session.get("difficulty", "Very Simple")
 
226
  message_text = (
227
  f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n"
228
  f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n"
 
251
  "Ensure the JSON is valid and contains all fields."
252
  )
253
 
254
+ model = GenerativeModel('gemini-2.0-flash')
 
 
 
255
  response = model.generate_content(message_text)
256
  return response.text
257
 
258
  def parse_evaluation(evaluation_text, active_session):
259
+ """
260
+ Parse the evaluation JSON and return feedback, updated difficulty, whether to advance,
261
+ newly identified details, and the score.
262
+ """
263
  try:
264
  json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
265
  if json_match:
 
268
  else:
269
  raise ValueError("No JSON object found in the response.")
270
 
 
271
  feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.")
272
  newly_identified_details = evaluation.get("newly_identified_details", [])
273
  hint = evaluation.get("hint", "")
274
  score = evaluation.get("score", 0)
275
  advance_difficulty = evaluation.get("advance_difficulty", False)
276
 
 
277
  identified_details = active_session.get("identified_details", [])
278
  for detail in newly_identified_details:
279
  if detail not in identified_details:
280
  identified_details.append(detail)
281
  active_session["identified_details"] = identified_details
282
 
 
283
  if hint:
284
  used_hints = active_session.get("used_hints", [])
285
  if hint not in used_hints:
286
  used_hints.append(hint)
287
  active_session["used_hints"] = used_hints
 
 
288
  if hint.strip() and hint.strip() not in feedback:
289
  feedback += f"\n\nπŸ’‘ Hint: {hint}"
290
 
 
291
  current_difficulty = active_session.get("difficulty", "Very Simple")
292
  should_advance = False
 
293
  if advance_difficulty:
294
  difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"]
295
  current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0
 
297
  current_difficulty = difficulties[current_index + 1]
298
  should_advance = True
299
 
300
+ return feedback, current_difficulty, should_advance, newly_identified_details, score
 
301
  except Exception as e:
302
  print(f"Error processing evaluation: {str(e)}")
303
+ return ("That's interesting! Can you tell me more about what you see?",
304
+ active_session.get("difficulty", "Very Simple"),
305
+ False,
306
+ [],
307
+ 0)
308
 
309
  def update_checklist(checklist, newly_identified, key_details):
310
  """
311
  Update the checklist based on newly identified details.
 
312
  """
313
  new_checklist = []
314
  for item in checklist:
315
  detail = item["detail"]
 
316
  is_identified = item["identified"]
 
 
317
  for identified in newly_identified:
 
318
  if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or
319
  any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)):
320
  is_identified = True
321
  break
 
322
  new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]})
 
323
  return new_checklist
324
 
325
  def chat_respond(user_message, active_session, saved_sessions, checklist):
326
  """
327
  Process a new chat message.
328
  Evaluate the child's description, update identified details, and advance difficulty if needed.
329
+ Only increment the attempt count if no new details were identified.
330
  """
331
  if not active_session.get("image"):
332
  bot_message = "Please generate an image first."
333
  updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)]
334
  active_session["chat"] = updated_chat
335
+ return "", updated_chat, saved_sessions, active_session, checklist, None
336
 
 
337
  raw_evaluation = compare_details_chat_fn(user_message, active_session)
338
+ feedback, updated_difficulty, should_advance, newly_identified, score = parse_evaluation(raw_evaluation, active_session)
339
 
340
+ # Only count a failed attempt if no new details were identified
341
+ if not newly_identified:
342
+ active_session["attempt_count"] = active_session.get("attempt_count", 0) + 1
343
 
 
344
  updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", []))
 
 
345
  updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)]
346
  active_session["chat"] = updated_chat
347
 
 
348
  all_identified = all(item["identified"] for item in updated_checklist)
349
+ attempts_exhausted = active_session.get("attempt_count", 0) >= active_session.get("attempt_limit", 3)
350
+ should_generate_new_image = should_advance or all_identified or attempts_exhausted
351
 
 
 
 
 
352
  if should_generate_new_image:
 
353
  new_sessions = saved_sessions.copy()
354
  new_sessions.append(active_session.copy())
 
 
355
  age = active_session.get("age", "3")
356
  autism_level = active_session.get("autism_level", "Level 1")
357
  topic_focus = active_session.get("topic_focus", "")
358
  treatment_plan = active_session.get("treatment_plan", "")
 
 
359
  difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple")
 
 
360
  generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan)
 
 
361
  new_image = generate_image_fn(generated_prompt)
 
 
 
 
362
  image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus)
363
+ key_details = extract_key_details(global_image_data_url, generated_prompt, topic_focus)
364
 
 
 
 
 
365
  new_active_session = {
366
  "prompt": generated_prompt,
367
  "image": global_image_data_url,
 
374
  "used_hints": [],
375
  "difficulty": difficulty_to_use,
376
  "autism_level": autism_level,
377
+ "age": age,
378
+ "attempt_limit": active_session.get("attempt_limit", 3),
379
+ "attempt_count": 0
380
  }
381
 
 
382
  new_checklist = []
383
  for i, detail in enumerate(key_details):
384
  new_checklist.append({"detail": detail, "identified": False, "id": i})
385
 
386
+ if attempts_exhausted:
387
+ advancement_message = "You've used all your allowed attempts. Let's try a new image."
388
+ elif updated_difficulty != active_session.get("difficulty", "Very Simple"):
389
  advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe."
390
  else:
391
  advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level."
392
 
393
  new_active_session["chat"] = [("System", advancement_message)]
 
394
  return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image
395
 
 
396
  return "", updated_chat, saved_sessions, active_session, updated_checklist, None
397
 
398
  def update_sessions(saved_sessions, active_session):
 
407
  # Gradio Interface
408
  ##############################################
409
  with gr.Blocks() as demo:
 
410
  active_session = gr.State({
411
  "prompt": None,
412
  "image": None,
 
419
  "used_hints": [],
420
  "difficulty": "Very Simple",
421
  "age": "3",
422
+ "autism_level": "Level 1",
423
+ "attempt_limit": 3,
424
+ "attempt_count": 0
425
  })
426
  saved_sessions = gr.State([])
427
  checklist_state = gr.State([])
428
 
429
  with gr.Row():
 
430
  with gr.Column(scale=2):
431
  gr.Markdown("# Autism Education Image Description Tool")
 
432
  difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
 
 
433
  with gr.Column():
434
  gr.Markdown("## Generate Image")
435
  gr.Markdown("Enter the child's details to generate an appropriate educational image.")
 
446
  placeholder="Enter the treatment plan to guide the image generation...",
447
  lines=2
448
  )
449
+ attempt_limit_input = gr.Number(label="Allowed Attempts", value=3, precision=0)
450
  generate_btn = gr.Button("Generate Image")
451
  img_output = gr.Image(label="Generated Image")
 
 
452
  with gr.Column():
453
  gr.Markdown("## Image Description Practice")
454
  gr.Markdown(
 
460
  with gr.Row():
461
  chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True)
462
  send_btn = gr.Button("Submit")
 
 
463
  with gr.Column(scale=1):
464
  gr.Markdown("## Details to Identify")
465
  gr.Markdown("The child should try to identify these elements in the image:")
 
 
466
  checklist_html = gr.HTML("""
467
  <div id="checklist-container">
468
  <p>Generate an image to see details to identify.</p>
469
  </div>
470
  """)
471
+ attempt_counter_html = gr.HTML("""
472
+ <div id="attempt-counter" style="margin-top: 10px; padding: 10px; background-color: #030404; border-radius: 5px;">
473
+ <p style="margin: 0; font-weight: bold;">Attempts: 0/3</p>
474
+ </div>
475
+ """)
476
  def update_checklist_html(checklist):
477
  if not checklist:
478
  return """
 
480
  <p>Generate an image to see details to identify.</p>
481
  </div>
482
  """
 
483
  html_content = """
484
  <div id="checklist-container" style="padding: 10px;">
485
  <style>
 
505
  }
506
  </style>
507
  """
 
508
  for item in checklist:
509
  detail = item["detail"]
510
  identified = item["identified"]
511
  css_class = "identified" if identified else "not-identified"
512
  checkmark = "βœ…" if identified else "⬜"
 
513
  html_content += f"""
514
  <div class="checklist-item {css_class}">
515
  <span class="checkmark">{checkmark}</span>
516
  <span>{detail}</span>
517
  </div>
518
  """
 
519
  html_content += """
520
  </div>
521
  """
522
  return html_content
 
 
523
  progress_html = gr.HTML("""
524
  <div id="progress-container">
525
  <p>No active session.</p>
526
  </div>
527
  """)
 
528
  def update_progress_html(checklist):
529
  if not checklist:
530
  return """
 
532
  <p>No active session.</p>
533
  </div>
534
  """
 
535
  total_items = len(checklist)
536
  identified_items = sum(1 for item in checklist if item["identified"])
537
  percentage = (identified_items / total_items) * 100 if total_items > 0 else 0
 
538
  progress_bar_width = f"{percentage}%"
539
  all_identified = identified_items == total_items
 
540
  html_content = f"""
541
  <div id="progress-container" style="padding: 10px;">
542
  <h3>Progress: {identified_items} / {total_items} details</h3>
 
545
  </div>
546
  <p style="font-size: 16px; font-weight: bold; text-align: center;">
547
  """
 
548
  if all_identified:
549
  html_content += "πŸŽ‰ Amazing! All details identified! πŸŽ‰"
550
  elif percentage >= 75:
 
555
  html_content += "Good start! Keep looking!"
556
  else:
557
  html_content += "Let's find more details!"
 
558
  html_content += """
559
  </p>
560
  </div>
561
  """
562
  return html_content
563
 
564
+ def update_attempt_counter(active_session):
565
+ current_count = active_session.get("attempt_count", 0)
566
+ limit = active_session.get("attempt_limit", 3)
567
+ return f"""
568
+ <div id="attempt-counter" style="margin-top: 10px; padding: 10px; background-color: #bfbfbf; border-radius: 5px; border: 1px solid #ddd;">
569
+ <p style="margin: 0; font-weight: bold; text-align: center;">Attempts: {current_count}/{limit}</p>
570
+ </div>
571
+ """
572
  with gr.Row():
573
  with gr.Column():
574
  gr.Markdown("## Progress Tracking")
 
578
  "and the full conversation history."
579
  )
580
  sessions_output = gr.JSON(label="Session Details", value={})
 
 
581
  def process_chat_and_image(user_msg, active_session, saved_sessions, checklist):
582
+ chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, new_image = chat_respond(
583
  user_msg, active_session, saved_sessions, checklist
584
  )
 
 
585
  if new_image is not None:
586
+ return chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, new_image
587
  else:
588
+ return chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, gr.update()
 
 
 
589
  generate_btn.click(
590
  generate_image_and_reset_chat,
591
+ inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, attempt_limit_input, active_session, saved_sessions],
592
  outputs=[img_output, active_session, saved_sessions, checklist_state]
593
  )
 
594
  send_btn.click(
595
  process_chat_and_image,
596
  inputs=[chat_input, active_session, saved_sessions, checklist_state],
597
  outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
598
  )
 
599
  chat_input.submit(
600
  process_chat_and_image,
601
  inputs=[chat_input, active_session, saved_sessions, checklist_state],
602
  outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
603
  )
 
 
604
  checklist_state.change(
605
  update_checklist_html,
606
  inputs=[checklist_state],
607
  outputs=[checklist_html]
608
  )
 
 
609
  checklist_state.change(
610
  update_progress_html,
611
  inputs=[checklist_state],
612
  outputs=[progress_html]
613
  )
 
 
614
  active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
615
+ active_session.change(update_attempt_counter, inputs=[active_session], outputs=[attempt_counter_html])
 
616
  active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
617
  saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
 
 
618
  demo.launch()