Mayada's picture
Update app.py
f3c07ed verified
raw
history blame
4.47 kB
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer') # Your model on Hugging Face
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
# Define the normalization and transformations
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet standard deviation
)
inference_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Load the dictionary (use it from your Hugging Face Space or include in the repo)
dictionary = {
"caption": "alternative_caption" # Replace with your actual dictionary
}
# Function to correct words in the caption using the dictionary
def correct_caption(caption):
corrected_words = [dictionary.get(word, word) for word in caption.split()]
corrected_caption = " ".join(corrected_words)
return corrected_caption
# Function to generate captions for an image
def generate_captions(image):
img_tensor = inference_transforms(image).unsqueeze(0)
generated = caption_model.generate(
img_tensor,
num_beams=3,
max_length=10,
early_stopping=True,
do_sample=True,
top_k=1000,
num_return_sequences=1,
)
captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
return captions
# Function to generate questions given a context and answer
def generate_questions(context, answer):
text = "context: " + context + " " + "answer: " + answer + " </s>"
text_encoding = question_tokenizer.encode_plus(
text, return_tensors="pt"
)
question_model.eval()
generated_ids = question_model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
'question: ', ' ') for g in generated_ids]
return questions
# Gradio Interface Function
def caption_question_interface(image):
captions = generate_captions(image)
corrected_captions = [correct_caption(caption) for caption in captions]
questions_with_answers = []
for caption in corrected_captions:
words = caption.split()
if len(words) > 0:
answer = words[0]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 1:
answer = words[1]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 1:
answer = " ".join(words[:2])
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 2:
answer = words[2]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 3:
answer = words[3]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
formatted_questions = "\n".join(formatted_questions)
return "\n".join(corrected_captions), formatted_questions
gr_interface = gr.Interface(
fn=caption_question_interface,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs=[
gr.outputs.Textbox(label="Generated Captions"),
gr.outputs.Textbox(label="Generated Questions and Answers")
],
title="Image Captioning and Question Generation",
description="Generate captions and questions for images using pre-trained models."
)
gr_interface.launch()