import gradio as gr import io import base64 import os import json import re from PIL import Image from huggingface_hub import InferenceClient from google.generativeai import configure, GenerativeModel from google.ai.generativelanguage import Content, Part # Load API keys from environment variables inference_api_key = os.environ.get("HF_TOKEN") google_api_key = os.environ.get("GOOGLE_API_KEY") # New Google API key # Configure Google API configure(api_key=google_api_key) # Global variables to store the image data URL and prompt for the currently generated image. global_image_data_url = None global_image_prompt = None # Still stored if needed elsewhere def update_difficulty_label(active_session): return f"**Current Difficulty:** {active_session.get('difficulty', 'Very Simple')}" def generate_prompt_from_options(difficulty, age, autism_level, topic_focus, treatment_plan=""): """ Generate an image prompt using Google's Gemini model. """ query = ( f""" Follow the instructions below to generate an image generation prompt for an educational image intended for autistic children. Consider the following parameters: - Difficulty: {difficulty} - Age: {age} - Autism Level: {autism_level} - Topic Focus: {topic_focus} - Treatment Plan: {treatment_plan} Emphasize that the image should be clear, calming, and support understanding and communication. The style should match the difficulty level: for example, "Very Simple" produces very basic visuals while "Very Detailed" produces rich visuals. The image should specifically focus on the topic: "{topic_focus}". Please generate a prompt that instructs the image generation engine to produce an image with: 1. Clarity and simplicity (minimalist backgrounds, clear subject) 2. Literal representation with defined borders and consistent style 3. Soft, muted colors and reduced visual complexity 4. Positive, calm scenes 5. Clear focus on the specified topic Use descriptive and detailed language. """ ) # Initialize the Gemini Pro model model = GenerativeModel('gemini-2.0-flash-lite') # Generate content using the Gemini model response = model.generate_content(query) return response.text.strip() def generate_image_fn(selected_prompt, guidance_scale=7.5, negative_prompt="ugly, blurry, poorly drawn hands, lewd, nude, deformed, missing limbs, missing eyes, missing arms, missing legs", num_inference_steps=50): """ Generate an image from the prompt via the Hugging Face Inference API. Convert the image to a data URL. """ global global_image_data_url, global_image_prompt global_image_prompt = selected_prompt image_client = InferenceClient(provider="hf-inference", api_key=inference_api_key) image = image_client.text_to_image( selected_prompt, model="stabilityai/stable-diffusion-3.5-large-turbo", guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps ) buffered = io.BytesIO() image.save(buffered, format="PNG") img_bytes = buffered.getvalue() img_b64 = base64.b64encode(img_bytes).decode("utf-8") global_image_data_url = f"data:image/png;base64,{img_b64}" return image def generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions): """ Generate a new image (with the current difficulty) and reset the chat. Now includes the topic_focus parameter to specify what the image should focus on. """ new_sessions = saved_sessions.copy() if active_session.get("prompt"): new_sessions.append(active_session) # Use the current difficulty from the active session (which should be updated if advanced) current_difficulty = active_session.get("difficulty", "Very Simple") generated_prompt = generate_prompt_from_options(current_difficulty, age, autism_level, topic_focus, treatment_plan) image = generate_image_fn(generated_prompt) new_active_session = { "prompt": generated_prompt, "image": global_image_data_url, "chat": [], "treatment_plan": treatment_plan, "topic_focus": topic_focus, "identified_details": [], "difficulty": current_difficulty, "autism_level": autism_level, "age": age } return image, new_active_session, new_sessions def compare_details_chat_fn(user_details, treatment_plan, chat_history, identified_details): """ Evaluate the child's description using Google's Gemini Vision model. """ if not global_image_data_url: return "Please generate an image first." history_text = "" if chat_history: history_text = "\n\n### Previous Conversation:\n" for idx, (user_msg, bot_msg) in enumerate(chat_history, 1): history_text += f"Turn {idx}:\nUser: {user_msg}\nTeacher: {bot_msg}\n" identified_details_text = "" if identified_details: identified_details_text = "\n\n### Previously Identified Details:\n" + "\n".join(f"- {detail}" for detail in identified_details) message_text = ( f"{history_text}{identified_details_text}\n\n" f"Based on the image provided above, please evaluate the following description given by the child:\n" f"'{user_details}'\n\n" "You are a kind and encouraging teacher speaking to a child. Use simple, clear language. " "Praise the child's correct observations and provide a gentle hint if something is missing. " "Keep your feedback positive and easy to understand.\n\n" "Focus on these evaluation criteria:\n" "1. **Object Identification** – Did the child mention the main objects?\n" "2. **Color & Shape Accuracy** – Were the colors and shapes described correctly?\n" "3. **Clarity & Simplicity** – Was the description clear and easy to understand?\n" "4. **Overall Communication** – How well did the child communicate their thoughts?\n\n" "Note: As difficulty increases, the expected level of detail is higher. Evaluate accordingly.\n\n" "Return your evaluation strictly as a JSON object with the following keys:\n" "{\n" " \"scores\": {\n" " \"object_identification\": ,\n" " \"color_shape_accuracy\": ,\n" " \"clarity_simplicity\": ,\n" " \"overall_communication\": \n" " },\n" " \"final_score\": ,\n" " \"feedback\": \"\",\n" " \"hint\": \"\",\n" " \"advance\": \n" "}\n\n" "Do not include any additional text outside the JSON." ) # Remove the data:image/png;base64, prefix to get just the base64 string base64_img = global_image_data_url.split(",")[1] # Create a Gemini Vision Pro model vision_model = GenerativeModel('gemini-2.0-flash-thinking-exp-01-21') # Create the content with image and text using the correct parameters # Use 'inline_data' instead of 'content' for the image part image_part = Part(inline_data={"mime_type": "image/png", "data": base64.b64decode(base64_img)}) text_part = Part(text=message_text) multimodal_content = Content(parts=[image_part, text_part]) # Generate evaluation using the vision model response = vision_model.generate_content(multimodal_content) return response.text def evaluate_scores(evaluation_text, current_difficulty): """ Parse the JSON evaluation and decide if the child advances. The threshold scales with difficulty: Very Simple: 70, Simple: 75, Moderate: 80, Detailed: 85, Very Detailed: 90. """ try: json_match = re.search(r'\{.*\}', evaluation_text, re.DOTALL) if json_match: json_str = json_match.group(0) evaluation = json.loads(json_str) else: raise ValueError("No JSON object found in the response.") final_score = evaluation.get("final_score", 0) hint = evaluation.get("hint", "Keep trying!") advance = evaluation.get("advance", False) difficulty_thresholds = { "Very Simple": 70, "Simple": 75, "Moderate": 80, "Detailed": 85, "Very Detailed": 90 } current_threshold = difficulty_thresholds.get(current_difficulty, 70) difficulty_mapping = { "Very Simple": "Simple", "Simple": "Moderate", "Moderate": "Detailed", "Detailed": "Very Detailed", "Very Detailed": "Very Detailed" } if final_score >= current_threshold or advance: new_difficulty = difficulty_mapping.get(current_difficulty, current_difficulty) response_msg = (f"Great job! Your final score is {final_score}, which meets the target of {current_threshold}. " f"You've advanced to {new_difficulty} difficulty.") return response_msg, new_difficulty else: response_msg = (f"Your final score is {final_score} (\n target: {current_threshold}). {hint} \n " f"Please try again at the {current_difficulty} level.") return response_msg, current_difficulty except Exception as e: return f"Error processing evaluation output: {str(e)}", current_difficulty def chat_respond(user_message, active_session, saved_sessions): """ Process a new chat message. Evaluate the child's description. If the evaluation indicates advancement, update the difficulty, generate a new image (resetting image and chat), and update the difficulty label. """ if not active_session.get("image"): bot_message = "Please generate an image first." updated_chat = active_session.get("chat", []) + [(user_message, bot_message)] active_session["chat"] = updated_chat return "", updated_chat, saved_sessions, active_session chat_history = active_session.get("chat", []) identified_details = active_session.get("identified_details", []) raw_evaluation = compare_details_chat_fn(user_message, "", chat_history, identified_details) current_difficulty = active_session.get("difficulty", "Very Simple") evaluation_response, updated_difficulty = evaluate_scores(raw_evaluation, current_difficulty) bot_message = evaluation_response # If the child advanced, update difficulty and generate a new image if updated_difficulty != current_difficulty: # Update the active session's difficulty before generating a new prompt active_session["difficulty"] = updated_difficulty age = active_session.get("age", "3") autism_level = active_session.get("autism_level", "Level 1") topic_focus = active_session.get("topic_focus", "") treatment_plan = active_session.get("treatment_plan", "") new_image, new_active_session, new_sessions = generate_image_and_reset_chat(age, autism_level, topic_focus, treatment_plan, active_session, saved_sessions) new_active_session["chat"].append(("System", f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you.")) active_session = new_active_session bot_message = f"You advanced to {updated_difficulty} difficulty! A new image has been generated for you." saved_sessions = new_sessions else: updated_chat = active_session.get("chat", []) + [(user_message, bot_message)] active_session["chat"] = updated_chat return "", active_session["chat"], saved_sessions, active_session def update_sessions(saved_sessions, active_session): """ Combine finished sessions with the active session for display. """ if active_session and active_session.get("prompt"): return saved_sessions + [active_session] return saved_sessions ############################################## # Gradio Interface ############################################## with gr.Blocks() as demo: # The active session now starts with difficulty "Very Simple" active_session = gr.State({ "prompt": None, "image": None, "chat": [], "treatment_plan": "", "topic_focus": "", "identified_details": [], "difficulty": "Very Simple", "age": "3", "autism_level": "Level 1" }) saved_sessions = gr.State([]) with gr.Column(): gr.Markdown("# Image Generation & Chat Inference") # Display current difficulty label difficulty_label = gr.Markdown("**Current Difficulty:** Very Simple") # ----- Image Generation Section ----- with gr.Column(): gr.Markdown("## Generate Image") gr.Markdown("Enter your age, select your autism level, specify a topic focus, and provide the treatment plan to generate an image based on the current difficulty level.") with gr.Row(): age_input = gr.Textbox(label="Age", placeholder="Enter age...", value="3") autism_level_dropdown = gr.Dropdown(label="Autism Level", choices=["Level 1", "Level 2", "Level 3"], value="Level 1") topic_focus_input = gr.Textbox( label="Topic Focus", placeholder="Enter a specific topic or detail to focus on (e.g., 'animals', 'emotions', 'daily routines')...", lines=1 ) treatment_plan_input = gr.Textbox( label="Treatment Plan", placeholder="Enter the treatment plan to guide the image generation...", lines=2 ) generate_btn = gr.Button("Generate Image") img_output = gr.Image(label="Generated Image") generate_btn.click( generate_image_and_reset_chat, inputs=[age_input, autism_level_dropdown, topic_focus_input, treatment_plan_input, active_session, saved_sessions], outputs=[img_output, active_session, saved_sessions] ) # ----- Chat Section ----- with gr.Column(): gr.Markdown("## Chat about the Image") gr.Markdown( "After generating an image, type details or descriptions about it. " "Your message, along with the generated image and conversation history, will be sent for evaluation." ) chatbot = gr.Chatbot(label="Chat History") with gr.Row(): chat_input = gr.Textbox(label="Your Message", placeholder="Type your description here...", show_label=False) send_btn = gr.Button("Send") send_btn.click( chat_respond, inputs=[chat_input, active_session, saved_sessions], outputs=[chat_input, chatbot, saved_sessions, active_session] ) chat_input.submit( chat_respond, inputs=[chat_input, active_session, saved_sessions], outputs=[chat_input, chatbot, saved_sessions, active_session] ) # ----- Sidebar Section for Session Details ----- with gr.Column(variant="sidebar"): gr.Markdown("## Saved Chat Sessions") gr.Markdown( "This sidebar automatically saves finished chat sessions. " "Each session includes the prompt used, the generated image (as a data URL), " "the topic focus, the treatment plan, the list of identified details, and the full chat history." ) sessions_output = gr.JSON(label="Session Details", value={}) active_session.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output) # Update the current difficulty label when active_session changes. active_session.change(update_difficulty_label, inputs=[active_session], outputs=[difficulty_label]) saved_sessions.change(update_sessions, inputs=[saved_sessions, active_session], outputs=sessions_output) # Launch the app with public sharing enabled. demo.launch()