marioggil commited on
Commit
f56681b
1 Parent(s): 0a2c293

Init commit

Browse files
Files changed (2) hide show
  1. app.py +44 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+ import re
5
+ import gradio as gr
6
+
7
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
8
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
9
+
10
+ def ClassificateDocs(pathimage):
11
+ image = Image.open(pathimage)
12
+ pixel_values = processor(image, return_tensors="pt").pixel_values
13
+ task_prompt = "<s_rvlcdip>"
14
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.to(device)
17
+ outputs = model.generate(
18
+ pixel_values.to(device),
19
+ decoder_input_ids=decoder_input_ids.to(device),
20
+ max_length=model.decoder.config.max_position_embeddings,
21
+ pad_token_id=processor.tokenizer.pad_token_id,
22
+ eos_token_id=processor.tokenizer.eos_token_id,
23
+ use_cache=True,
24
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
25
+ return_dict_in_generate=True,
26
+ )
27
+
28
+ sequence = processor.batch_decode(outputs.sequences)[0]
29
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
30
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
31
+ return processor.token2json(sequence)
32
+ ClassificateDocs("/content/Factura3.jpeg")
33
+ demo = gr.Blocks()
34
+
35
+ gradio_app = gr.Interface(
36
+ fn=ClassificateDocs,
37
+ inputs=[
38
+ gr.Image(type='filepath')
39
+ ],
40
+ outputs="text",
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ gradio_app.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ torch