File size: 4,465 Bytes
8ceebf6 f3c07ed 8ceebf6 f3c07ed 8ceebf6 f3c07ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer') # Your model on Hugging Face
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
# Define the normalization and transformations
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet standard deviation
)
inference_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Load the dictionary (use it from your Hugging Face Space or include in the repo)
dictionary = {
"caption": "alternative_caption" # Replace with your actual dictionary
}
# Function to correct words in the caption using the dictionary
def correct_caption(caption):
corrected_words = [dictionary.get(word, word) for word in caption.split()]
corrected_caption = " ".join(corrected_words)
return corrected_caption
# Function to generate captions for an image
def generate_captions(image):
img_tensor = inference_transforms(image).unsqueeze(0)
generated = caption_model.generate(
img_tensor,
num_beams=3,
max_length=10,
early_stopping=True,
do_sample=True,
top_k=1000,
num_return_sequences=1,
)
captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
return captions
# Function to generate questions given a context and answer
def generate_questions(context, answer):
text = "context: " + context + " " + "answer: " + answer + " </s>"
text_encoding = question_tokenizer.encode_plus(
text, return_tensors="pt"
)
question_model.eval()
generated_ids = question_model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
'question: ', ' ') for g in generated_ids]
return questions
# Gradio Interface Function
def caption_question_interface(image):
captions = generate_captions(image)
corrected_captions = [correct_caption(caption) for caption in captions]
questions_with_answers = []
for caption in corrected_captions:
words = caption.split()
if len(words) > 0:
answer = words[0]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 1:
answer = words[1]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 1:
answer = " ".join(words[:2])
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 2:
answer = words[2]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
if len(words) > 3:
answer = words[3]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
formatted_questions = "\n".join(formatted_questions)
return "\n".join(corrected_captions), formatted_questions
gr_interface = gr.Interface(
fn=caption_question_interface,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs=[
gr.outputs.Textbox(label="Generated Captions"),
gr.outputs.Textbox(label="Generated Questions and Answers")
],
title="Image Captioning and Question Generation",
description="Generate captions and questions for images using pre-trained models."
)
gr_interface.launch()
|