th0mascat commited on
Commit
7ff20b3
1 Parent(s): 08706e3

files added

Browse files
Pipfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ pytesseract = "*"
8
+ fastapi = "*"
9
+ opencv-python = "*"
10
+ mltu = "*"
11
+ python-multipart = "*"
12
+ uvicorn = "*"
13
+ symspellpy = "*"
14
+ textblob = "*"
15
+ swig = "*"
16
+ happytransformer = "*"
17
+
18
+ [dev-packages]
19
+
20
+ [requires]
21
+ python_version = "3.11"
Pipfile.lock ADDED
The diff for this file is too large to render. See raw diff
 
__pycache__/handwritting_fastapi.cpython-311.pyc ADDED
Binary file (6.15 kB). View file
 
configs.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 32
2
+ height: 96
3
+ learning_rate: 0.0005
4
+ max_text_length: 73
5
+ model_path: ./models/model.onnx
6
+ train_epochs: 1000
7
+ train_workers: 20
8
+ vocab: '''3.FR20JWIe8CyBowxTV5rgOYQ,ipPcqDGnMAK(Eb6)fH:"9LlUt;jsz m4&1#kZ-adNhvu7!S?'
9
+ width: 1408
dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ COPY . .
4
+
5
+ WORKDIR /
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /requirements.txt
8
+
9
+
10
+ CMD ["uvicorn", "handwritting_fastapi:app", "--host", "0.0.0.0", "--port", "7860"]
handwritting_fastapi.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import io
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import pytesseract
7
+
8
+ from fastapi import FastAPI, UploadFile, File
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+
11
+ from mltu.inferenceModel import OnnxInferenceModel
12
+ from mltu.utils.text_utils import ctc_decoder
13
+ from mltu.transformers import ImageResizer
14
+ from mltu.configs import BaseModelConfigs
15
+
16
+ from textblob import TextBlob
17
+ from happytransformer import HappyTextToText, TTSettings
18
+
19
+ configs = BaseModelConfigs.load("./configs.yaml")
20
+
21
+ happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
22
+
23
+ beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100)
24
+
25
+ app = FastAPI()
26
+
27
+ origins = ["*"]
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=origins,
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+
38
+ class ImageToWordModel(OnnxInferenceModel):
39
+ def __init__(self, char_list, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ self.char_list = char_list
42
+
43
+ def predict(self, image: np.ndarray):
44
+ image = ImageResizer.resize_maintaining_aspect_ratio(
45
+ image, *self.input_shape[:2][::-1]
46
+ )
47
+
48
+ image_pred = np.expand_dims(image, axis=0).astype(np.float32)
49
+
50
+ preds = self.model.run(None, {self.input_name: image_pred})[0]
51
+
52
+ text = ctc_decoder(preds, self.char_list)[0]
53
+
54
+ return text
55
+
56
+
57
+ model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
58
+ extracted_text = ""
59
+
60
+ @app.post("/extract_handwritten_text/")
61
+ async def predict_text(image: UploadFile):
62
+ global extracted_text
63
+ # Read the uploaded image
64
+ img = await image.read()
65
+ nparr = np.frombuffer(img, np.uint8)
66
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
67
+
68
+ # Make a prediction
69
+ extracted_text = model.predict(img)
70
+ corrected_text = happy_tt.generate_text(extracted_text, beam_settings)
71
+
72
+ return {"text": extracted_text, "corrected_text": corrected_text}
73
+
74
+
75
+ @app.post("/extract_text/")
76
+ async def extract_text_from_image(image: UploadFile):
77
+ global extracted_text
78
+ # Check if the uploaded file is an image
79
+ if image.content_type.startswith("image/"):
80
+ # Read the image from the uploaded file
81
+ image_bytes = await image.read()
82
+ img = Image.open(io.BytesIO(image_bytes))
83
+
84
+ # Perform OCR on the image
85
+ extracted_text = pytesseract.image_to_string(img)
86
+ corrected_text = happy_tt.generate_text(extracted_text, beam_settings)
87
+
88
+ return {"text": extracted_text, "corrected_text": corrected_text}
89
+ else:
90
+ return {"error": "Invalid file format. Please upload an image."}
91
+
92
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
93
+ from pydantic import BaseModel
94
+
95
+ tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large")
96
+ chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large")
97
+
98
+ class ChatPrompt(BaseModel):
99
+ prompt: str
100
+
101
+ @app.post("/chat_prompt/")
102
+ async def chat_prompt(request: ChatPrompt):
103
+ global extracted_text
104
+ input_text = request.prompt + ": " + extracted_text
105
+ print(input_text)
106
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
107
+ outputs = chatModel.generate(input_ids, max_length=256)
108
+ edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
+
110
+ return {"edited_text": edited_text}
models/mnist_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e92accaee40ef1e16c8822fb60686fd5c26183b3cf833446c135701bd61344c2
3
+ size 7435768
models/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3447f46751b438e16020536605597a8508d7cb6f98118b62d3f21258e4be83aa
3
+ size 9718812
requirements.txt ADDED
Binary file (5.48 kB). View file