|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- alakxender/dhivehi-image-text |
|
language: |
|
- dv |
|
base_model: |
|
- facebook/deit-base-distilled-patch16-384 |
|
--- |
|
# TrOCR Finetuned for Dhivehi Text Recognition |
|
|
|
A TrOCR model finetuned for Dhivehi (Divehi/Maldivian) text recognition using DeiT base encoder and BERT decoder. |
|
|
|
## Model Details |
|
|
|
- Base models: |
|
- Encoder: facebook/deit-base-distilled-patch16-384 |
|
- Decoder: alakxender/bert-base-dv |
|
- Training data: 10k samples with 90/10 train/test split |
|
- Input size: 384x384 pixels |
|
- Beam search parameters: |
|
- max_length: 64 |
|
- num_beams: 4 |
|
- early_stopping: True |
|
- length_penalty: 2.0 |
|
- no_repeat_ngram_size: 3 |
|
|
|
## Training |
|
|
|
The model was trained with: |
|
- 7 epochs |
|
- Batch size: 8 |
|
- Learning rate: 4e-5 |
|
- FP16 mixed precision |
|
- Training augmentations: |
|
- Elastic transform (α=8.0, σ=5.0) |
|
- Gaussian blur (kernel size=(5,9), σ=(0.1,5)) |
|
- Resize (384x384) |
|
- Normalization ([0.5,0.5,0.5], [0.5,0.5,0.5]) |
|
|
|
## Usage |
|
|
|
```python |
|
from PIL import Image |
|
import torch |
|
from torchvision import transforms |
|
from transformers import ( |
|
DeiTImageProcessor, |
|
TrOCRProcessor, |
|
VisionEncoderDecoderModel, |
|
AutoTokenizer |
|
) |
|
|
|
class OCRPredictor: |
|
def __init__(self, model_name="alakxender/trocr-dv-diet-base-bert"): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.model = self._load_model(model_name) |
|
self.processor = self._load_processor() |
|
self.transform = self._get_transforms() |
|
|
|
def _load_model(self, model_name): |
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
return model.to(self.device) |
|
|
|
def _load_processor(self): |
|
tokenizer = AutoTokenizer.from_pretrained("alakxender/trocr-dv-diet-base-bert") |
|
image_processor = DeiTImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-384") |
|
return TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer) |
|
|
|
def _get_transforms(self): |
|
return transforms.Compose([ |
|
transforms.Resize((384, 384)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5] * 3, [0.5] * 3) |
|
]) |
|
|
|
def predict(self, image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
pixel_values = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
outputs = self.model.generate( |
|
pixel_values, |
|
max_length=64, |
|
num_beams=4, |
|
early_stopping=True, |
|
length_penalty=2.0, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
return self.processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Usage |
|
predictor = OCRPredictor() |
|
text = predictor.predict("ocr2.png") |
|
print(text) # ތިން މިނިސްޓްރީއެއް ހިންގާ މ.ގްރީން ބިލްޑިންގުގައި މިދިޔަ ބުރާސްފަތި ދުވަހު ހިނގި ބޮޑު އަލިފާނުގެ. |
|
``` |
|
|
|
## Evaluation Results |
|
|
|
```json |
|
[ |
|
{ |
|
"file_name": "data/images/DV01-04/DV01-04_140.jpg", |
|
"predicted_text": "ޤާނޫނުގެ 42 ވަނަ މާއްދާގައި ލާޒިމްކުރާ މި ރިޕޯޓު ތައްޔާރުކޮށް ފޮނުވުމުގެ ޒިންމާއަކީ ޤާނޫނުން އިދާރާގެ އިންފޮމޭޝަން އޮފިސަރު ކުރައްވަންޖެހޭ ކަމެކެވެ .", |
|
"true_text": "ޤާނޫނުގެ 42 ވަނަ މާއްދާގައި ލާޒިމްކުރާ މި ރިޕޯޓު ތައްޔާރުކޮށް ފޮނުވުމުގެ ޒިންމާއަކީ ޤާނޫނުން އިދާރާގެ އިންފޮމޭޝަން އޮފިސަރު ކުރައްވަންޖެހޭ ކަމެކެވެ." |
|
} |
|
] |
|
``` |