Spaces:
Runtime error
Runtime error
JustinLin610
commited on
Commit
•
582f2a6
1
Parent(s):
43773d5
add requirements
Browse files- app.py +30 -180
- requirements.txt +4 -0
app.py
CHANGED
@@ -1,184 +1,35 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
import pandas as pd
|
|
|
3 |
|
4 |
-
os.system('cd fairseq;'
|
5 |
-
'pip install ./; cd ..')
|
6 |
-
|
7 |
-
os.system('cd ezocr;'
|
8 |
-
'pip install .; cd ..')
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import numpy as np
|
12 |
-
from fairseq import utils, tasks
|
13 |
-
from fairseq import checkpoint_utils
|
14 |
-
from utils.eval_utils import eval_step
|
15 |
-
from data.mm_data.ocr_dataset import ocr_resize
|
16 |
-
from tasks.mm_tasks.ocr import OcrTask
|
17 |
-
from PIL import Image, ImageDraw
|
18 |
-
from torchvision import transforms
|
19 |
-
from typing import List, Tuple
|
20 |
-
import cv2
|
21 |
-
from easyocrlite import ReaderLite
|
22 |
import gradio as gr
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
FourPoint = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]
|
42 |
-
|
43 |
-
|
44 |
-
reader = ReaderLite(gpu=True)
|
45 |
-
overrides={"eval_cider": False, "beam": 5, "max_len_b": 64, "patch_image_size": 480,
|
46 |
-
"orig_patch_image_size": 224, "interpolate_position": True,
|
47 |
-
"no_repeat_ngram_size": 0, "seed": 42}
|
48 |
-
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
49 |
-
utils.split_paths('checkpoints/ocr_general_clean.pt'),
|
50 |
-
arg_overrides=overrides
|
51 |
-
)
|
52 |
-
|
53 |
-
# Move models to GPU
|
54 |
-
for model in models:
|
55 |
-
model.eval()
|
56 |
-
if use_fp16:
|
57 |
-
model.half()
|
58 |
-
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
59 |
-
model.cuda()
|
60 |
-
model.prepare_for_inference_(cfg)
|
61 |
-
|
62 |
-
# Initialize generator
|
63 |
-
generator = task.build_generator(models, cfg.generation)
|
64 |
-
|
65 |
-
bos_item = torch.LongTensor([task.src_dict.bos()])
|
66 |
-
eos_item = torch.LongTensor([task.src_dict.eos()])
|
67 |
-
pad_idx = task.src_dict.pad()
|
68 |
-
|
69 |
-
|
70 |
-
def four_point_transform(image: np.ndarray, rect: FourPoint) -> np.ndarray:
|
71 |
-
(tl, tr, br, bl) = rect
|
72 |
-
|
73 |
-
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
74 |
-
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
75 |
-
maxWidth = max(int(widthA), int(widthB))
|
76 |
-
|
77 |
-
# compute the height of the new image, which will be the
|
78 |
-
# maximum distance between the top-right and bottom-right
|
79 |
-
# y-coordinates or the top-left and bottom-left y-coordinates
|
80 |
-
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
81 |
-
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
82 |
-
maxHeight = max(int(heightA), int(heightB))
|
83 |
-
|
84 |
-
dst = np.array(
|
85 |
-
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
|
86 |
-
dtype="float32",
|
87 |
-
)
|
88 |
-
|
89 |
-
# compute the perspective transform matrix and then apply it
|
90 |
-
M = cv2.getPerspectiveTransform(rect, dst)
|
91 |
-
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
92 |
-
|
93 |
-
return warped
|
94 |
-
|
95 |
-
|
96 |
-
def get_images(img: str, reader: ReaderLite, **kwargs):
|
97 |
-
results = reader.process(img, **kwargs)
|
98 |
-
return results
|
99 |
-
|
100 |
-
|
101 |
-
def draw_boxes(image, bounds, color='red', width=4):
|
102 |
-
draw = ImageDraw.Draw(image)
|
103 |
-
for i, bound in enumerate(bounds):
|
104 |
-
p0, p1, p2, p3 = bound
|
105 |
-
draw.text((p0[0]+5, p0[1]+5), str(i+1), fill=color, align='center')
|
106 |
-
draw.line([*p0, *p1, *p2, *p3, *p0], fill=color, width=width)
|
107 |
-
return image
|
108 |
-
|
109 |
-
|
110 |
-
def encode_text(text, length=None, append_bos=False, append_eos=False):
|
111 |
-
s = task.tgt_dict.encode_line(
|
112 |
-
line=task.bpe.encode(text),
|
113 |
-
add_if_not_exist=False,
|
114 |
-
append_eos=False
|
115 |
-
).long()
|
116 |
-
if length is not None:
|
117 |
-
s = s[:length]
|
118 |
-
if append_bos:
|
119 |
-
s = torch.cat([bos_item, s])
|
120 |
-
if append_eos:
|
121 |
-
s = torch.cat([s, eos_item])
|
122 |
-
return s
|
123 |
-
|
124 |
-
|
125 |
-
def patch_resize_transform(patch_image_size=480, is_document=False):
|
126 |
-
_patch_resize_transform = transforms.Compose(
|
127 |
-
[
|
128 |
-
lambda image: ocr_resize(
|
129 |
-
image, patch_image_size, is_document=is_document, split='test',
|
130 |
-
),
|
131 |
-
transforms.ToTensor(),
|
132 |
-
transforms.Normalize(mean=mean, std=std),
|
133 |
-
]
|
134 |
-
)
|
135 |
-
|
136 |
-
return _patch_resize_transform
|
137 |
-
|
138 |
-
|
139 |
-
# Construct input for caption task
|
140 |
-
def construct_sample(image: Image, patch_image_size=480, is_document=False):
|
141 |
-
patch_image = patch_resize_transform(patch_image_size, is_document=is_document)(image).unsqueeze(0)
|
142 |
-
patch_mask = torch.tensor([True])
|
143 |
-
src_text = encode_text("图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
|
144 |
-
src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
|
145 |
-
sample = {
|
146 |
-
"id":np.array(['42']),
|
147 |
-
"net_input": {
|
148 |
-
"src_tokens": src_text,
|
149 |
-
"src_lengths": src_length,
|
150 |
-
"patch_images": patch_image,
|
151 |
-
"patch_masks": patch_mask,
|
152 |
-
},
|
153 |
-
"target": None
|
154 |
}
|
155 |
-
return sample
|
156 |
-
|
157 |
-
|
158 |
-
# Function to turn FP32 to FP16
|
159 |
-
def apply_half(t):
|
160 |
-
if t.dtype is torch.float32:
|
161 |
-
return t.to(dtype=torch.half)
|
162 |
-
return t
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
box_list, image_list = zip(*results)
|
170 |
-
draw_boxes(out_img, box_list)
|
171 |
-
|
172 |
-
ocr_result = []
|
173 |
-
for i, (box, image) in enumerate(zip(box_list, image_list)):
|
174 |
-
image = Image.fromarray(image)
|
175 |
-
sample = construct_sample(image, cfg.task.patch_image_size)
|
176 |
-
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
177 |
-
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
|
178 |
-
|
179 |
-
with torch.no_grad():
|
180 |
-
result, scores = eval_step(task, generator, models, sample)
|
181 |
-
ocr_result.append([str(i+1), result[0]['ocr'].replace(' ', '')])
|
182 |
|
183 |
result = pd.DataFrame(ocr_result, columns=['Box ID', 'Text'])
|
184 |
|
@@ -193,10 +44,9 @@ description = "Gradio Demo for Chinese OCR based on OFA-Base. "\
|
|
193 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
194 |
"Repo</a></p> "
|
195 |
examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
|
196 |
-
|
197 |
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
|
198 |
outputs=[gr.outputs.Image(type='pil', label='Image'),
|
199 |
gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
|
200 |
-
title=title, description=description, article=article
|
201 |
-
io.launch()
|
202 |
-
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
from io import BytesIO
|
4 |
import pandas as pd
|
5 |
+
from PIL import Image
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
+
import requests
|
9 |
+
|
10 |
+
|
11 |
+
def ocr(image):
|
12 |
+
|
13 |
+
image = Image.open(image)
|
14 |
+
img_buffer = BytesIO()
|
15 |
+
image.save(img_buffer, format=image.format)
|
16 |
+
byte_data = img_buffer.getvalue()
|
17 |
+
base64_bytes = base64.b64encode(byte_data) # bytes
|
18 |
+
base64_str = base64_bytes.decode()
|
19 |
+
url = "https://www.modelscope.cn/api/v1/studio/damo/ofa_ocr_pipeline/gradio/api/predict/"
|
20 |
+
payload = json.dumps({
|
21 |
+
"data": [f"data:image/jpeg;base64,{base64_str}"],
|
22 |
+
"dataType": ["image"]
|
23 |
+
})
|
24 |
+
headers = {
|
25 |
+
'Content-Type': 'application/json'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
29 |
+
jobj = json.loads(response.text)
|
30 |
+
out_img_base64 = jobj['Data']['data'][0].replace('data:image/png;base64,','')
|
31 |
+
out_img = Image.open(BytesIO(base64.urlsafe_b64decode(out_img_base64)))
|
32 |
+
ocr_result = jobj['Data']['data'][1]['data']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
result = pd.DataFrame(ocr_result, columns=['Box ID', 'Text'])
|
35 |
|
|
|
44 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
45 |
"Repo</a></p> "
|
46 |
examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
|
47 |
+
['qiaodaima.png'], ['xsd.jpg']]
|
48 |
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
|
49 |
outputs=[gr.outputs.Image(type='pil', label='Image'),
|
50 |
gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
|
51 |
+
title=title, description=description, article=article)
|
52 |
+
io.launch()
|
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
pillow
|
3 |
+
pandas
|
4 |
+
requests
|