Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,170 +1,167 @@
|
|
1 |
-
import os
|
2 |
-
os.system('pip install -U transformers==4.44.2')
|
3 |
-
import sys
|
4 |
-
import shutil
|
5 |
-
import torch
|
6 |
-
import base64
|
7 |
-
import argparse
|
8 |
-
import gradio as gr
|
9 |
-
import numpy as np
|
10 |
-
from PIL import Image
|
11 |
-
from huggingface_hub import snapshot_download
|
12 |
-
import spaces
|
13 |
-
|
14 |
-
# == download weights ==
|
15 |
-
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
|
16 |
-
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
|
17 |
-
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
|
18 |
-
os.system("ls -l models/unimernet_tiny")
|
19 |
-
os.system("ls -l models/unimernet_small")
|
20 |
-
os.system("ls -l models/unimernet_base")
|
21 |
-
# == download weights ==
|
22 |
-
|
23 |
-
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
|
24 |
-
from unimernet.common.config import Config
|
25 |
-
import unimernet.tasks as tasks
|
26 |
-
from unimernet.processors import load_processor
|
27 |
-
|
28 |
-
|
29 |
-
template_html = """<!DOCTYPE html>
|
30 |
-
<html lang="en" data-lt-installed="true"><head>
|
31 |
-
<meta charset="UTF-8">
|
32 |
-
<title>Title</title>
|
33 |
-
<script>
|
34 |
-
const text =
|
35 |
-
</script>
|
36 |
-
<style>
|
37 |
-
#content {
|
38 |
-
max-width: 800px;
|
39 |
-
margin: auto;
|
40 |
-
}
|
41 |
-
</style>
|
42 |
-
<script>
|
43 |
-
let script = document.createElement('script');
|
44 |
-
script.src = "https://cdn.jsdelivr.net/npm/[email protected]/es5/bundle.js";
|
45 |
-
document.head.append(script);
|
46 |
-
|
47 |
-
script.onload = function() {
|
48 |
-
const isLoaded = window.loadMathJax();
|
49 |
-
if (isLoaded) {
|
50 |
-
console.log('Styles loaded!')
|
51 |
-
}
|
52 |
-
|
53 |
-
const el = window.document.getElementById('content-text');
|
54 |
-
if (el) {
|
55 |
-
const options = {
|
56 |
-
htmlTags: true
|
57 |
-
};
|
58 |
-
const html = window.render(text, options);
|
59 |
-
el.outerHTML = html;
|
60 |
-
}
|
61 |
-
};
|
62 |
-
</script>
|
63 |
-
</head>
|
64 |
-
<body>
|
65 |
-
<div id="content"><div id="content-text"></div></div>
|
66 |
-
</body>
|
67 |
-
</html>
|
68 |
-
"""
|
69 |
-
|
70 |
-
def latex2html(latex_code):
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
gt=
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
return
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
)
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
|
168 |
-
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
|
169 |
-
|
170 |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|
|
|
1 |
+
import os
|
2 |
+
os.system('pip install -U transformers==4.44.2')
|
3 |
+
import sys
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import base64
|
7 |
+
import argparse
|
8 |
+
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
import spaces
|
13 |
+
|
14 |
+
# == download weights ==
|
15 |
+
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
|
16 |
+
small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
|
17 |
+
base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
|
18 |
+
os.system("ls -l models/unimernet_tiny")
|
19 |
+
os.system("ls -l models/unimernet_small")
|
20 |
+
os.system("ls -l models/unimernet_base")
|
21 |
+
# == download weights ==
|
22 |
+
|
23 |
+
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
|
24 |
+
from unimernet.common.config import Config
|
25 |
+
import unimernet.tasks as tasks
|
26 |
+
from unimernet.processors import load_processor
|
27 |
+
|
28 |
+
|
29 |
+
template_html = """<!DOCTYPE html>
|
30 |
+
<html lang="en" data-lt-installed="true"><head>
|
31 |
+
<meta charset="UTF-8">
|
32 |
+
<title>Title</title>
|
33 |
+
<script>
|
34 |
+
const text =
|
35 |
+
</script>
|
36 |
+
<style>
|
37 |
+
#content {
|
38 |
+
max-width: 800px;
|
39 |
+
margin: auto;
|
40 |
+
}
|
41 |
+
</style>
|
42 |
+
<script>
|
43 |
+
let script = document.createElement('script');
|
44 |
+
script.src = "https://cdn.jsdelivr.net/npm/[email protected]/es5/bundle.js";
|
45 |
+
document.head.append(script);
|
46 |
+
|
47 |
+
script.onload = function() {
|
48 |
+
const isLoaded = window.loadMathJax();
|
49 |
+
if (isLoaded) {
|
50 |
+
console.log('Styles loaded!')
|
51 |
+
}
|
52 |
+
|
53 |
+
const el = window.document.getElementById('content-text');
|
54 |
+
if (el) {
|
55 |
+
const options = {
|
56 |
+
htmlTags: true
|
57 |
+
};
|
58 |
+
const html = window.render(text, options);
|
59 |
+
el.outerHTML = html;
|
60 |
+
}
|
61 |
+
};
|
62 |
+
</script>
|
63 |
+
</head>
|
64 |
+
<body>
|
65 |
+
<div id="content"><div id="content-text"></div></div>
|
66 |
+
</body>
|
67 |
+
</html>
|
68 |
+
"""
|
69 |
+
|
70 |
+
def latex2html(latex_code):
|
71 |
+
latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
|
72 |
+
|
73 |
+
latex_code = latex_code.replace('"', '``').replace('$', '')
|
74 |
+
|
75 |
+
latex_code_list = latex_code.split('\n')
|
76 |
+
gt= ''
|
77 |
+
for out in latex_code_list:
|
78 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
79 |
+
|
80 |
+
gt = gt[:-2]
|
81 |
+
gt = "\\[" + gt + "\\]"
|
82 |
+
|
83 |
+
lines = template_html.split("const text =")
|
84 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
85 |
+
return new_web
|
86 |
+
|
87 |
+
def load_model_and_processor(cfg_path):
|
88 |
+
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
89 |
+
cfg = Config(args)
|
90 |
+
task = tasks.setup_task(cfg)
|
91 |
+
model = task.build_model(cfg)
|
92 |
+
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
93 |
+
return model, vis_processor
|
94 |
+
|
95 |
+
@spaces.GPU
|
96 |
+
def recognize_image(input_img, model_type):
|
97 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
98 |
+
if model_type == "base":
|
99 |
+
model = model_base.to(device)
|
100 |
+
elif model_type == "small":
|
101 |
+
model = model_small.to(device)
|
102 |
+
else:
|
103 |
+
model = model_tiny.to(device)
|
104 |
+
|
105 |
+
if len(input_img.shape) == 3:
|
106 |
+
input_img = input_img[:, :, ::-1].copy()
|
107 |
+
|
108 |
+
img = Image.fromarray(input_img)
|
109 |
+
image = vis_processor(img).unsqueeze(0).to(device)
|
110 |
+
output = model.generate({"image": image})
|
111 |
+
latex_code = output["pred_str"][0]
|
112 |
+
html_code = latex2html(latex_code)
|
113 |
+
encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
|
114 |
+
iframe_src = f"data:text/html;base64,{encoded_html}"
|
115 |
+
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
|
116 |
+
return latex_code, iframe
|
117 |
+
|
118 |
+
def gradio_reset():
|
119 |
+
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
root_path = os.path.abspath(os.getcwd())
|
124 |
+
# == load model ==
|
125 |
+
print("load tiny model ...")
|
126 |
+
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
|
127 |
+
print("load small model ...")
|
128 |
+
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
|
129 |
+
print("load base model ...")
|
130 |
+
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
|
131 |
+
print("== load all models done. ==")
|
132 |
+
# == load model ==
|
133 |
+
|
134 |
+
with open("header.html", "r") as file:
|
135 |
+
header = file.read()
|
136 |
+
with gr.Blocks() as demo:
|
137 |
+
gr.HTML(header)
|
138 |
+
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
model_type = gr.Radio(
|
142 |
+
choices=["tiny", "small", "base"],
|
143 |
+
value="tiny",
|
144 |
+
label="Model Type",
|
145 |
+
interactive=True,
|
146 |
+
)
|
147 |
+
input_img = gr.Image(label=" ", interactive=True)
|
148 |
+
with gr.Row():
|
149 |
+
clear = gr.Button("Clear")
|
150 |
+
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
|
151 |
+
|
152 |
+
with gr.Accordion("Examples:"):
|
153 |
+
example_root = os.path.join(os.path.dirname(__file__), "examples")
|
154 |
+
gr.Examples(
|
155 |
+
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
|
156 |
+
_.endswith("png")],
|
157 |
+
inputs=input_img,
|
158 |
+
)
|
159 |
+
with gr.Column():
|
160 |
+
gr.Button(value="Predict Result:", interactive=False)
|
161 |
+
pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
|
162 |
+
output_html = gr.HTML(label='Output Html')
|
163 |
+
|
164 |
+
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
|
165 |
+
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
|
166 |
+
|
|
|
|
|
|
|
167 |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|