testwildlife / app.py
ki1207's picture
Update app.py
9e66ec3 verified
raw
history blame
3.86 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import string
import gradio as gr
import PIL.Image
import torch
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
DESCRIPTION = "# [BLIP-2 VQA Ad Listing Analysis](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU.</p>"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_ID = "Salesforce/instructblip-flan-t5-xl"
processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
def answer_ad_listing_question(
image: PIL.Image.Image,
title: str,
decoding_method: str = "Nucleus sampling",
temperature: float = 1.0,
length_penalty: float = 1.0,
repetition_penalty: float = 1.5,
max_length: int = 50,
min_length: int = 1,
num_beams: int = 5,
top_p: float = 0.9,
) -> str:
# The prompt template with the provided title
prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text:
Identify the species mentioned in the text, including specific names, e.g., 'Nile crocodile' instead of just 'crocodile'.
Select the product type from the following options: Animal fibers, Animal parts (bone or bone-like), Animal parts (fleshy), Coral product, Egg, Extract, Food, Ivory products, Live, Medicine, Nests, Organs and tissues, Powder, Scales or spines, Shells, Skin or leather products, Taxidermy, Insects.
The response should be in the format:
"Product Type: [type]
Species: [species]"
"""
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(
**inputs,
do_sample=decoding_method == "Nucleus sampling",
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
top_p=top_p,
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return result
def postprocess_output(output: str) -> str:
if output and output[-1] not in string.punctuation:
output += "."
return output
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
# Image and ad title input
image = gr.Image(type="pil")
ad_title = gr.Textbox(label="Ad Title", placeholder="Enter the ad title here", lines=1)
# Output section
answer_output = gr.Textbox(label="Ad Listing Analysis", show_label=True, placeholder="Response will appear here.")
# Submit and clear buttons
with gr.Row():
submit_button = gr.Button("Analyze Ad Listing", variant="primary")
clear_button = gr.Button("Clear")
# Logic to handle clicking on "Analyze Ad Listing"
submit_button.click(
fn=answer_ad_listing_question,
inputs=[
image,
ad_title, # The title from the ad
"Nucleus sampling", # Default values for decoding method, temperature, etc.
1.0, # temperature
1.0, # length_penalty
1.5, # repetition_penalty
50, # max_length
1, # min_length
5, # num_beams
0.9, # top_p
],
outputs=answer_output,
)
# Logic to handle clearing the inputs and outputs
clear_button.click(
fn=lambda: ("", "", ""), # Clear all the fields
inputs=None,
outputs=[image, ad_title, answer_output],
queue=False,
)
if __name__ == "__main__":
demo.queue(max_size=10).launch()