File size: 4,944 Bytes
f56681b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96be5b8
1fa72e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56681b
 
1fa72e2
f56681b
 
 
 
 
 
1fa72e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56681b
 
1fa72e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import re
import gradio as gr

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")

def ClassificateDocs(pathimage):
  image = Image.open(pathimage)
  pixel_values = processor(image, return_tensors="pt").pixel_values
  task_prompt = "<s_rvlcdip>"
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
  device = "cuda" if torch.cuda.is_available() else "cpu"
  model.to(device)
  outputs = model.generate(
      pixel_values.to(device),
      decoder_input_ids=decoder_input_ids.to(device),
      max_length=model.decoder.config.max_position_embeddings,
      pad_token_id=processor.tokenizer.pad_token_id,
      eos_token_id=processor.tokenizer.eos_token_id,
      use_cache=True,
      bad_words_ids=[[processor.tokenizer.unk_token_id]],
      return_dict_in_generate=True,
  )

  sequence = processor.batch_decode(outputs.sequences)[0]
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
  return processor.token2json(sequence)

processor_prs= DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
model_prs = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")

def ProcessBill(pathimage ):
  image = Image.open(pathimage)
  pixel_values = processor_prs(image, return_tensors="pt").pixel_values
  task_prompt = "<s_cord-v2>"
  decoder_input_ids = processor_prs.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
  device = "cuda" if torch.cuda.is_available() else "cpu"
  model_prs.to(device)
  outputs = model_prs.generate(pixel_values.to(device),
                               decoder_input_ids=decoder_input_ids.to(device),
                               max_length=model_prs.decoder.config.max_position_embeddings,
                               early_stopping=True,
                               pad_token_id=processor_prs.tokenizer.pad_token_id,
                               eos_token_id=processor_prs.tokenizer.eos_token_id,
                               use_cache=True,
                               num_beams=1,
                               bad_words_ids=[[processor_prs.tokenizer.unk_token_id]],
                               return_dict_in_generate=True,
                               output_scores=True,)
  sequence = processor_prs.batch_decode(outputs.sequences)[0]
  sequence = sequence.replace(processor_prs.tokenizer.eos_token, "").replace(processor_prs.tokenizer.pad_token, "")
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
  return processor_prs.token2json(sequence)

processor_qa= DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model_qa = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
def QAsBill(pathimage,question="When is the coffee break?" ):
  image = Image.open(pathimage)
  pixel_values = processor_qa(image, return_tensors="pt").pixel_values
  task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
  prompt = task_prompt.replace("{user_input}", question)
  decoder_input_ids = processor_qa.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids

  device = "cuda" if torch.cuda.is_available() else "cpu"
  model_qa.to(device)
  outputs = model_qa.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    max_length=model.decoder.config.max_position_embeddings,
    pad_token_id=processor_qa.tokenizer.pad_token_id,
    eos_token_id=processor_qa.tokenizer.eos_token_id,
    use_cache=True,
    bad_words_ids=[[processor_qa.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
  )

  sequence = processor_qa.batch_decode(outputs.sequences)[0]
  sequence = sequence.replace(processor_qa.tokenizer.eos_token, "").replace(processor._qatokenizer.pad_token, "")
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
  return processor_qa.token2json(sequence)


demo = gr.Blocks()

gradio_app_cls = gr.Interface(
    fn=ClassificateDocs,
    inputs=[
        gr.Image(type='filepath')
    ],
    outputs="text",
)
gradio_app_prs = gr.Interface(
    fn=ProcessBill,
    inputs=[
        gr.Image(type='filepath')
    ],
    outputs="text",
)
gradio_app_qa = gr.Interface(
    fn=QAsBill,
    inputs=[
        gr.Image(type='filepath'),
        gr.Text()
    ],
    outputs="text",
)


demo = gr.TabbedInterface([gradio_app_cls, gradio_app_prs,gradio_app_qa], ["class", "parse","QA"])

if __name__ == "__main__":
    demo.launch()