Spaces:
Build error
Build error
import numpy as np | |
import random | |
from PIL import Image | |
from ram.models import tag2text | |
from ram import inference_tag2text as inference | |
from ram import get_transform | |
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
evice = torch.device('cpu') | |
model = tag2text(pretrained='./tag2text.pth', image_size=384, vit='swin_b') | |
model.threshold = 0.68 | |
model.eval() | |
def tag_image(input_image): | |
transform = get_transform(image_size=384) | |
if isinstance(input_image, Image.Image): | |
img = input_image | |
else: | |
# Convert Gradio Image datatype (NumPy array) to PIL Image | |
img = Image.fromarray(input_image) | |
# Process the image | |
print(f"Start processing, image size {img.size}") | |
image = transform(img).unsqueeze(0) | |
# Generate Tags and Captions | |
res = inference(image, model) | |
tags = res[0].strip(' ').replace(' ', ' ') | |
caption = res[2] | |
print(tags, caption) | |
return tags, caption | |
# Interface for the demo | |
inputs = gr.inputs.Image() | |
outputs = [gr.outputs.Textbox(label='Tags'), gr.outputs.Textbox(label='Caption')] | |
# Launch the Gradio app | |
gr.Interface( | |
fn=tag_image, | |
inputs=inputs, | |
outputs=outputs, | |
title="Tags and Captioning using Tag2Text", | |
description="Upload an image and see the results in text boxes.", | |
live=True, | |
).launch(share=True, enable_queue=True, debug=True) |