Spaces:
Build error
Build error
import argparse | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from donut import DonutModel | |
def demo_process_vqa(input_img, question): | |
global pretrained_model, task_prompt, task_name | |
# input_img = Image.fromarray(input_img) | |
user_prompt = task_prompt.replace("{user_input}", question) | |
output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] | |
return output | |
def demo_process(input_img): | |
global pretrained_model, task_prompt, task_name,security_layer | |
input_img = Image.fromarray(input_img) | |
sec = security_layer.inference(image=input_img,prompt="<s_rvlcdip>")['predictions'][0] | |
print(sec) | |
if sec['class']=="invoice": | |
output = pretrained_model.inference(image=input_img, prompt="<s_cord-v2>")["predictions"][0] | |
return output | |
return sec | |
task_name="cord-v2" | |
if "docvqa" == task_name: | |
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" | |
else: # rvlcdip, cord, ... | |
task_prompt = f"<s_{task_name}>" | |
security_layer = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip") | |
pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") | |
if torch.cuda.is_available(): | |
pretrained_model.half() | |
security_layer.half() | |
device = torch.device("cuda") | |
pretrained_model.to(device) | |
security_layer.to(device) | |
else: | |
pretrained_model.encoder.to(torch.bfloat16) | |
security_layer.encoder.to(torch.bfloat16) | |
pretrained_model.eval() | |
security_layer.eval() | |
demo = gr.Interface( | |
fn=demo_process_vqa if task_name == "docvqa" else demo_process, | |
inputs=["image", "text"] if task_name == "docvqa" else "image", | |
outputs="json", | |
title=f"Donut 🍩 demonstration for `{task_name}` task", | |
concurrency_limit=10, | |
description="Get invoice details if invoice" | |
) | |
demo.queue(default_concurrency_limit=2,max_size=5) | |
demo.launch(debug=True,share=True, inline=False) | |