import torch import torchvision.transforms as T from PIL import Image from huggingface_hub import hf_hub_download from transformers import VisionEncoderDecoderModel from fastapi import FastAPI, File, UploadFile from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import warnings from contextlib import contextmanager from transformers import MBartTokenizer, ViTImageProcessor, XLMRobertaTokenizer from transformers import ProcessorMixin class CustomOCRProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None, **kwargs): if "feature_extractor" in kwargs: warnings.warn( "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" " instead.", FutureWarning, ) feature_extractor = kwargs.pop("feature_extractor") image_processor = image_processor if image_processor is not None else feature_extractor if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor self._in_target_context_manager = False def __call__(self, *args, **kwargs): # For backward compatibility if self._in_target_context_manager: return self.current_processor(*args, **kwargs) images = kwargs.pop("images", None) text = kwargs.pop("text", None) if len(args) > 0: images = args[0] args = args[1:] if images is None and text is None: raise ValueError("You need to specify either an `images` or `text` input to process.") if images is not None: inputs = self.image_processor(images, *args, **kwargs) if text is not None: encodings = self.tokenizer(text, **kwargs) if text is None: return inputs elif images is None: return encodings else: inputs["labels"] = encodings["input_ids"] return inputs def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) image_processor = ViTImageProcessor.from_pretrained( 'microsoft/swin-base-patch4-window12-384-in22k' ) tokenizer = MBartTokenizer.from_pretrained( 'facebook/mbart-large-50' ) processortext2 = CustomOCRProcessor(image_processor,tokenizer) app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Download and load the model model2 = VisionEncoderDecoderModel.from_pretrained("musadac/vilanocr-single-urdu",use_auth_token=True).to(device) @app.get("/", response_class=HTMLResponse) async def root(): return templates.TemplateResponse("index.html", {"request": None}) @app.post("/upload/", response_class=HTMLResponse) async def upload_image(image: UploadFile = File(...)): # Preprocess image img = Image.open(image.file).convert("RGB") pixel_values = processortext(img.convert("RGB"), return_tensors="pt").pixel_values # Run the model with torch.no_grad(): generated_ids = model2.generate(img_tensor) # Extract OCR result result = processortext.batch_decode(generated_ids, skip_special_tokens=True)[0] return {"result": result}