File size: 3,785 Bytes
545dc78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d7006
a1df1d3
545dc78
 
 
01d7006
 
 
 
 
 
 
 
 
 
 
 
 
545dc78
 
 
01d7006
545dc78
 
 
 
 
 
 
01d7006
 
 
545dc78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
"""Demo.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1Icb8zeoaudyTDOKM1QySNay1cXzltRAp
"""

import gradio as gr
from PIL import Image
import re

import torch
import torch.nn as nn
from warnings import simplefilter

simplefilter('ignore')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Seting up the model
from transformers import  DonutProcessor, VisionEncoderDecoderModel

print('Loading the base model ....')
base_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-recognition')
base_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-recognition')
print('Loading complete')

print('Loading the latence optimized model ....')
optimized_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
optimized_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
print('Loading complete')

print('Loading the performance optimized model ....')
performance_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-cheques_1920')
performance_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-cheques_1920')
print('Loading complete')

models = {'baseline': base_model ,
         'performance': performance_model ,
         'latence': optimized_model}

processor = {'baseline': base_processor ,
         'performance': performance_processor ,
         'latence': optimized_processor}

# setting


def process_image(image , mode = 'baseline' ):
    """ Function that takes an image and perform an OCR using the model DonUT via the task document
    parsing

    parameters
    __________
    image : a machine readable image of class PIL or numpy"""

    model = models[mode]
    processor = processor[mode]
    d_type = torch.float32 

    model.to(device)
    model.eval()


    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    pixel_values = processor(image, return_tensors="pt").pixel_values

    outputs = model.generate(
        pixel_values.to(device , dtype  = d_type),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    output = processor.token2json(sequence)

    return output


def image_classifier(image , mode):
    return process_image(image , mode)



examples_list = [['./test_images/test_0.jpg' ,"base"] ,
                 ['./test_images/test_1.jpg','base'],
                 ['./test_images/test_2.jpg' ,"base"],
                 ['./test_images/test_3.jpg','base'],
                 ['./test_images/test_4.jpg','base'],
                 ['./test_images/test_5.jpg' ,"base"],
                 ['./test_images/test_6.jpg' ,"base"],
                 ['./test_images/test_7.jpg','base'],
                 ['./test_images/test_8.jpg','base'],
                 ['./test_images/test_9.jpg','base']
                 ]

demo = gr.Interface(fn=image_classifier, inputs=["image",
                                                 gr.Radio(["base" , "optimized"], label="mode")],
                     outputs="text",
                    examples = examples_list )

demo.launch(share = True , debug = True)