import gradio as gr import io import base64 import os import json import re from PIL import Image from huggingface_hub import InferenceClient from google.generativeai import configure, GenerativeModel from google.ai.generativelanguage import Content, Part # Load API keys from environment variables inference_api_key = os.environ.get("HF_TOKEN") google_api_key = os.environ.get("GOOGLE_API_KEY") # Configure Google API configure(api_key=google_api_key) # Global variables to store the image data URL, prompt, and detailed description global_image_data_url = None global_image_prompt = None global_image_description = None # New variable to store Gemini's detailed description def update_difficulty_label(active_session): return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}" def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, treatment_plan=""): """ Generate an image prompt using Google's Gemini model. """ query = ( f""" Follow the instructions below to generate an image generation prompt for an educational image intended for autistic children. Consider the following parameters: - Difficulty: {difficulty} - Age: {age} - Autism Level: {autism_level} - Topic Focus: {topic_focus} - Treatment Plan: {treatment_plan} Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals. The image should specifically focus on the topic: "{topic_focus}". Please generate a prompt that instructs the image generation engine to produce an image with: 1. Clarity and simplicity (minimalist backgrounds, clear subject) 2. Literal representation with defined borders and consistent style 3. Soft, muted colors and reduced visual complexity 4. Positive, calm scenes 5. Clear focus on the specified topic Use descriptive and detailed language. """ ) model = GenerativeModel('gemini-2.0-flash') response = model.generate_content(query) return response.text.strip() def generate_detailed_description(image_data_url, prompt, difficulty, topic_focus): """ Generate a detailed description of the image using Gemini Vision. """ base64_img = image_data_url.split(",")[1] query = ( f""" You are an expert educator specializing in teaching children with autism. Please provide a detailed description of this image that was generated based on the prompt: "{prompt}" The image is intended for a child with autism, focusing on the topic: "{topic_focus}" at a {difficulty} difficulty level. In your description: 1. List all key objects, characters, and elements present in the image 2. Describe colors, shapes, positions, and relationships between elements 3. Note any emotions, actions, or interactions depicted 4. Highlight details that would be important for the child to notice 5. Organize your description in a structured, clear way Your description will be used as a reference to evaluate the child's observations, so please be comprehensive but focus on observable details rather than interpretations. """ ) vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21') image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)}) text_part = Part(text=query) multimodal_content = Content(parts=[image_part, text_part]) response = vision_model.generate_content(multimodal_content) return response.text.strip() def extract_key_details(image_data_url, prompt, topic_focus): """ Extract key details directly from the image using Gemini Vision. Returns a list of key elements/details from the image. """ base64_img = image_data_url.split(",")[1] query = ( f""" You are analyzing an educational image created for a child with autism, based on the prompt: "{prompt}". The image focuses on the topic: "{topic_focus}". Please extract a list of 10-15 key details that a child might identify in this image. Each detail should be a simple, clear phrase describing one observable element. Focus on concrete, visible elements rather than abstract concepts. Format your response as a JSON array of strings, each representing one key detail. Example format: ["red ball on the grass", "smiling girl with brown hair", "blue sky with clouds"] Ensure each detail is: 1. Directly observable in the image 2. Unique (not a duplicate) 3. Described in simple, concrete language 4. Relevant to what a child would notice """ ) vision_model = GenerativeModel('gemini-2.0-flash') image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)}) text_part = Part(text=query) multimodal_content = Content(parts=[image_part, text_part]) response = vision_model.generate_content(multimodal_content) try: details_match = re.search(r'\[.*\]', response.text, re.DOTALL) if details_match: details_json = details_match.group(0) key_details = json.loads(details_json) return key_details else: # If no JSON array is found, try to extract bullet points or lines lines = response.text.split('\n') details = [] for line in lines: if line.strip().startswith('-') or line.strip().startswith('*'): details.append(line.strip()[1:].strip()) return details[:15] if details else ["object in image", "color", "shape", "background"] except Exception as e: print(f"Error extracting key details: {str(e)}") return ["object in image", "color", "shape", "background"] def generate_image_fn(selected_prompt, guidance_scale=7.5, negative_prompt="ugly, blurry, poorly drawn hands, nude, deformed, missing limbs, missing body parts", num_inference_steps=45): """ Generate an image from the prompt via the Hugging Face Inference API. Convert the image to a data URL. """ global global_image_data_url, global_image_prompt global_image_prompt = selected_prompt image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key) image = image_client.text_to_image( selected_prompt, model="stabilityai/stable-diffusion-3.5-large-turbo", guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps ) buffered = io.BytesIO() image.save(buffered, format="PNG") img_bytes = buffered.getvalue() img_b64 = base64.b64encode(img_bytes).decode("utf-8") global_image_data_url = f"data:image/png;base64,{img_b64}" return image def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, attempt_limit_input, active_session, saved_sessions): """ Generate a new image (with the current difficulty) and reset the chat. Also resets the attempt count and uses the user-entered attempt limit. """ global global_image_description new_sessions = saved_sessions.copy() if active_session.get("prompt"): new_sessions.append(active_session) current_difficulty = active_session.get("difficulty", "Very Simple") generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan) image = generate_image_fn(generated_prompt) image_description = generate_detailed_description(global_image_data_url, generated_prompt, current_difficulty, topic_focus) global_image_description = image_description key_details = extract_key_details(global_image_data_url, generated_prompt, topic_focus) new_active_session = { "prompt": generated_prompt, "image": global_image_data_url, "image_description": image_description, "chat": [], "treatment_plan": treatment_plan, "topic_focus": topic_focus, "key_details": key_details, "identified_details": [], "used_hints": [], "difficulty": current_difficulty, "autism_level": autism_level, "age": age, "attempt_limit": int(attempt_limit_input) if attempt_limit_input else 3, "attempt_count": 0 } checklist_items = [] for i, detail in enumerate(key_details): checklist_items.append({"detail": detail, "identified": False, "id": i}) return image, new_active_session, new_sessions, checklist_items def compare_details_chat_fn(user_details, active_session): """ Evaluate the child's description using Google's Gemini model. """ if not global_image_data_url or not global_image_description: return "Please generate an image first." image_description = active_session.get("image_description", global_image_description) chat_history = active_session.get("chat", []) history_text = "" if chat_history: history_text = "\n\n### Previous Conversation:\n" for idx, (speaker, msg) in enumerate(chat_history, 1): history_text += f"Turn {idx}:\n{speaker}: {msg}\n" key_details = active_session.get("key_details", []) identified_details = active_session.get("identified_details", []) used_hints = active_session.get("used_hints", []) key_details_text = "\n\n### Key Details to Identify:\n" + "\n".join(f"- {detail}" for detail in key_details) identified_details_text = "" if identified_details: identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details) used_hints_text = "" if used_hints: used_hints_text = "\n\n### Previously Given Hints:\n" + "\n".join(f"- {hint}" for hint in used_hints) current_difficulty = active_session.get("difficulty", "Very Simple") message_text = ( f"You are a kind and encouraging teacher helping a child with autism describe an image.\n\n" f"### Image Prompt:\n{active_session.get('prompt', 'No prompt available')}\n\n" f"### Detailed Image Description (Reference):\n{image_description}\n\n" f"### Current Difficulty Level: {current_difficulty}\n" f"{key_details_text}{history_text}{identified_details_text}{used_hints_text}\n\n" f"### Child's Current Description:\n'{user_details}'\n\n" "Evaluate the child's description compared to the key details list. Use simple, clear language. " "Praise specific correct observations. If something important is missing, provide a gentle hint " "that hasn't been given before.\n\n" "Follow these guidelines:\n" "1. DO NOT mention that you're evaluating or scoring the child.\n" "2. Keep feedback warm, positive, and encouraging.\n" "3. If giving a hint, make it specific but not too obvious.\n" "4. Never repeat hints that have already been given.\n" "5. Focus on details the child hasn't yet identified.\n" "6. Acknowledge the child's progress.\n\n" "Return your response as a JSON object with the following format:\n" "{\n" " \"feedback\": \"Your encouraging response to the child\",\n" " \"newly_identified_details\": [\"list\", \"of\", \"new details\", \"the child identified\"],\n" " \"hint\": \"A new hint about something not yet identified\",\n" " \"score\": ,\n" " \"advance_difficulty\": \n" "}\n\n" "Ensure the JSON is valid and contains all fields." ) model = GenerativeModel('gemini-2.0-flash') response = model.generate_content(message_text) return response.text def parse_evaluation(evaluation_text, active_session): """ Parse the evaluation JSON and return feedback, updated difficulty, whether to advance, newly identified details, and the score. """ try: json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL) if json_match: json_str = json_match.group(0) evaluation = json.loads(json_str) else: raise ValueError("No JSON object found in the response.") feedback = evaluation.get("feedback", "Great effort! Keep describing what you see.") newly_identified_details = evaluation.get("newly_identified_details", []) hint = evaluation.get("hint", "") score = evaluation.get("score", 0) advance_difficulty = evaluation.get("advance_difficulty", False) identified_details = active_session.get("identified_details", []) for detail in newly_identified_details: if detail not in identified_details: identified_details.append(detail) active_session["identified_details"] = identified_details if hint: used_hints = active_session.get("used_hints", []) if hint not in used_hints: used_hints.append(hint) active_session["used_hints"] = used_hints if hint.strip() and hint.strip() not in feedback: feedback += f"\n\nšŸ’” Hint: {hint}" current_difficulty = active_session.get("difficulty", "Very Simple") should_advance = False if advance_difficulty: difficulties = ["Very Simple", "Simple", "Moderate", "Detailed", "Very Detailed"] current_index = difficulties.index(current_difficulty) if current_difficulty in difficulties else 0 if current_index < len(difficulties) - 1: current_difficulty = difficulties[current_index + 1] should_advance = True return feedback, current_difficulty, should_advance, newly_identified_details, score except Exception as e: print(f"Error processing evaluation: {str(e)}") return ("That's interesting! Can you tell me more about what you see?", active_session.get("difficulty", "Very Simple"), False, [], 0) def update_checklist(checklist, newly_identified, key_details): """ Update the checklist based on newly identified details. """ new_checklist = [] for item in checklist: detail = item["detail"] is_identified = item["identified"] for identified in newly_identified: if (identified.lower() in detail.lower() or detail.lower() in identified.lower() or any(word for word in identified.lower().split() if word in detail.lower() and len(word) > 3)): is_identified = True break new_checklist.append({"detail": detail, "identified": is_identified, "id": item["id"]}) return new_checklist def chat_respond(user_message, active_session, saved_sessions, checklist): """ Process a new chat message. Evaluate the child's description, update identified details, and advance difficulty if needed. Only increment the attempt count if no new details were identified. """ if not active_session.get("image"): bot_message = "Please generate an image first." updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", bot_message)] active_session["chat"] = updated_chat return "", updated_chat, saved_sessions, active_session, checklist, None raw_evaluation = compare_details_chat_fn(user_message, active_session) feedback, updated_difficulty, should_advance, newly_identified, score = parse_evaluation(raw_evaluation, active_session) # Only count a failed attempt if no new details were identified if not newly_identified: active_session["attempt_count"] = active_session.get("attempt_count", 0) + 1 updated_checklist = update_checklist(checklist, newly_identified, active_session.get("key_details", [])) updated_chat = active_session.get("chat", []) + [("Child", user_message), ("Teacher", feedback)] active_session["chat"] = updated_chat all_identified = all(item["identified"] for item in updated_checklist) attempts_exhausted = active_session.get("attempt_count", 0) >= active_session.get("attempt_limit", 3) should_generate_new_image = should_advance or all_identified or attempts_exhausted if should_generate_new_image: new_sessions = saved_sessions.copy() new_sessions.append(active_session.copy()) age = active_session.get("age", "3") autism_level = active_session.get("autism_level", "Level 1") topic_focus = active_session.get("topic_focus", "") treatment_plan = active_session.get("treatment_plan", "") difficulty_to_use = updated_difficulty if updated_difficulty != active_session.get("difficulty", "Very Simple") else active_session.get("difficulty", "Very Simple") generated_prompt = generate_prompt_from_options(difficulty_to_use, age, autism_level, topic_focus, treatment_plan) new_image = generate_image_fn(generated_prompt) image_description = generate_detailed_description(global_image_data_url, generated_prompt, difficulty_to_use, topic_focus) key_details = extract_key_details(global_image_data_url, generated_prompt, topic_focus) new_active_session = { "prompt": generated_prompt, "image": global_image_data_url, "image_description": image_description, "chat": [], "treatment_plan": treatment_plan, "topic_focus": topic_focus, "key_details": key_details, "identified_details": [], "used_hints": [], "difficulty": difficulty_to_use, "autism_level": autism_level, "age": age, "attempt_limit": active_session.get("attempt_limit", 3), "attempt_count": 0 } new_checklist = [] for i, detail in enumerate(key_details): new_checklist.append({"detail": detail, "identified": False, "id": i}) if attempts_exhausted: advancement_message = "You've used all your allowed attempts. Let's try a new image." elif updated_difficulty != active_session.get("difficulty", "Very Simple"): advancement_message = f"Congratulations! You've advanced to {updated_difficulty} difficulty! Here's a new image to describe." else: advancement_message = "Great job identifying all the details! Here's a new image at the same difficulty level." new_active_session["chat"] = [("System", advancement_message)] return "", new_active_session["chat"], new_sessions, new_active_session, new_checklist, new_image return "", updated_chat, saved_sessions, active_session, updated_checklist, None def update_sessions(saved_sessions, active_session): """ Combine finished sessions with the active session for display. """ if active_session and active_session.get("prompt"): return saved_sessions + [active_session] return saved_sessions ############################################## # Gradio Interface ############################################## with gr.Blocks() as demo: active_session = gr.State({ "prompt": None, "image": None, "image_description": None, "chat": [], "treatment_plan": "", "topic_focus": "", "key_details": [], "identified_details": [], "used_hints": [], "difficulty": "Very Simple", "age": "3", "autism_level": "Level 1", "attempt_limit": 3, "attempt_count": 0 }) saved_sessions = gr.State([]) checklist_state = gr.State([]) with gr.Row(): with gr.Column(scale=2): gr.Markdown("# Autism Education Image Description Tool") difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple") with gr.Column(): gr.Markdown("## Generate Image") gr.Markdown("Enter the child's details to generate an appropriate educational image.") with gr.Row(): age_input = gr.Textbox(label="Child's Age", placeholder="Enter age...", value="3") autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1") topic_focus_input = gr.Textbox( label="Topic Focus", placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...", lines=1 ) treatment_plan_input = gr.Textbox( label="Treatment Plan", placeholder="Enter the treatment plan to guide the image generation...", lines=2 ) attempt_limit_input = gr.Number(label="Allowed Attempts", value=3, precision=0) generate_btn = gr.Button("Generate Image") img_output = gr.Image(label="Generated Image") with gr.Column(): gr.Markdown("## Image Description Practice") gr.Markdown( "After generating an image, ask the child to describe what they see. " "Type their description in the box below. The system will provide supportive feedback " "and track their progress in identifying details." ) chatbot = gr.Chatbot(label="Conversation History") with gr.Row(): chat_input = gr.Textbox(label="Child's Description", placeholder="Type what the child says about the image...", show_label=True) send_btn = gr.Button("Submit") with gr.Column(scale=1): gr.Markdown("## Details to Identify") gr.Markdown("The child should try to identify these elements in the image:") checklist_html = gr.HTML("""

Generate an image to see details to identify.

""") attempt_counter_html = gr.HTML("""

Attempts: 0/3

""") def update_checklist_html(checklist): if not checklist: return """

Generate an image to see details to identify.

""" html_content = """
""" for item in checklist: detail = item["detail"] identified = item["identified"] css_class = "identified" if identified else "not-identified" checkmark = "āœ…" if identified else "ā¬œ" html_content += f"""
{checkmark} {detail}
""" html_content += """
""" return html_content progress_html = gr.HTML("""

No active session.

""") def update_progress_html(checklist): if not checklist: return """

No active session.

""" total_items = len(checklist) identified_items = sum(1 for item in checklist if item["identified"]) percentage = (identified_items / total_items) * 100 if total_items > 0 else 0 progress_bar_width = f"{percentage}%" all_identified = identified_items == total_items html_content = f"""

Progress: {identified_items} / {total_items} details

""" if all_identified: html_content += "šŸŽ‰ Amazing! All details identified! šŸŽ‰" elif percentage >= 75: html_content += "Almost there! Keep going!" elif percentage >= 50: html_content += "Halfway there! You're doing great!" elif percentage >= 25: html_content += "Good start! Keep looking!" else: html_content += "Let's find more details!" html_content += """

""" return html_content def update_attempt_counter(active_session): current_count = active_session.get("attempt_count", 0) limit = active_session.get("attempt_limit", 3) return f"""

Attempts: {current_count}/{limit}

""" with gr.Row(): with gr.Column(): gr.Markdown("## Progress Tracking") gr.Markdown( "This section tracks the child's progress across sessions. " "Each session includes the difficulty level, identified details, " "and the full conversation history." ) sessions_output = gr.JSON(label="Session Details", value={}) def process_chat_and_image(user_msg, active_session, saved_sessions, checklist): chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, new_image = chat_respond( user_msg, active_session, saved_sessions, checklist ) if new_image is not None: return chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, new_image else: return chat_input_val, chatbot_val, new_sessions, new_active_session, new_checklist, gr.update() generate_btn.click( generate_image_and_reset_chat, inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, attempt_limit_input, active_session, saved_sessions], outputs=[img_output, active_session, saved_sessions, checklist_state] ) send_btn.click( process_chat_and_image, inputs=[chat_input, active_session, saved_sessions, checklist_state], outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output] ) chat_input.submit( process_chat_and_image, inputs=[chat_input, active_session, saved_sessions, checklist_state], outputs=[chat_input, chatbot, saved_sessions, active_session, checklist_state, img_output] ) checklist_state.change( update_checklist_html, inputs=[checklist_state], outputs=[checklist_html] ) checklist_state.change( update_progress_html, inputs=[checklist_state], outputs=[progress_html] ) active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label]) active_session.change(update_attempt_counter, inputs=[active_session], outputs=[attempt_counter_html]) active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output) saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output) demo.launch()