File size: 8,088 Bytes
e968589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba6f6c
 
e968589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import gradio as gr
import spaces
import time
import os
import torch
from PIL import Image
from threading import Thread
from transformers import TextIteratorStreamer, AutoConfig, AutoModelForCausalLM
from constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
from conversation import conv_templates
from eval_utils import load_maya_model
from utils import disable_torch_init
from mm_utils import tokenizer_image_token, process_images
from huggingface_hub._login import _login

# Import LLaVA modules to register model types
from model import *
from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig

# Register model type and config
AutoConfig.register("llava_cohere", LlavaCohereConfig)
AutoModelForCausalLM.register(LlavaCohereConfig, LlavaCohereForCausalLM)

hf_token = os.getenv("hf_token")
_login(token=hf_token, add_to_git_credential=False)

# Global Variables
MODEL_BASE = "CohereForAI/aya-23-8B"
MODEL_PATH = "maya-multimodal/maya"
MODE = "finetuned"

def load_model():
    """Load the Maya model and required components"""
    model, tokenizer, image_processor, _ = load_maya_model(
        MODEL_BASE, MODEL_PATH, None, MODE
    )
    model = model.cuda()
    model.eval()
    return model, tokenizer, image_processor

# Load model globally
print("Loading model...")
model, tokenizer, image_processor = load_model()
print("Model loaded successfully!")

def validate_image_file(image_path):
    """Validate that the image file exists and is in a supported format."""
    if not os.path.isfile(image_path):
        raise gr.Error(f"Error: File {image_path} does not exist.")

    try:
        with Image.open(image_path) as img:
            img.verify()
        return True
    except (IOError, SyntaxError) as e:
        raise gr.Error(f"Error: {image_path} is not a valid image file. {e}")

@spaces.GPU
def process_chat_stream(message, history):
    print(message)
    print("History:", history)
    image = None  # Initialize image variable first
    
    # First try to get image from current message
    if message.get("files", []):
        current_files = message["files"]
        if current_files:
            last_file = current_files[-1]
            image = last_file["path"] if isinstance(last_file, dict) else last_file
    
    # If no image in current message, try to get from history
    if image is None and history:
        for hist in reversed(history):
            print("Processing history item:", hist)
            if isinstance(hist["content"], tuple):
                image = hist["content"][0]
                break
            elif isinstance(hist["content"], dict) and hist["content"].get("files"):
                hist_files = hist["content"]["files"]
                if hist_files:
                    first_file = hist_files[0]
                    image = first_file["path"] if isinstance(first_file, dict) else first_file
                    break
    
    # Check if we found an image
    if image is None:
        raise gr.Error("Please upload an image to start the conversation.")
            
    # Validate and process image
    validate_image_file(image)
    image = Image.open(image).convert("RGB")

    # Process image for the model
    image_tensor = process_images([image], image_processor, model.config)
    if image_tensor is None:
        raise gr.Error("Failed to process image")

    image_tensor = image_tensor.cuda()

    # Prepare conversation
    conv = conv_templates["aya"].copy()
    
    # Add conversation history
    for hist in history:
        # Handle user messages
        if hist["role"] == "user":
            # Extract text content based on format
            if isinstance(hist["content"], str):
                human_text = hist["content"]
            elif isinstance(hist["content"], tuple):
                human_text = hist["content"][1] if len(hist["content"]) > 1 else ""
            else:
                human_text = hist["content"]
            conv.append_message(conv.roles[0], human_text)
        
        # Handle assistant messages
        elif hist["role"] == "assistant":
            conv.append_message(conv.roles[1], hist["content"])

    # Format current message with proper image token placement
    current_message = message["text"]
    if not history:
        if model.config.mm_use_im_start_end:
            current_message = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{current_message}"
        else:
            current_message = f"{DEFAULT_IMAGE_TOKEN}\n{current_message}"
    
    # Add current message to conversation
    conv.append_message(conv.roles[0], current_message)
    conv.append_message(conv.roles[1], None)

    # Get prompt and ensure input_ids are properly created
    prompt = conv.get_prompt()
    # print("PROMPT: ", prompt)
    
    try:
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        if input_ids is None:
            raise ValueError("Tokenization returned None")
        
        # Ensure input_ids is 2D tensor
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(0)
        input_ids = input_ids.cuda()

        # Validate vision tower and image tensor before starting generation
        if not hasattr(model, 'get_vision_tower') or model.get_vision_tower() is None:
            raise ValueError("Model's vision tower is not properly initialized")
        
        if image_tensor is None:
            raise ValueError("Image tensor is None")

        # Setup streamer and generation
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
        
        generation_kwargs = {
            "inputs": input_ids,
            "images": image_tensor,
            "image_sizes": [image.size],
            "streamer": streamer,
            "temperature": 0.3,
            "do_sample": True,
            "top_p": 0.9,
            "num_beams": 1,
            "max_new_tokens": 4096,
            "use_cache": True
        }

        def generate_with_error_handling():
            try:
                model.generate(**generation_kwargs)
            except Exception as e:
                import traceback
                error_msg = f"Generation error: {str(e)}\nTraceback:\n{''.join(traceback.format_exc())}"
                raise gr.Error(error_msg)

        thread = Thread(target=generate_with_error_handling)
        thread.start()

    except Exception as e:
        error_msg = f"Setup error: {str(e)}"
        import traceback
        error_msg += f"\nTraceback:\n{''.join(traceback.format_exc())}"
        raise gr.Error(error_msg)

    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        time.sleep(0.1)
        yield {"role": "assistant", "content": partial_message}



# Create Gradio interface
chatbot = gr.Chatbot(
    show_label=False,
    height=450,
    show_share_button=False,
    show_copy_button=False,
    avatar_images=None,
    container=True,
    render_markdown=True,
    scale=1,
    type="messages"
)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
        fn=process_chat_stream,
        title="Maya: Multilingual Multimodal Model",
    examples=[{"text": "Describe this photo in detail.", "files": ["./asian_food.jpg"]},
              {"text": "What is the name of this famous sight in the photo?", "files": ["./hawaii.jpg"]}],
    description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. [Read the research paper](https://huggingface.co/papers/2412.07112)\n\nTeam πŸ’š Maya",
    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )

if __name__ == "__main__":
    demo.queue(api_open=False)
    demo.launch(show_api=False, share=False)