Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import time | |
import os | |
import torch | |
from PIL import Image | |
from threading import Thread | |
from transformers import TextIteratorStreamer, AutoConfig, AutoModelForCausalLM | |
from constants import ( | |
IMAGE_TOKEN_INDEX, | |
DEFAULT_IMAGE_TOKEN, | |
DEFAULT_IM_START_TOKEN, | |
DEFAULT_IM_END_TOKEN, | |
) | |
from conversation import conv_templates | |
from eval_utils import load_maya_model | |
from utils import disable_torch_init | |
from mm_utils import tokenizer_image_token, process_images | |
from huggingface_hub._login import _login | |
# Import LLaVA modules to register model types | |
from model import * | |
from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig | |
# Register model type and config | |
AutoConfig.register("llava_cohere", LlavaCohereConfig) | |
AutoModelForCausalLM.register(LlavaCohereConfig, LlavaCohereForCausalLM) | |
hf_token = os.getenv("hf_token") | |
_login(token=hf_token, add_to_git_credential=False) | |
# Global Variables | |
MODEL_BASE = "CohereForAI/aya-23-8B" | |
MODEL_PATH = "maya-multimodal/maya" | |
MODE = "finetuned" | |
def load_model(): | |
"""Load the Maya model and required components""" | |
model, tokenizer, image_processor, _ = load_maya_model( | |
MODEL_BASE, MODEL_PATH, None, MODE | |
) | |
model = model.cuda() | |
model.eval() | |
return model, tokenizer, image_processor | |
# Load model globally | |
print("Loading model...") | |
model, tokenizer, image_processor = load_model() | |
print("Model loaded successfully!") | |
def validate_image_file(image_path): | |
"""Validate that the image file exists and is in a supported format.""" | |
if not os.path.isfile(image_path): | |
raise gr.Error(f"Error: File {image_path} does not exist.") | |
try: | |
with Image.open(image_path) as img: | |
img.verify() | |
return True | |
except (IOError, SyntaxError) as e: | |
raise gr.Error(f"Error: {image_path} is not a valid image file. {e}") | |
def process_chat_stream(message, history): | |
print(message) | |
print("History:", history) | |
image = None # Initialize image variable first | |
# First try to get image from current message | |
if message.get("files", []): | |
current_files = message["files"] | |
if current_files: | |
last_file = current_files[-1] | |
image = last_file["path"] if isinstance(last_file, dict) else last_file | |
# If no image in current message, try to get from history | |
if image is None and history: | |
for hist in reversed(history): | |
print("Processing history item:", hist) | |
if isinstance(hist["content"], tuple): | |
image = hist["content"][0] | |
break | |
elif isinstance(hist["content"], dict) and hist["content"].get("files"): | |
hist_files = hist["content"]["files"] | |
if hist_files: | |
first_file = hist_files[0] | |
image = first_file["path"] if isinstance(first_file, dict) else first_file | |
break | |
# Check if we found an image | |
if image is None: | |
raise gr.Error("Please upload an image to start the conversation.") | |
# Validate and process image | |
validate_image_file(image) | |
image = Image.open(image).convert("RGB") | |
# Process image for the model | |
image_tensor = process_images([image], image_processor, model.config) | |
if image_tensor is None: | |
raise gr.Error("Failed to process image") | |
image_tensor = image_tensor.cuda() | |
# Prepare conversation | |
conv = conv_templates["aya"].copy() | |
# Add conversation history | |
for hist in history: | |
# Handle user messages | |
if hist["role"] == "user": | |
# Extract text content based on format | |
if isinstance(hist["content"], str): | |
human_text = hist["content"] | |
elif isinstance(hist["content"], tuple): | |
human_text = hist["content"][1] if len(hist["content"]) > 1 else "" | |
else: | |
human_text = hist["content"] | |
conv.append_message(conv.roles[0], human_text) | |
# Handle assistant messages | |
elif hist["role"] == "assistant": | |
conv.append_message(conv.roles[1], hist["content"]) | |
# Format current message with proper image token placement | |
current_message = message["text"] | |
if not history: | |
if model.config.mm_use_im_start_end: | |
current_message = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{current_message}" | |
else: | |
current_message = f"{DEFAULT_IMAGE_TOKEN}\n{current_message}" | |
# Add current message to conversation | |
conv.append_message(conv.roles[0], current_message) | |
conv.append_message(conv.roles[1], None) | |
# Get prompt and ensure input_ids are properly created | |
prompt = conv.get_prompt() | |
# print("PROMPT: ", prompt) | |
try: | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
if input_ids is None: | |
raise ValueError("Tokenization returned None") | |
# Ensure input_ids is 2D tensor | |
if len(input_ids.shape) == 1: | |
input_ids = input_ids.unsqueeze(0) | |
input_ids = input_ids.cuda() | |
# Validate vision tower and image tensor before starting generation | |
if not hasattr(model, 'get_vision_tower') or model.get_vision_tower() is None: | |
raise ValueError("Model's vision tower is not properly initialized") | |
if image_tensor is None: | |
raise ValueError("Image tensor is None") | |
# Setup streamer and generation | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) | |
generation_kwargs = { | |
"inputs": input_ids, | |
"images": image_tensor, | |
"image_sizes": [image.size], | |
"streamer": streamer, | |
"temperature": 0.3, | |
"do_sample": True, | |
"top_p": 0.9, | |
"num_beams": 1, | |
"max_new_tokens": 4096, | |
"use_cache": True | |
} | |
def generate_with_error_handling(): | |
try: | |
model.generate(**generation_kwargs) | |
except Exception as e: | |
import traceback | |
error_msg = f"Generation error: {str(e)}\nTraceback:\n{''.join(traceback.format_exc())}" | |
raise gr.Error(error_msg) | |
thread = Thread(target=generate_with_error_handling) | |
thread.start() | |
except Exception as e: | |
error_msg = f"Setup error: {str(e)}" | |
import traceback | |
error_msg += f"\nTraceback:\n{''.join(traceback.format_exc())}" | |
raise gr.Error(error_msg) | |
partial_message = "" | |
for new_token in streamer: | |
partial_message += new_token | |
time.sleep(0.1) | |
yield {"role": "assistant", "content": partial_message} | |
# Create Gradio interface | |
chatbot = gr.Chatbot( | |
show_label=False, | |
height=450, | |
show_share_button=False, | |
show_copy_button=False, | |
avatar_images=None, | |
container=True, | |
render_markdown=True, | |
scale=1, | |
type="messages" | |
) | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
with gr.Blocks(fill_height=True, ) as demo: | |
gr.ChatInterface( | |
fn=process_chat_stream, | |
title="Maya: Multilingual Multimodal Model", | |
examples=[{"text": "Describe this photo in detail.", "files": ["./asian_food.jpg"]}, | |
{"text": "What is the name of this famous sight in the photo?", "files": ["./hawaii.jpg"]}], | |
description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. [Read the research paper](https://huggingface.co/papers/2412.07112)\n\nTeam π Maya", | |
stop_btn="Stop Generation", | |
multimodal=True, | |
textbox=chat_input, | |
chatbot=chatbot, | |
) | |
if __name__ == "__main__": | |
demo.queue(api_open=False) | |
demo.launch(show_api=False, share=False) |