Edgar404 commited on
Commit
01d7006
1 Parent(s): a1df1d3

update the app.py to include a new model

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -26,15 +26,28 @@ base_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-rec
26
  base_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-recognition')
27
  print('Loading complete')
28
 
29
- print('Loading the optimized model ....')
30
  optimized_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
31
  optimized_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
32
  print('Loading complete')
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # setting
35
 
36
 
37
- def process_image(image , mode = 'optimized' ):
38
  """ Function that takes an image and perform an OCR using the model DonUT via the task document
39
  parsing
40
 
@@ -42,9 +55,9 @@ def process_image(image , mode = 'optimized' ):
42
  __________
43
  image : a machine readable image of class PIL or numpy"""
44
 
45
- model = optimized_model if mode == 'optimized' else base_model
46
- processor = optimized_processor if mode == 'optimized' else base_processor
47
- d_type = torch.bfloat16 if ((mode == 'optimized') & (device =='cuda')) else torch.float32
48
 
49
  model.to(device)
50
  model.eval()
 
26
  base_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-recognition')
27
  print('Loading complete')
28
 
29
+ print('Loading the latence optimized model ....')
30
  optimized_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
31
  optimized_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
32
  print('Loading complete')
33
 
34
+ print('Loading the performance optimized model ....')
35
+ performance_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-cheques_1920')
36
+ performance_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-cheques_1920')
37
+ print('Loading complete')
38
+
39
+ models = {'baseline': base_model ,
40
+ 'performance': performance_model ,
41
+ 'latence': optimized_model}
42
+
43
+ processor = {'baseline': base_processor ,
44
+ 'performance': performance_processor ,
45
+ 'latence': optimized_processor}
46
+
47
  # setting
48
 
49
 
50
+ def process_image(image , mode = 'baseline' ):
51
  """ Function that takes an image and perform an OCR using the model DonUT via the task document
52
  parsing
53
 
 
55
  __________
56
  image : a machine readable image of class PIL or numpy"""
57
 
58
+ model = models[mode]
59
+ processor = processor[mode]
60
+ d_type = torch.float32
61
 
62
  model.to(device)
63
  model.eval()