fedihch commited on
Commit
1880d16
·
1 Parent(s): 07a8821

first-commit

Browse files
Files changed (1) hide show
  1. classifier_app.py +36 -0
classifier_app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def classify(input_img):
5
+
6
+ from transformers import (
7
+ AutoModelForSequenceClassification,
8
+ LayoutLMv2FeatureExtractor,
9
+ LayoutLMv2Tokenizer,
10
+ LayoutLMv2Processor,
11
+ )
12
+
13
+ model = AutoModelForSequenceClassification.from_pretrained(
14
+ "fedihch/InvoiceReceiptClassifier"
15
+ )
16
+ feature_extractor = LayoutLMv2FeatureExtractor()
17
+ tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")
18
+ processor = LayoutLMv2Processor(feature_extractor, tokenizer)
19
+ encoded_inputs = processor(input_img, return_tensors="pt")
20
+ for k, v in encoded_inputs.items():
21
+ encoded_inputs[k] = v.to(model.device)
22
+ outputs = model(**encoded_inputs)
23
+ logits = outputs.logits
24
+ predicted_class_idx = logits.argmax(-1).item()
25
+
26
+ id2label = {0: "invoice", 1: "receipt"}
27
+ return id2label[predicted_class_idx]
28
+
29
+
30
+ demo = gr.Interface(
31
+ fn=classify,
32
+ inputs=gr.Image(shape=(200, 200)),
33
+ outputs="text",
34
+ allow_flagging="manual",
35
+ )
36
+ demo.launch(share=True)