File size: 4,465 Bytes
8ceebf6
f3c07ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ceebf6
 
f3c07ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ceebf6
f3c07ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM

# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer')  # Your model on Hugging Face
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")

# Define the normalization and transformations
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],  # ImageNet mean
    std=[0.229, 0.224, 0.225]  # ImageNet standard deviation
)

inference_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

# Load the dictionary (use it from your Hugging Face Space or include in the repo)
dictionary = {
    "caption": "alternative_caption"  # Replace with your actual dictionary
}

# Function to correct words in the caption using the dictionary
def correct_caption(caption):
    corrected_words = [dictionary.get(word, word) for word in caption.split()]
    corrected_caption = " ".join(corrected_words)
    return corrected_caption

# Function to generate captions for an image
def generate_captions(image):
    img_tensor = inference_transforms(image).unsqueeze(0)
    generated = caption_model.generate(
        img_tensor,
        num_beams=3,
        max_length=10,
        early_stopping=True,
        do_sample=True,
        top_k=1000,
        num_return_sequences=1,
    )
    captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
    return captions

# Function to generate questions given a context and answer
def generate_questions(context, answer):
    text = "context: " + context + " " + "answer: " + answer + " </s>"
    text_encoding = question_tokenizer.encode_plus(
        text, return_tensors="pt"
    )
    question_model.eval()
    generated_ids = question_model.generate(
        input_ids=text_encoding['input_ids'],
        attention_mask=text_encoding['attention_mask'],
        max_length=64,
        num_beams=5,
        num_return_sequences=1
    )
    questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
        'question: ', ' ') for g in generated_ids]
    return questions

# Gradio Interface Function
def caption_question_interface(image):
    captions = generate_captions(image)
    corrected_captions = [correct_caption(caption) for caption in captions]
    questions_with_answers = []
    
    for caption in corrected_captions:
        words = caption.split()
        if len(words) > 0:
            answer = words[0]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        if len(words) > 1:
            answer = words[1]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        if len(words) > 1:
            answer = " ".join(words[:2])
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        if len(words) > 2:
            answer = words[2]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        if len(words) > 3:
            answer = words[3]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])

    formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
    formatted_questions = "\n".join(formatted_questions)

    return "\n".join(corrected_captions), formatted_questions

gr_interface = gr.Interface(
    fn=caption_question_interface,
    inputs=gr.inputs.Image(type="pil", label="Input Image"),
    outputs=[
        gr.outputs.Textbox(label="Generated Captions"),
        gr.outputs.Textbox(label="Generated Questions and Answers")
    ],
    title="Image Captioning and Question Generation",
    description="Generate captions and questions for images using pre-trained models."
)

gr_interface.launch()