from PIL import Image import gradio as gr import torch import torch.nn as nn from transformers import AutoTokenizer, pipeline from transformers import AutoModelForCausalLM from torchvision import transforms from transformers import CLIPProcessor, CLIPModel from model import build_mlp_vector_projector device = "cpu" # Load the CLIP model and processor clip_model_name = "openai/clip-vit-base-patch16" clip_model = CLIPModel.from_pretrained(clip_model_name).to(device) clip_processor = CLIPProcessor.from_pretrained(clip_model_name) clip_transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor() ] ) def process_image(img_path): image = Image.open(img_path).convert("RGB") image = clip_transform(image) inputs = clip_processor(text=[""], images=image, return_tensors="pt", padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} img_embedding = clip_model(**inputs).image_embeds img_proj_head = build_mlp_vector_projector().to(device) img_proj_head.load_state_dict(torch.load( 'stage_2_proj_head_v3.pth', map_location=torch.device(device))) img_tokens = img_proj_head(img_embedding) return img_tokens phi_model_name = "microsoft/phi-2" text_tokenizer = AutoTokenizer.from_pretrained( phi_model_name, trust_remote_code=True) with torch.no_grad(): base_phi2_text = AutoModelForCausalLM.from_pretrained( phi_model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16 ) tuned_phi2 = AutoModelForCausalLM.from_pretrained( "stage2_adaptor", trust_remote_code=True, ).to("cpu") print("phi2 model loaded") audio_model_name = "openai/whisper-small" audio_pipe = pipeline( task="automatic-speech-recognition", model=audio_model_name, chunk_length_s=30, device=device) def process_text(text, count): inputs = text_tokenizer.encode(text, return_tensors="pt") input_embeds = tuned_phi2.get_submodule( 'model.embed_tokens')(inputs).to(device) prediction = text_tokenizer.batch_decode( tuned_phi2.generate( inputs_embeds=input_embeds, max_new_tokens=30, bos_token_id=text_tokenizer.bos_token_id, eos_token_id=text_tokenizer.eos_token_id, pad_token_id=text_tokenizer.pad_token_id ) ) return prediction[0].rstrip('<|endoftext|>').rstrip("\n") def process_audio(audio): if audio is None: raise gr.Error( "Please provide an audio file or record your input" ) text = audio_pipe( audio, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True )["text"] return text def generate_response(image, audio, text, count): count = int(count) overall_input = "" if audio: overall_input = process_audio(audio) if text: overall_input = text + overall_input if image: img_tokens = process_image(image) overall_input = "Question: " + overall_input + "Answer:" q_tokens = text_tokenizer.encode( overall_input, return_tensors='pt').to(device) question_token_embeddings = tuned_phi2.get_submodule( 'model.embed_tokens')(q_tokens).to(device) inputs = torch.concat( (img_tokens.unsqueeze(0), question_token_embeddings), axis=-2).to(device) prediction = text_tokenizer.batch_decode( tuned_phi2.generate( inputs_embeds=inputs, max_new_tokens=30, bos_token_id=text_tokenizer.bos_token_id, eos_token_id=text_tokenizer.eos_token_id, pad_token_id=text_tokenizer.pad_token_id ) ) return prediction[0].rstrip('<|endoftext|>').rstrip("\n") else: return process_text(overall_input, count) with gr.Blocks() as demo: gr.Markdown("# **AnyModeAssistant**") gr.Markdown("Use any mode text/image/audio to interact with AI assistant") with gr.Row(): with gr.Column(scale=4): with gr.Row("Text"): text_input = gr.Textbox(placeholder="Enter your question here", label="Input") with gr.Row(): image_input = gr.Image(type="filepath") with gr.Row("Audio mode"): audio_input = gr.Audio(type="filepath") with gr.Row("Image"): response_count = gr.Textbox( placeholder="Number of tokens to respond", value=20, label="Count") with gr.Column(scale=2): response = gr.Textbox(label="AI Response") with gr.Row(): submit_button = gr.Button("Submit") submit_button.click(generate_response, inputs=[ image_input, audio_input, text_input, response_count ], outputs=response) gr.Examples( examples=[ ["dog_man_forest.jpg", "audio.wav", "Is there a dog present in the image?"], ], inputs=[image_input, audio_input, text_input, response_count], outputs=[response], fn=generate_response, ) demo.launch(share=True)