Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as T
|
3 |
+
from PIL import Image
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
from transformers import VisionEncoderDecoderModel
|
6 |
+
from fastapi import FastAPI, File, UploadFile
|
7 |
+
from fastapi.responses import HTMLResponse
|
8 |
+
from fastapi.staticfiles import StaticFiles
|
9 |
+
from fastapi.templating import Jinja2Templates
|
10 |
+
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
from contextlib import contextmanager
|
14 |
+
from transformers import MBartTokenizer, ViTImageProcessor, XLMRobertaTokenizer
|
15 |
+
from transformers import ProcessorMixin
|
16 |
+
|
17 |
+
|
18 |
+
class CustomOCRProcessor(ProcessorMixin):
|
19 |
+
attributes = ["image_processor", "tokenizer"]
|
20 |
+
image_processor_class = "AutoImageProcessor"
|
21 |
+
tokenizer_class = "AutoTokenizer"
|
22 |
+
|
23 |
+
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
24 |
+
if "feature_extractor" in kwargs:
|
25 |
+
warnings.warn(
|
26 |
+
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
|
27 |
+
" instead.",
|
28 |
+
FutureWarning,
|
29 |
+
)
|
30 |
+
feature_extractor = kwargs.pop("feature_extractor")
|
31 |
+
|
32 |
+
image_processor = image_processor if image_processor is not None else feature_extractor
|
33 |
+
if image_processor is None:
|
34 |
+
raise ValueError("You need to specify an `image_processor`.")
|
35 |
+
if tokenizer is None:
|
36 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
37 |
+
|
38 |
+
super().__init__(image_processor, tokenizer)
|
39 |
+
self.current_processor = self.image_processor
|
40 |
+
self._in_target_context_manager = False
|
41 |
+
|
42 |
+
def __call__(self, *args, **kwargs):
|
43 |
+
# For backward compatibility
|
44 |
+
if self._in_target_context_manager:
|
45 |
+
return self.current_processor(*args, **kwargs)
|
46 |
+
|
47 |
+
images = kwargs.pop("images", None)
|
48 |
+
text = kwargs.pop("text", None)
|
49 |
+
if len(args) > 0:
|
50 |
+
images = args[0]
|
51 |
+
args = args[1:]
|
52 |
+
|
53 |
+
if images is None and text is None:
|
54 |
+
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
55 |
+
|
56 |
+
if images is not None:
|
57 |
+
inputs = self.image_processor(images, *args, **kwargs)
|
58 |
+
if text is not None:
|
59 |
+
encodings = self.tokenizer(text, **kwargs)
|
60 |
+
|
61 |
+
if text is None:
|
62 |
+
return inputs
|
63 |
+
elif images is None:
|
64 |
+
return encodings
|
65 |
+
else:
|
66 |
+
inputs["labels"] = encodings["input_ids"]
|
67 |
+
return inputs
|
68 |
+
|
69 |
+
def batch_decode(self, *args, **kwargs):
|
70 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
71 |
+
|
72 |
+
def decode(self, *args, **kwargs):
|
73 |
+
return self.tokenizer.decode(*args, **kwargs)
|
74 |
+
|
75 |
+
|
76 |
+
image_processor = ViTImageProcessor.from_pretrained(
|
77 |
+
'microsoft/swin-base-patch4-window12-384-in22k'
|
78 |
+
)
|
79 |
+
tokenizer = MBartTokenizer.from_pretrained(
|
80 |
+
'facebook/mbart-large-50'
|
81 |
+
)
|
82 |
+
processortext2 = CustomOCRProcessor(image_processor,tokenizer)
|
83 |
+
|
84 |
+
|
85 |
+
app = FastAPI()
|
86 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
87 |
+
templates = Jinja2Templates(directory="templates")
|
88 |
+
|
89 |
+
# Download and load the model
|
90 |
+
model2 = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-urdu",use_auth_token=True).to(device)
|
91 |
+
|
92 |
+
|
93 |
+
@app.get("/", response_class=HTMLResponse)
|
94 |
+
async def root():
|
95 |
+
return templates.TemplateResponse("index.html", {"request": None})
|
96 |
+
|
97 |
+
@app.post("/upload/", response_class=HTMLResponse)
|
98 |
+
async def upload_image(image: UploadFile = File(...)):
|
99 |
+
# Preprocess image
|
100 |
+
img = Image.open(image.file).convert("RGB")
|
101 |
+
pixel_values = processortext(img.convert("RGB"), return_tensors="pt").pixel_values
|
102 |
+
|
103 |
+
# Run the model
|
104 |
+
with torch.no_grad():
|
105 |
+
generated_ids = model2.generate(img_tensor)
|
106 |
+
|
107 |
+
# Extract OCR result
|
108 |
+
result = processortext.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
109 |
+
|
110 |
+
return {"result": result}
|