import streamlit as st from PIL import Image import torch from transformers import BlipProcessor, BlipForConditionalGeneration, GPT2LMHeadModel, GPT2Tokenizer import os # ---------------------------- Load Models ---------------------------- # def load_blip_model(): """ Loads the BLIP processor and model. """ processor = BlipProcessor.from_pretrained('Salesforce/blip-image-captioning-base') model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() return processor, model, device def load_finetuned_gpt2(model_dir='jyotikh/gpt2_finetuned'): """ Loads the fine-tuned GPT-2 tokenizer and model. """ tokenizer = GPT2Tokenizer.from_pretrained(model_dir) model = GPT2LMHeadModel.from_pretrained(model_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() return tokenizer, model, device # Load Models processor_blip, model_blip, device_blip = load_blip_model() tokenizer_gpt2, model_gpt2, device_gpt2 = load_finetuned_gpt2() # ---------------------------- Streamlit UI ---------------------------- # st.title("Image Caption Generator") # File Uploader uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display the uploaded image image = Image.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image', use_column_width=True) with st.spinner('Generating captions...'): # Generate base caption using BLIP inputs = processor_blip(images=image, return_tensors="pt").to(device_blip) with torch.no_grad(): out = model_blip.generate(**inputs) base_caption = processor_blip.decode(out[0], skip_special_tokens=True) # Generate funny caption using GPT-2 prompt = f"Generate a funny caption for: {base_caption} \nFunny Caption:" inputs_gpt = tokenizer_gpt2.encode(prompt, return_tensors='pt').to(device_gpt2) outputs_gpt = model_gpt2.generate( inputs_gpt, max_length=len(inputs_gpt[0]) + 50, temperature=0.7, top_p=0.95, do_sample=True, num_return_sequences=1, pad_token_id=tokenizer_gpt2.eos_token_id ) generated_text = tokenizer_gpt2.decode(outputs_gpt[0], skip_special_tokens=True) funny_caption = generated_text.split("Caption:")[-1].strip() st.write(funny_caption)