ParsedBill / app.py
marioggil's picture
All api
1fa72e2
raw
history blame contribute delete
No virus
4.94 kB
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import re
import gradio as gr
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
def ClassificateDocs(pathimage):
image = Image.open(pathimage)
pixel_values = processor(image, return_tensors="pt").pixel_values
task_prompt = "<s_rvlcdip>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)
processor_prs= DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
model_prs = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
def ProcessBill(pathimage ):
image = Image.open(pathimage)
pixel_values = processor_prs(image, return_tensors="pt").pixel_values
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor_prs.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
device = "cuda" if torch.cuda.is_available() else "cpu"
model_prs.to(device)
outputs = model_prs.generate(pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model_prs.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor_prs.tokenizer.pad_token_id,
eos_token_id=processor_prs.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor_prs.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,)
sequence = processor_prs.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor_prs.tokenizer.eos_token, "").replace(processor_prs.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor_prs.token2json(sequence)
processor_qa= DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model_qa = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
def QAsBill(pathimage,question="When is the coffee break?" ):
image = Image.open(pathimage)
pixel_values = processor_qa(image, return_tensors="pt").pixel_values
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor_qa.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
device = "cuda" if torch.cuda.is_available() else "cpu"
model_qa.to(device)
outputs = model_qa.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor_qa.tokenizer.pad_token_id,
eos_token_id=processor_qa.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor_qa.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor_qa.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor_qa.tokenizer.eos_token, "").replace(processor._qatokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor_qa.token2json(sequence)
demo = gr.Blocks()
gradio_app_cls = gr.Interface(
fn=ClassificateDocs,
inputs=[
gr.Image(type='filepath')
],
outputs="text",
)
gradio_app_prs = gr.Interface(
fn=ProcessBill,
inputs=[
gr.Image(type='filepath')
],
outputs="text",
)
gradio_app_qa = gr.Interface(
fn=QAsBill,
inputs=[
gr.Image(type='filepath'),
gr.Text()
],
outputs="text",
)
demo = gr.TabbedInterface([gradio_app_cls, gradio_app_prs,gradio_app_qa], ["class", "parse","QA"])
if __name__ == "__main__":
demo.launch()