Tag2Text_Demo / app.py
gremlin97's picture
Update app.py
5360e91
raw
history blame
1.42 kB
import numpy as np
import random
from PIL import Image
from ram.models import tag2text
from ram import inference_tag2text as inference
from ram import get_transform
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
device = torch.device('cpu')
model = tag2text(pretrained='tag2text.pth', image_size=384, vit='swin_b')
model.threshold = 0.68
model.eval()
def tag_image(input_image):
transform = get_transform(image_size=384)
if isinstance(input_image, Image.Image):
img = input_image
else:
# Convert Gradio Image datatype (NumPy array) to PIL Image
img = Image.fromarray(input_image)
# Process the image
print(f"Start processing, image size {img.size}")
image = transform(img).unsqueeze(0)
# Generate Tags and Captions
res = inference(image, model)
tags = res[0].strip(' ').replace(' ', ' ')
caption = res[2]
print(tags, caption)
return tags, caption
# Interface for the demo
inputs = gr.inputs.Image()
outputs = [gr.outputs.Textbox(label='Tags'), gr.outputs.Textbox(label='Caption')]
# Launch the Gradio app
gr.Interface(
fn=tag_image,
inputs=inputs,
outputs=outputs,
title="Tags and Captioning using Tag2Text",
description="Upload an image and see the tags and captions in the output boxes",
live=True,
).launch(share=True, enable_queue=True, debug=True)