Daemontatox commited on
Commit
a9c0662
Β·
verified Β·
1 Parent(s): 0dfba1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +515 -177
app.py CHANGED
@@ -8,17 +8,20 @@ from PIL import Image
8
  from huggingface_hub import InferenceClient
9
  from google.generativeai import configure, GenerativeModel
10
  from google.ai.generativelanguage import Content, Part
 
 
11
 
12
  # Load API keys from environment variables
13
  inference_api_key = os.environ.get("HF_TOKEN")
14
- google_api_key = os.environ.get("GOOGLE_API_KEY") # New Google API key
15
 
16
  # Configure Google API
17
  configure(api_key=google_api_key)
18
 
19
- # Global variables to store the image data URL and prompt for the currently generated image.
20
  global_image_data_url = None
21
- global_image_prompt = None # Still stored if needed elsewhere
 
22
 
23
  def update_difficulty_label(active_session):
24
  return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"
@@ -36,28 +39,53 @@ def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, tre
36
  - Autism Level: {autism_level}
37
  - Topic Focus: {topic_focus}
38
  - Treatment Plan: {treatment_plan}
39
-
40
  Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.
41
-
42
  The image should specifically focus on the topic: "{topic_focus}".
43
-
44
  Please generate a prompt that instructs the image generation engine to produce an image with:
45
  1. Clarity and simplicity (minimalist backgrounds, clear subject)
46
  2. Literal representation with defined borders and consistent style
47
  3. Soft, muted colors and reduced visual complexity
48
  4. Positive, calm scenes
49
  5. Clear focus on the specified topic
50
-
51
  Use descriptive and detailed language.
52
  """
53
  )
54
-
55
  # Initialize the Gemini Pro model
56
  model = GenerativeModel('gemini-2.0-flash-lite')
57
-
58
  # Generate content using the Gemini model
59
  response = model.generate_content(query)
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  return response.text.strip()
62
 
63
  def generate_image_fn(selected_prompt, guidance_scale=7.5,
@@ -84,100 +112,166 @@ def generate_image_fn(selected_prompt, guidance_scale=7.5,
84
  global_image_data_url = f"data:image/png;base64,{img_b64}"
85
  return image
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
88
  """
89
  Generate a new image (with the current difficulty) and reset the chat.
90
  Now includes the topic_focus parameter to specify what the image should focus on.
91
  """
 
92
  new_sessions = saved_sessions.copy()
93
  if active_session.get("prompt"):
94
  new_sessions.append(active_session)
95
- # Use the current difficulty from the active session (which should be updated if advanced)
 
96
  current_difficulty = active_session.get("difficulty", "Very Simple")
 
 
97
  generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
 
 
98
  image = generate_image_fn(generated_prompt)
 
 
 
 
 
 
 
 
 
99
  new_active_session = {
100
  "prompt": generated_prompt,
101
  "image": global_image_data_url,
 
102
  "chat": [],
103
  "treatment_plan": treatment_plan,
104
  "topic_focus": topic_focus,
 
105
  "identified_details": [],
 
106
  "difficulty": current_difficulty,
107
  "autism_level": autism_level,
108
  "age": age
109
  }
110
- return image, new_active_session, new_sessions
111
 
112
- def compare_details_chat_fn(user_details, treatment_plan, chat_history, identified_details):
 
 
 
 
 
 
 
 
113
  """
114
- Evaluate the child's description using Google's Gemini Vision model.
 
115
  """
116
- if not global_image_data_url:
117
  return "Please generate an image first."
118
 
 
 
 
 
 
119
  history_text = ""
120
  if chat_history:
121
  history_text = "\n\n### Previous Conversation:\n"
122
- for idx, (user_msg, bot_msg) in enumerate(chat_history, 1):
123
- history_text += f"Turn {idx}:\nUser: {user_msg}\nTeacher: {bot_msg}\n"
124
 
 
 
 
 
 
 
 
125
  identified_details_text = ""
126
  if identified_details:
127
  identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)
 
 
 
 
 
 
128
 
129
  message_text = (
130
- f"{history_text}{identified_details_text}\n\n"
131
- f"Based on the image provided above, please evaluate the following description given by the child:\n"
132
- f"'{user_details}'\n\n"
133
- "You are a kind and encouraging teacher speaking to a child. Use simple, clear language. "
134
- "Praise the child's correct observations and provide a gentle hint if something is missing. "
135
- "Keep your feedback positive and easy to understand.\n\n"
136
- "Focus on these evaluation criteria:\n"
137
- "1. **Object Identification** – Did the child mention the main objects?\n"
138
- "2. **Color & Shape Accuracy** – Were the colors and shapes described correctly?\n"
139
- "3. **Clarity & Simplicity** – Was the description clear and easy to understand?\n"
140
- "4. **Overall Communication** – How well did the child communicate their thoughts?\n\n"
141
- "Note: As difficulty increases, the expected level of detail is higher. Evaluate accordingly.\n\n"
142
- "Return your evaluation strictly as a JSON object with the following keys:\n"
 
 
 
 
143
  "{\n"
144
- " \"scores\": {\n"
145
- " \"object_identification\": <number>,\n"
146
- " \"color_shape_accuracy\": <number>,\n"
147
- " \"clarity_simplicity\": <number>,\n"
148
- " \"overall_communication\": <number>\n"
149
- " },\n"
150
- " \"final_score\": <number>,\n"
151
- " \"feedback\": \"<string>\",\n"
152
- " \"hint\": \"<string>\",\n"
153
- " \"advance\": <boolean>\n"
154
  "}\n\n"
155
- "Do not include any additional text outside the JSON."
156
  )
157
 
158
- # Remove the data:image/png;base64, prefix to get just the base64 string
159
- base64_img = global_image_data_url.split(",")[1]
160
-
161
- # Create a Gemini Vision Pro model
162
- vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
163
-
164
- # Create the content with image and text using the correct parameters
165
- # Use 'inline_data' instead of 'content' for the image part
166
- image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
167
- text_part = Part(text=message_text)
168
- multimodal_content = Content(parts=[image_part, text_part])
169
-
170
- # Generate evaluation using the vision model
171
- response = vision_model.generate_content(multimodal_content)
172
 
 
 
173
  return response.text
174
 
175
- def evaluate_scores(evaluation_text, current_difficulty):
176
- """
177
- Parse the JSON evaluation and decide if the child advances.
178
- The threshold scales with difficulty:
179
- Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90.
180
- """
181
  try:
182
  json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
183
  if json_match:
@@ -185,73 +279,164 @@ def evaluate_scores(evaluation_text, current_difficulty):
185
  evaluation = json.loads(json_str)
186
  else:
187
  raise ValueError("No JSON object found in the response.")
188
- final_score = evaluation.get("final_score", 0)
189
- hint = evaluation.get("hint", "Keep trying!")
190
- advance = evaluation.get("advance", False)
191
- difficulty_thresholds = {
192
- "Very Simple": 70,
193
- "Simple": 75,
194
- "Moderate": 80,
195
- "Detailed": 85,
196
- "Very Detailed": 90
197
- }
198
- current_threshold = difficulty_thresholds.get(current_difficulty, 70)
199
- difficulty_mapping = {
200
- "Very Simple": "Simple",
201
- "Simple": "Moderate",
202
- "Moderate": "Detailed",
203
- "Detailed": "Very Detailed",
204
- "Very Detailed": "Very Detailed"
205
- }
206
- if final_score >= current_threshold or advance:
207
- new_difficulty = difficulty_mapping.get(current_difficulty, current_difficulty)
208
- response_msg = (f"Great job! Your final score is {final_score}, which meets the target of {current_threshold}. "
209
- f"You've advanced to {new_difficulty} difficulty.")
210
- return response_msg, new_difficulty
211
- else:
212
- response_msg = (f"Your final score is {final_score} (\n target: {current_threshold}). {hint} \n "
213
- f"Please try again at the {current_difficulty} level.")
214
- return response_msg, current_difficulty
 
 
 
 
 
 
 
 
 
 
 
 
215
  except Exception as e:
216
- return f"Error processing evaluation output: {str(e)}", current_difficulty
 
217
 
218
- def chat_respond(user_message, active_session, saved_sessions):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  """
220
  Process a new chat message.
221
- Evaluate the child's description. If the evaluation indicates advancement,
222
- update the difficulty, generate a new image (resetting image and chat), and update the difficulty label.
223
  """
224
  if not active_session.get("image"):
225
  bot_message = "Please generate an image first."
226
- updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
227
  active_session["chat"] = updated_chat
228
- return "", updated_chat, saved_sessions, active_session
229
 
230
- chat_history = active_session.get("chat", [])
231
- identified_details = active_session.get("identified_details", [])
232
- raw_evaluation = compare_details_chat_fn(user_message, "", chat_history, identified_details)
233
- current_difficulty = active_session.get("difficulty", "Very Simple")
234
- evaluation_response, updated_difficulty = evaluate_scores(raw_evaluation, current_difficulty)
235
- bot_message = evaluation_response
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- # If the child advanced, update difficulty and generate a new image
238
- if updated_difficulty != current_difficulty:
239
- # Update the active session's difficulty before generating a new prompt
240
- active_session["difficulty"] = updated_difficulty
 
 
 
241
  age = active_session.get("age", "3")
242
  autism_level = active_session.get("autism_level", "Level 1")
243
  topic_focus = active_session.get("topic_focus", "")
244
  treatment_plan = active_session.get("treatment_plan", "")
245
- new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions)
246
- new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."))
247
- active_session = new_active_session
248
- bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you."
249
- saved_sessions = new_sessions
250
- else:
251
- updated_chat = active_session.get("chat", []) + [(user_message, bot_message)]
252
- active_session["chat"] = updated_chat
253
 
254
- return "", active_session["chat"], saved_sessions, active_session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def update_sessions(saved_sessions, active_session):
257
  """
@@ -265,87 +450,240 @@ def update_sessions(saved_sessions, active_session):
265
  # Gradio Interface
266
  ##############################################
267
  with gr.Blocks() as demo:
268
- # The active session now starts with difficulty "Very Simple"
269
  active_session = gr.State({
270
  "prompt": None,
271
  "image": None,
 
272
  "chat": [],
273
  "treatment_plan": "",
274
  "topic_focus": "",
 
275
  "identified_details": [],
 
276
  "difficulty": "Very Simple",
277
  "age": "3",
278
  "autism_level": "Level 1"
279
  })
280
  saved_sessions = gr.State([])
281
-
282
- with gr.Column():
283
- gr.Markdown("# Image Generation & Chat Inference")
284
- # Display current difficulty label
285
- difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
286
-
287
- # ----- Image Generation Section -----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  with gr.Column():
289
- gr.Markdown("## Generate Image")
290
- gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.")
291
- with gr.Row():
292
- age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3")
293
- autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")
294
-
295
- topic_focus_input = gr.Textbox(
296
- label="Topic Focus",
297
- placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
298
- lines=1
299
- )
300
-
301
- treatment_plan_input = gr.Textbox(
302
- label="Treatment Plan",
303
- placeholder="Enter the treatment plan to guide the image generation...",
304
- lines=2
305
- )
306
- generate_btn = gr.Button("Generate Image")
307
- img_output = gr.Image(label="Generated Image")
308
- generate_btn.click(
309
- generate_image_and_reset_chat,
310
- inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
311
- outputs=[img_output, active_session, saved_sessions]
312
- )
313
-
314
- # ----- Chat Section -----
315
- with gr.Column():
316
- gr.Markdown("## Chat about the Image")
317
  gr.Markdown(
318
- "After generating an image, type details or descriptions about it. "
319
- "Your message, along with the generated image and conversation history, will be sent for evaluation."
320
- )
321
- chatbot = gr.Chatbot(label="Chat History")
322
- with gr.Row():
323
- chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False)
324
- send_btn = gr.Button("Send")
325
- send_btn.click(
326
- chat_respond,
327
- inputs=[chat_input, active_session, saved_sessions],
328
- outputs=[chat_input, chatbot, saved_sessions, active_session]
329
- )
330
- chat_input.submit(
331
- chat_respond,
332
- inputs=[chat_input, active_session, saved_sessions],
333
- outputs=[chat_input, chatbot, saved_sessions, active_session]
334
  )
 
335
 
336
- # ----- Sidebar Section for Session Details -----
337
- with gr.Column(variant="sidebar"):
338
- gr.Markdown("## Saved Chat Sessions")
339
- gr.Markdown(
340
- "This sidebar automatically saves finished chat sessions. "
341
- "Each session includes the prompt used, the generated image (as a data URL), "
342
- "the topic focus, the treatment plan, the list of identified details, and the full chat history."
343
  )
344
- sessions_output = gr.JSON(label="Session Details", value={})
345
- active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
346
- # Update the current difficulty label when active_session changes.
347
- active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
348
- saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
349
 
350
- # Launch the app with public sharing enabled.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  demo.launch()
 
8
  from huggingface_hub import InferenceClient
9
  from google.generativeai import configure, GenerativeModel
10
  from google.ai.generativelanguage import Content, Part
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
 
14
  # Load API keys from environment variables
15
  inference_api_key = os.environ.get("HF_TOKEN")
16
+ google_api_key = os.environ.get("GOOGLE_API_KEY")
17
 
18
  # Configure Google API
19
  configure(api_key=google_api_key)
20
 
21
+ # Global variables to store the image data URL, prompt, and detailed description
22
  global_image_data_url = None
23
+ global_image_prompt = None
24
+ global_image_description = None # New variable to store Gemini's detailed description
25
 
26
  def update_difficulty_label(active_session):
27
  return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}"
 
39
  - Autism Level: {autism_level}
40
  - Topic Focus: {topic_focus}
41
  - Treatment Plan: {treatment_plan}
 
42
  Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals.
 
43
  The image should specifically focus on the topic: "{topic_focus}".
 
44
  Please generate a prompt that instructs the image generation engine to produce an image with:
45
  1. Clarity and simplicity (minimalist backgrounds, clear subject)
46
  2. Literal representation with defined borders and consistent style
47
  3. Soft, muted colors and reduced visual complexity
48
  4. Positive, calm scenes
49
  5. Clear focus on the specified topic
 
50
  Use descriptive and detailed language.
51
  """
52
  )
 
53
  # Initialize the Gemini Pro model
54
  model = GenerativeModel('gemini-2.0-flash-lite')
 
55
  # Generate content using the Gemini model
56
  response = model.generate_content(query)
57
+ return response.text.strip()
58
 
59
+ def generate_detailed_description(image_data_url, prompt, difficulty, topic_focus):
60
+ """
61
+ Generate a detailed description of the image using Gemini Vision.
62
+ """
63
+ # Remove the data:image/png;base64, prefix to get just the base64 string
64
+ base64_img = image_data_url.split(",")[1]
65
+ query = (
66
+ f"""
67
+ You are an expert educator specializing in teaching children with autism.
68
+ Please provide a detailed description of this image that was generated based on the prompt:
69
+ "{prompt}"
70
+ The image is intended for a child with autism, focusing on the topic: "{topic_focus}" at a {difficulty} difficulty level.
71
+ In your description:
72
+ 1. List all key objects, characters, and elements present in the image
73
+ 2. Describe colors, shapes, positions, and relationships between elements
74
+ 3. Note any emotions, actions, or interactions depicted
75
+ 4. Highlight details that would be important for the child to notice
76
+ 5. Organize your description in a structured, clear way
77
+ Your description will be used as a reference to evaluate the child's observations,
78
+ so please be comprehensive but focus on observable details rather than interpretations.
79
+ """
80
+ )
81
+ # Create a Gemini Vision Pro model
82
+ vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
83
+ # Create the content with image and text
84
+ image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)})
85
+ text_part = Part(text=query)
86
+ multimodal_content = Content(parts=[image_part, text_part])
87
+ # Generate description using the vision model
88
+ response = vision_model.generate_content(multimodal_content)
89
  return response.text.strip()
90
 
91
  def generate_image_fn(selected_prompt, guidance_scale=7.5,
 
112
  global_image_data_url = f"data:image/png;base64,{img_b64}"
113
  return image
114
 
115
+ def extract_key_details(description):
116
+ """
117
+ Extract key details from Gemini's description to use for tracking.
118
+ Returns a list of key elements/details from the description.
119
+ """
120
+ # Create a query to extract key details
121
+ query = (
122
+ f"""
123
+ From the following detailed image description, extract a list of 10-15 key details that a child might identify.
124
+ Each detail should be a simple, clear phrase describing one observable element.
125
+ Description:
126
+ {description}
127
+ Format your response as a JSON array of strings, each representing one key detail.
128
+ Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"]
129
+ """
130
+ )
131
+ # Use Gemini text model to extract key details
132
+ model = GenerativeModel('gemini-2.0-flash-lite')
133
+ response = model.generate_content(query)
134
+ try:
135
+ # Parse the JSON response
136
+ details_match = re.search(r'\[.*\]', response.text, re.DOTALL)
137
+ if details_match:
138
+ details_json = details_match.group(0)
139
+ key_details = json.loads(details_json)
140
+ return key_details
141
+ else:
142
+ # If no JSON found, do basic extraction
143
+ lines = description.split('\n')
144
+ details = []
145
+ for line in lines:
146
+ if line.strip().startswith('-') or line.strip().startswith('*'):
147
+ details.append(line.strip()[1:].strip())
148
+ return details[:15] if details else ["object in image", "color", "shape", "background"]
149
+ except Exception as e:
150
+ print(f"Error extracting key details: {str(e)}")
151
+ return ["object in image", "color", "shape", "background"]
152
+
153
  def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions):
154
  """
155
  Generate a new image (with the current difficulty) and reset the chat.
156
  Now includes the topic_focus parameter to specify what the image should focus on.
157
  """
158
+ global global_image_description
159
  new_sessions = saved_sessions.copy()
160
  if active_session.get("prompt"):
161
  new_sessions.append(active_session)
162
+
163
+ # Use the current difficulty from the active session
164
  current_difficulty = active_session.get("difficulty", "Very Simple")
165
+
166
+ # Generate the prompt for the image
167
  generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan)
168
+
169
+ # Generate the image
170
  image = generate_image_fn(generated_prompt)
171
+
172
+ # Generate a detailed description of the image using Gemini Vision
173
+ image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus)
174
+ global_image_description = image_description
175
+
176
+ # Extract key details to be identified
177
+ key_details = extract_key_details(image_description)
178
+
179
+ # Create a new active session with all the necessary information
180
  new_active_session = {
181
  "prompt": generated_prompt,
182
  "image": global_image_data_url,
183
+ "image_description": image_description,
184
  "chat": [],
185
  "treatment_plan": treatment_plan,
186
  "topic_focus": topic_focus,
187
+ "key_details": key_details, # Store the list of key details
188
  "identified_details": [],
189
+ "used_hints": [],
190
  "difficulty": current_difficulty,
191
  "autism_level": autism_level,
192
  "age": age
193
  }
 
194
 
195
+ # Create the checklist of items to identify
196
+ checklist_items = []
197
+ for i, detail in enumerate(key_details):
198
+ checklist_items.append({"detail": detail, "identified": False, "id": i})
199
+
200
+ # Return the updated state and checklist
201
+ return image, new_active_session, new_sessions, checklist_items
202
+
203
+ def compare_details_chat_fn(user_details, active_session):
204
  """
205
+ Evaluate the child's description using Google's Gemini model.
206
+ Now uses the image description and tracks identified details and used hints.
207
  """
208
+ if not global_image_data_url or not global_image_description:
209
  return "Please generate an image first."
210
 
211
+ # Get the detailed image description
212
+ image_description = active_session.get("image_description", global_image_description)
213
+
214
+ # Get chat history
215
+ chat_history = active_session.get("chat", [])
216
  history_text = ""
217
  if chat_history:
218
  history_text = "\n\n### Previous Conversation:\n"
219
+ for idx, (speaker, msg) in enumerate(chat_history, 1):
220
+ history_text += f"Turn {idx}:\n{speaker}: {msg}\n"
221
 
222
+ # Get key details, identified details and used hints
223
+ key_details = active_session.get("key_details", [])
224
+ identified_details = active_session.get("identified_details", [])
225
+ used_hints = active_session.get("used_hints", [])
226
+
227
+ # Format for the API
228
+ key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details)
229
  identified_details_text = ""
230
  if identified_details:
231
  identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details)
232
+ used_hints_text = ""
233
+ if used_hints:
234
+ used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints)
235
+
236
+ # Current difficulty level
237
+ current_difficulty = active_session.get("difficulty", "Very Simple")
238
 
239
  message_text = (
240
+ f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n"
241
+ f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n"
242
+ f"### Detailed Image Description (Reference):\n{image_description}\n\n"
243
+ f"### Current Difficulty Level: {current_difficulty}\n"
244
+ f"{key_details_text}{history_text}{identified_details_text}{used_hints_text}\n\n"
245
+ f"### Child's Current Description:\n'{user_details}'\n\n"
246
+ "Evaluate the child's description compared to the key details list. Use simple, clear language. "
247
+ "Praise specific correct observations. If something important is missing, provide a gentle hint "
248
+ "that hasn't been given before.\n\n"
249
+ "Follow these guidelines:\n"
250
+ "1. DO NOT mention that you're evaluating or scoring the child.\n"
251
+ "2. Keep feedback warm, positive, and encouraging.\n"
252
+ "3. If giving a hint, make it specific but not too obvious.\n"
253
+ "4. Never repeat hints that have already been given.\n"
254
+ "5. Focus on details the child hasn't yet identified.\n"
255
+ "6. Acknowledge the child's progress.\n\n"
256
+ "Return your response as a JSON object with the following format:\n"
257
  "{\n"
258
+ " \"feedback\": \"Your encouraging response to the child\",\n"
259
+ " \"newly_identified_details\": [\"list\", \"of\", \"new details\", \"the child identified\"],\n"
260
+ " \"hint\": \"A new hint about something not yet identified\",\n"
261
+ " \"score\": <number from 0-100 based on how complete the description is>,\n"
262
+ " \"advance_difficulty\": <boolean indicating if child should advance>\n"
 
 
 
 
 
263
  "}\n\n"
264
+ "Ensure the JSON is valid and contains all fields."
265
  )
266
 
267
+ # Create a Gemini model for evaluation
268
+ model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21')
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ # Generate evaluation using the model
271
+ response = model.generate_content(message_text)
272
  return response.text
273
 
274
+ def parse_evaluation(evaluation_text, active_session):
 
 
 
 
 
275
  try:
276
  json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL)
277
  if json_match:
 
279
  evaluation = json.loads(json_str)
280
  else:
281
  raise ValueError("No JSON object found in the response.")
282
+
283
+ # Extract data from the evaluation
284
+ feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.")
285
+ newly_identified_details = evaluation.get("newly_identified_details", [])
286
+ hint = evaluation.get("hint", "")
287
+ score = evaluation.get("score", 0)
288
+ advance_difficulty = evaluation.get("advance_difficulty", False)
289
+
290
+ # Update the session with newly identified details
291
+ identified_details = active_session.get("identified_details", [])
292
+ for detail in newly_identified_details:
293
+ if detail not in identified_details:
294
+ identified_details.append(detail)
295
+ active_session["identified_details"] = identified_details
296
+
297
+ # Add the hint to used hints if one was provided
298
+ if hint:
299
+ used_hints = active_session.get("used_hints", [])
300
+ if hint not in used_hints:
301
+ used_hints.append(hint)
302
+ active_session["used_hints"] = used_hints
303
+
304
+ # Add the hint to the feedback if it's not already included
305
+ if hint.strip() and hint.strip() not in feedback:
306
+ feedback += f"\n\nπŸ’‘ Hint: {hint}"
307
+
308
+ # Get current difficulty and check if it should be advanced
309
+ current_difficulty = active_session.get("difficulty", "Very Simple")
310
+ should_advance = False
311
+
312
+ if advance_difficulty:
313
+ difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"]
314
+ current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0
315
+ if current_index < len(difficulties) - 1:
316
+ current_difficulty = difficulties[current_index + 1]
317
+ should_advance = True
318
+
319
+ return feedback, current_difficulty, should_advance, newly_identified_details
320
+
321
  except Exception as e:
322
+ print(f"Error processing evaluation: {str(e)}")
323
+ return f"That's interesting! Can you tell me more about what you see?", active_session.get("difficulty", "Very Simple"), False, []
324
 
325
+ def update_checklist(checklist, newly_identified, key_details):
326
+ """
327
+ Update the checklist based on newly identified details.
328
+ Returns an updated checklist.
329
+ """
330
+ new_checklist = []
331
+ for item in checklist:
332
+ detail = item["detail"]
333
+ # Check if this detail has been identified
334
+ is_identified = item["identified"]
335
+
336
+ # If newly identified, update status
337
+ for identified in newly_identified:
338
+ # Check if the identified detail matches or is similar to the key detail
339
+ if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or
340
+ any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)):
341
+ is_identified = True
342
+ break
343
+
344
+ new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]})
345
+
346
+ return new_checklist
347
+
348
+ def chat_respond(user_message, active_session, saved_sessions, checklist):
349
  """
350
  Process a new chat message.
351
+ Evaluate the child's description, update identified details, and advance difficulty if needed.
 
352
  """
353
  if not active_session.get("image"):
354
  bot_message = "Please generate an image first."
355
+ updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)]
356
  active_session["chat"] = updated_chat
357
+ return "", updated_chat, saved_sessions, active_session, checklist, None # Return None for image
358
 
359
+ # Get the evaluation from Gemini
360
+ raw_evaluation = compare_details_chat_fn(user_message, active_session)
361
+
362
+ # Parse the evaluation and update session
363
+ feedback, updated_difficulty, should_advance, newly_identified = parse_evaluation(raw_evaluation, active_session)
364
+
365
+ # Update the checklist with newly identified details
366
+ updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", []))
367
+
368
+ # Add the current exchange to the chat history
369
+ updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)]
370
+ active_session["chat"] = updated_chat
371
+
372
+ # Check if all items have been identified
373
+ all_identified = all(item["identified"] for item in updated_checklist)
374
+
375
+ # Modify this line to generate new image when all details are identified
376
+ should_generate_new_image = should_advance or all_identified
377
 
378
+ # If the child should advance to a new difficulty or has identified all details
379
+ if should_generate_new_image:
380
+ # Save the current session
381
+ new_sessions = saved_sessions.copy()
382
+ new_sessions.append(active_session.copy())
383
+
384
+ # Get parameters for generating new image
385
  age = active_session.get("age", "3")
386
  autism_level = active_session.get("autism_level", "Level 1")
387
  topic_focus = active_session.get("topic_focus", "")
388
  treatment_plan = active_session.get("treatment_plan", "")
 
 
 
 
 
 
 
 
389
 
390
+ # Use current difficulty if not advancing, otherwise use updated difficulty
391
+ difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple")
392
+
393
+ # Generate a new prompt with the difficulty
394
+ generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan)
395
+
396
+ # Generate the new image - returns a PIL Image
397
+ new_image = generate_image_fn(generated_prompt)
398
+
399
+ # Now the global_image_data_url should be updated
400
+
401
+ # Generate a detailed description of the image using Gemini Vision
402
+ image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus)
403
+
404
+ # Extract key details to be identified
405
+ key_details = extract_key_details(image_description)
406
+
407
+ # Create fresh active session with the new image
408
+ new_active_session = {
409
+ "prompt": generated_prompt,
410
+ "image": global_image_data_url,
411
+ "image_description": image_description,
412
+ "chat": [],
413
+ "treatment_plan": treatment_plan,
414
+ "topic_focus": topic_focus,
415
+ "key_details": key_details,
416
+ "identified_details": [],
417
+ "used_hints": [],
418
+ "difficulty": difficulty_to_use,
419
+ "autism_level": autism_level,
420
+ "age": age
421
+ }
422
+
423
+ # Create new checklist for the new image
424
+ new_checklist = []
425
+ for i, detail in enumerate(key_details):
426
+ new_checklist.append({"detail": detail, "identified": False, "id": i})
427
+
428
+ # Initialize the new chat with an appropriate message
429
+ if updated_difficulty != active_session.get("difficulty", "Very Simple"):
430
+ advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe."
431
+ else:
432
+ advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level."
433
+
434
+ new_active_session["chat"] = [("System", advancement_message)]
435
+
436
+ return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image
437
+
438
+ # If not advancing, return None for the image to indicate no change
439
+ return "", updated_chat, saved_sessions, active_session, updated_checklist, None
440
 
441
  def update_sessions(saved_sessions, active_session):
442
  """
 
450
  # Gradio Interface
451
  ##############################################
452
  with gr.Blocks() as demo:
453
+ # Initialize the active session with default values
454
  active_session = gr.State({
455
  "prompt": None,
456
  "image": None,
457
+ "image_description": None,
458
  "chat": [],
459
  "treatment_plan": "",
460
  "topic_focus": "",
461
+ "key_details": [],
462
  "identified_details": [],
463
+ "used_hints": [],
464
  "difficulty": "Very Simple",
465
  "age": "3",
466
  "autism_level": "Level 1"
467
  })
468
  saved_sessions = gr.State([])
469
+ checklist_state = gr.State([])
470
+
471
+ with gr.Row():
472
+ # Main content area
473
+ with gr.Column(scale=2):
474
+ gr.Markdown("# Autism Education Image Description Tool")
475
+ # Display current difficulty label
476
+ difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple")
477
+
478
+ # ----- Image Generation Section -----
479
+ with gr.Column():
480
+ gr.Markdown("## Generate Image")
481
+ gr.Markdown("Enter the child's details to generate an appropriate educational image.")
482
+ with gr.Row():
483
+ age_input = gr.Textbox(label="Child's Age", placeholder="Enter age...", value="3")
484
+ autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1")
485
+ topic_focus_input = gr.Textbox(
486
+ label="Topic Focus",
487
+ placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...",
488
+ lines=1
489
+ )
490
+ treatment_plan_input = gr.Textbox(
491
+ label="Treatment Plan",
492
+ placeholder="Enter the treatment plan to guide the image generation...",
493
+ lines=2
494
+ )
495
+ generate_btn = gr.Button("Generate Image")
496
+ img_output = gr.Image(label="Generated Image")
497
+
498
+ # ----- Chat Section -----
499
+ with gr.Column():
500
+ gr.Markdown("## Image Description Practice")
501
+ gr.Markdown(
502
+ "After generating an image, ask the child to describe what they see. "
503
+ "Type their description in the box below. The system will provide supportive feedback "
504
+ "and track their progress in identifying details."
505
+ )
506
+ chatbot = gr.Chatbot(label="Conversation History")
507
+ with gr.Row():
508
+ chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True)
509
+ send_btn = gr.Button("Submit")
510
+
511
+ # Sidebar - Checklist of items to identify
512
+ with gr.Column(scale=1):
513
+ gr.Markdown("## Details to Identify")
514
+ gr.Markdown("The child should try to identify these elements in the image:")
515
+
516
+ # Create a custom HTML component to display the checklist with checkboxes
517
+ checklist_html = gr.HTML("""
518
+ <div id="checklist-container">
519
+ <p>Generate an image to see details to identify.</p>
520
+ </div>
521
+ """)
522
+
523
+ # Add a function to update the checklist HTML
524
+ def update_checklist_html(checklist):
525
+ if not checklist:
526
+ return """
527
+ <div id="checklist-container">
528
+ <p>Generate an image to see details to identify.</p>
529
+ </div>
530
+ """
531
+
532
+ html_content = """
533
+ <div id="checklist-container" style="padding: 10px;">
534
+ <style>
535
+ .checklist-item {
536
+ display: flex;
537
+ align-items: center;
538
+ margin-bottom: 10px;
539
+ padding: 8px;
540
+ border-radius: 5px;
541
+ transition: background-color 0.3s;
542
+ }
543
+ .identified {
544
+ background-color: #e6f7e6;
545
+ text-decoration: line-through;
546
+ color: #4CAF50;
547
+ }
548
+ .not-identified {
549
+ background-color: #f5f5f5;
550
+ }
551
+ .checkmark {
552
+ margin-right: 10px;
553
+ font-size: 1.2em;
554
+ }
555
+ </style>
556
+ """
557
+
558
+ for item in checklist:
559
+ detail = item["detail"]
560
+ identified = item["identified"]
561
+ css_class = "identified" if identified else "not-identified"
562
+ checkmark = "βœ…" if identified else "⬜"
563
+
564
+ html_content += f"""
565
+ <div class="checklist-item {css_class}">
566
+ <span class="checkmark">{checkmark}</span>
567
+ <span>{detail}</span>
568
+ </div>
569
+ """
570
+
571
+ html_content += """
572
+ </div>
573
+ """
574
+ return html_content
575
+
576
+ # Progress summary
577
+ progress_html = gr.HTML("""
578
+ <div id="progress-container">
579
+ <p>No active session.</p>
580
+ </div>
581
+ """)
582
+
583
+ def update_progress_html(checklist):
584
+ if not checklist:
585
+ return """
586
+ <div id="progress-container">
587
+ <p>No active session.</p>
588
+ </div>
589
+ """
590
+
591
+ total_items = len(checklist)
592
+ identified_items = sum(1 for item in checklist if item["identified"])
593
+ percentage = (identified_items / total_items) * 100 if total_items > 0 else 0
594
+
595
+ progress_bar_width = f"{percentage}%"
596
+ all_identified = identified_items == total_items
597
+
598
+ html_content = f"""
599
+ <div id="progress-container" style="padding: 10px;">
600
+ <h3>Progress: {identified_items} / {total_items} details</h3>
601
+ <div style="width: 100%; background-color: #f1f1f1; border-radius: 5px; margin-bottom: 10px;">
602
+ <div style="width: {progress_bar_width}; height: 24px; background-color: #4CAF50; border-radius: 5px;"></div>
603
+ </div>
604
+ <p style="font-size: 16px; font-weight: bold; text-align: center;">
605
+ """
606
+
607
+ if all_identified:
608
+ html_content += "πŸŽ‰ Amazing! All details identified! πŸŽ‰"
609
+ elif percentage >= 75:
610
+ html_content += "Almost there! Keep going!"
611
+ elif percentage >= 50:
612
+ html_content += "Halfway there! You're doing great!"
613
+ elif percentage >= 25:
614
+ html_content += "Good start! Keep looking!"
615
+ else:
616
+ html_content += "Let's find more details!"
617
+
618
+ html_content += """
619
+ </p>
620
+ </div>
621
+ """
622
+ return html_content
623
+
624
+ # ----- Session Details Section -----
625
+ with gr.Row():
626
  with gr.Column():
627
+ gr.Markdown("## Progress Tracking")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  gr.Markdown(
629
+ "This section tracks the child's progress across sessions. "
630
+ "Each session includes the difficulty level, identified details, "
631
+ "and the full conversation history."
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  )
633
+ sessions_output = gr.JSON(label="Session Details", value={})
634
 
635
+ # Process chat and update image as needed
636
+ def process_chat_and_image(user_msg, active_session, saved_sessions, checklist):
637
+ chat_input, chatbot, new_sessions, new_active_session, new_checklist, new_image = chat_respond(
638
+ user_msg, active_session, saved_sessions, checklist
 
 
 
639
  )
 
 
 
 
 
640
 
641
+ # Only return a new image if one was generated (advancement case)
642
+ if new_image is not None:
643
+ return chat_input, chatbot, new_sessions, new_active_session, new_checklist, new_image
644
+ else:
645
+ # Return a no-update flag for the image to keep the current one
646
+ return chat_input, chatbot, new_sessions, new_active_session, new_checklist, gr.update()
647
+
648
+ # Connect event handlers
649
+ generate_btn.click(
650
+ generate_image_and_reset_chat,
651
+ inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions],
652
+ outputs=[img_output, active_session, saved_sessions, checklist_state]
653
+ )
654
+
655
+ send_btn.click(
656
+ process_chat_and_image,
657
+ inputs=[chat_input, active_session, saved_sessions, checklist_state],
658
+ outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
659
+ )
660
+
661
+ chat_input.submit(
662
+ process_chat_and_image,
663
+ inputs=[chat_input, active_session, saved_sessions, checklist_state],
664
+ outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output]
665
+ )
666
+
667
+ # Update the checklist HTML when checklist state changes
668
+ checklist_state.change(
669
+ update_checklist_html,
670
+ inputs=[checklist_state],
671
+ outputs=[checklist_html]
672
+ )
673
+
674
+ # Update the progress HTML when checklist state changes
675
+ checklist_state.change(
676
+ update_progress_html,
677
+ inputs=[checklist_state],
678
+ outputs=[progress_html]
679
+ )
680
+
681
+ # Update the current difficulty label when active_session changes
682
+ active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label])
683
+
684
+ # Update sessions when active_session or saved_sessions change
685
+ active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
686
+ saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output)
687
+
688
+ # Launch the app
689
  demo.launch()