File size: 2,950 Bytes
32c980d
 
 
 
 
 
 
 
 
 
 
 
3c3c65b
32c980d
 
 
 
 
 
 
 
 
8828e34
32c980d
9e66ec3
32c980d
9e66ec3
32c980d
9e66ec3
 
 
 
 
 
 
 
32c980d
 
3c3c65b
32c980d
 
 
 
bfc1711
 
32c980d
 
654ae11
32c980d
 
 
9e66ec3
32c980d
3c3c65b
32c980d
9e66ec3
3c3c65b
32c980d
9e66ec3
 
3c3c65b
9e66ec3
 
 
 
 
9a4e757
9e66ec3
32c980d
 
9e66ec3
 
 
32c980d
9e66ec3
32c980d
 
 
 
9e66ec3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python

from __future__ import annotations

import os
import string

import gradio as gr
import PIL.Image
import torch
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration

DESCRIPTION = "# [BLIP-2 test](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU.</p>"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MODEL_ID = "Salesforce/instructblip-flan-t5-xl"

processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device)

def answer_ad_listing_question(
    image: PIL.Image.Image,
    title: str,
) -> str:
    # The prompt template with the provided title
    prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text:
    Identify the species mentioned in the text, including specific names, e.g., 'Nile crocodile' instead of just 'crocodile'.
    Select the product type from the following options: Animal fibers, Animal parts (bone or bone-like), Animal parts (fleshy), Coral product, Egg, Extract, Food, Ivory products, Live, Medicine, Nests, Organs and tissues, Powder, Scales or spines, Shells, Skin or leather products, Taxidermy, Insects.
    The response should be in the format:
    "Product Type: [type]
    Species: [species]"
    """

    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(**inputs)
    result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return result

def postprocess_output(output: str) -> str:
    # if output and output[-1] not in string.punctuation:
    #     output += "."
    return output

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Group():
        # Image and ad title input
        image = gr.Image(type="pil")
        ad_title = gr.Textbox(label="Advertisement Title", placeholder="Enter the title here", lines=1)

        # Output section
        answer_output = gr.Textbox(label="Analysis", show_label=True, placeholder="Response.")

        # Submit and clear buttons
        with gr.Row():
            submit_button = gr.Button("Analyze Listing", variant="primary")
            clear_button = gr.Button("Clear")

    # Logic to handle clicking on "Analyze Ad Listing"
    submit_button.click(
        fn=answer_ad_listing_question,
        inputs=[image, ad_title],  # Only the image and ad title are inputs
        outputs=answer_output,
    )

    # Logic to handle clearing the inputs and outputs
    clear_button.click(
        fn=lambda: ("", "", ""),  # Clear all the fields
        inputs=None,
        outputs=[image, ad_title, answer_output],
        queue=False,
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch()