wufan commited on
Commit
1d3853f
·
verified ·
1 Parent(s): 6fa1649

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -169
app.py CHANGED
@@ -1,170 +1,167 @@
1
- import os
2
- os.system('pip install -U transformers==4.44.2')
3
- import sys
4
- import shutil
5
- import torch
6
- import base64
7
- import argparse
8
- import gradio as gr
9
- import numpy as np
10
- from PIL import Image
11
- from huggingface_hub import snapshot_download
12
- import spaces
13
-
14
- # == download weights ==
15
- tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
16
- small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
17
- base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
18
- os.system("ls -l models/unimernet_tiny")
19
- os.system("ls -l models/unimernet_small")
20
- os.system("ls -l models/unimernet_base")
21
- # == download weights ==
22
-
23
- sys.path.insert(0, os.path.join(os.getcwd(), ".."))
24
- from unimernet.common.config import Config
25
- import unimernet.tasks as tasks
26
- from unimernet.processors import load_processor
27
-
28
-
29
- template_html = """<!DOCTYPE html>
30
- <html lang="en" data-lt-installed="true"><head>
31
- <meta charset="UTF-8">
32
- <title>Title</title>
33
- <script>
34
- const text =
35
- </script>
36
- <style>
37
- #content {
38
- max-width: 800px;
39
- margin: auto;
40
- }
41
- </style>
42
- <script>
43
- let script = document.createElement('script');
44
- script.src = "https://cdn.jsdelivr.net/npm/[email protected]/es5/bundle.js";
45
- document.head.append(script);
46
-
47
- script.onload = function() {
48
- const isLoaded = window.loadMathJax();
49
- if (isLoaded) {
50
- console.log('Styles loaded!')
51
- }
52
-
53
- const el = window.document.getElementById('content-text');
54
- if (el) {
55
- const options = {
56
- htmlTags: true
57
- };
58
- const html = window.render(text, options);
59
- el.outerHTML = html;
60
- }
61
- };
62
- </script>
63
- </head>
64
- <body>
65
- <div id="content"><div id="content-text"></div></div>
66
- </body>
67
- </html>
68
- """
69
-
70
- def latex2html(latex_code):
71
- right_num = latex_code.count('\\right')
72
- left_num = latex_code.count('\left')
73
-
74
- if right_num != left_num:
75
- latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
76
-
77
- latex_code = latex_code.replace('"', '``').replace('$', '')
78
-
79
- latex_code_list = latex_code.split('\n')
80
- gt= ''
81
- for out in latex_code_list:
82
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
83
-
84
- gt = gt[:-2]
85
-
86
- lines = template_html.split("const text =")
87
- new_web = lines[0] + 'const text =' + gt + lines[1]
88
- return new_web
89
-
90
- def load_model_and_processor(cfg_path):
91
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
92
- cfg = Config(args)
93
- task = tasks.setup_task(cfg)
94
- model = task.build_model(cfg)
95
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
96
- return model, vis_processor
97
-
98
- @spaces.GPU
99
- def recognize_image(input_img, model_type):
100
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
- if model_type == "base":
102
- model = model_base.to(device)
103
- elif model_type == "small":
104
- model = model_small.to(device)
105
- else:
106
- model = model_tiny.to(device)
107
-
108
- if len(input_img.shape) == 3:
109
- input_img = input_img[:, :, ::-1].copy()
110
-
111
- img = Image.fromarray(input_img)
112
- image = vis_processor(img).unsqueeze(0).to(device)
113
- output = model.generate({"image": image})
114
- latex_code = output["pred_str"][0]
115
- html_code = latex2html(latex_code)
116
- encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
117
- iframe_src = f"data:text/html;base64,{encoded_html}"
118
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
119
- return latex_code, iframe
120
-
121
- def gradio_reset():
122
- return gr.update(value=None), gr.update(value=None), gr.update(value=None)
123
-
124
-
125
- if __name__ == "__main__":
126
- root_path = os.path.abspath(os.getcwd())
127
- # == load model ==
128
- print("load tiny model ...")
129
- model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
130
- print("load small model ...")
131
- model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
132
- print("load base model ...")
133
- model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
134
- print("== load all models done. ==")
135
- # == load model ==
136
-
137
- with open("header.html", "r") as file:
138
- header = file.read()
139
- with gr.Blocks() as demo:
140
- gr.HTML(header)
141
-
142
- with gr.Row():
143
- with gr.Column():
144
- model_type = gr.Radio(
145
- choices=["tiny", "small", "base"],
146
- value="tiny",
147
- label="Model Type",
148
- interactive=True,
149
- )
150
- input_img = gr.Image(label=" ", interactive=True)
151
- with gr.Row():
152
- clear = gr.Button("Clear")
153
- predict = gr.Button(value="Recognize", interactive=True, variant="primary")
154
-
155
- with gr.Accordion("Examples:"):
156
- example_root = os.path.join(os.path.dirname(__file__), "examples")
157
- gr.Examples(
158
- examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
159
- _.endswith("png")],
160
- inputs=input_img,
161
- )
162
- with gr.Column():
163
- gr.Button(value="Predict Result:", interactive=False)
164
- pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
165
- output_html = gr.HTML(label='Output Html')
166
-
167
- clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
168
- predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
169
-
170
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
1
+ import os
2
+ os.system('pip install -U transformers==4.44.2')
3
+ import sys
4
+ import shutil
5
+ import torch
6
+ import base64
7
+ import argparse
8
+ import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ from huggingface_hub import snapshot_download
12
+ import spaces
13
+
14
+ # == download weights ==
15
+ tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny', local_dir='./models/unimernet_tiny')
16
+ small_model_dir = snapshot_download('wanderkid/unimernet_small', local_dir='./models/unimernet_small')
17
+ base_model_dir = snapshot_download('wanderkid/unimernet_base', local_dir='./models/unimernet_base')
18
+ os.system("ls -l models/unimernet_tiny")
19
+ os.system("ls -l models/unimernet_small")
20
+ os.system("ls -l models/unimernet_base")
21
+ # == download weights ==
22
+
23
+ sys.path.insert(0, os.path.join(os.getcwd(), ".."))
24
+ from unimernet.common.config import Config
25
+ import unimernet.tasks as tasks
26
+ from unimernet.processors import load_processor
27
+
28
+
29
+ template_html = """<!DOCTYPE html>
30
+ <html lang="en" data-lt-installed="true"><head>
31
+ <meta charset="UTF-8">
32
+ <title>Title</title>
33
+ <script>
34
+ const text =
35
+ </script>
36
+ <style>
37
+ #content {
38
+ max-width: 800px;
39
+ margin: auto;
40
+ }
41
+ </style>
42
+ <script>
43
+ let script = document.createElement('script');
44
+ script.src = "https://cdn.jsdelivr.net/npm/[email protected]/es5/bundle.js";
45
+ document.head.append(script);
46
+
47
+ script.onload = function() {
48
+ const isLoaded = window.loadMathJax();
49
+ if (isLoaded) {
50
+ console.log('Styles loaded!')
51
+ }
52
+
53
+ const el = window.document.getElementById('content-text');
54
+ if (el) {
55
+ const options = {
56
+ htmlTags: true
57
+ };
58
+ const html = window.render(text, options);
59
+ el.outerHTML = html;
60
+ }
61
+ };
62
+ </script>
63
+ </head>
64
+ <body>
65
+ <div id="content"><div id="content-text"></div></div>
66
+ </body>
67
+ </html>
68
+ """
69
+
70
+ def latex2html(latex_code):
71
+ latex_code = latex_code.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
72
+
73
+ latex_code = latex_code.replace('"', '``').replace('$', '')
74
+
75
+ latex_code_list = latex_code.split('\n')
76
+ gt= ''
77
+ for out in latex_code_list:
78
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
79
+
80
+ gt = gt[:-2]
81
+ gt = "\\[" + gt + "\\]"
82
+
83
+ lines = template_html.split("const text =")
84
+ new_web = lines[0] + 'const text =' + gt + lines[1]
85
+ return new_web
86
+
87
+ def load_model_and_processor(cfg_path):
88
+ args = argparse.Namespace(cfg_path=cfg_path, options=None)
89
+ cfg = Config(args)
90
+ task = tasks.setup_task(cfg)
91
+ model = task.build_model(cfg)
92
+ vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
93
+ return model, vis_processor
94
+
95
+ @spaces.GPU
96
+ def recognize_image(input_img, model_type):
97
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
+ if model_type == "base":
99
+ model = model_base.to(device)
100
+ elif model_type == "small":
101
+ model = model_small.to(device)
102
+ else:
103
+ model = model_tiny.to(device)
104
+
105
+ if len(input_img.shape) == 3:
106
+ input_img = input_img[:, :, ::-1].copy()
107
+
108
+ img = Image.fromarray(input_img)
109
+ image = vis_processor(img).unsqueeze(0).to(device)
110
+ output = model.generate({"image": image})
111
+ latex_code = output["pred_str"][0]
112
+ html_code = latex2html(latex_code)
113
+ encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
114
+ iframe_src = f"data:text/html;base64,{encoded_html}"
115
+ iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
116
+ return latex_code, iframe
117
+
118
+ def gradio_reset():
119
+ return gr.update(value=None), gr.update(value=None), gr.update(value=None)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ root_path = os.path.abspath(os.getcwd())
124
+ # == load model ==
125
+ print("load tiny model ...")
126
+ model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
127
+ print("load small model ...")
128
+ model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
129
+ print("load base model ...")
130
+ model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
131
+ print("== load all models done. ==")
132
+ # == load model ==
133
+
134
+ with open("header.html", "r") as file:
135
+ header = file.read()
136
+ with gr.Blocks() as demo:
137
+ gr.HTML(header)
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ model_type = gr.Radio(
142
+ choices=["tiny", "small", "base"],
143
+ value="tiny",
144
+ label="Model Type",
145
+ interactive=True,
146
+ )
147
+ input_img = gr.Image(label=" ", interactive=True)
148
+ with gr.Row():
149
+ clear = gr.Button("Clear")
150
+ predict = gr.Button(value="Recognize", interactive=True, variant="primary")
151
+
152
+ with gr.Accordion("Examples:"):
153
+ example_root = os.path.join(os.path.dirname(__file__), "examples")
154
+ gr.Examples(
155
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
156
+ _.endswith("png")],
157
+ inputs=input_img,
158
+ )
159
+ with gr.Column():
160
+ gr.Button(value="Predict Result:", interactive=False)
161
+ pred_latex = gr.Textbox(label='Predict Latex', interactive=False)
162
+ output_html = gr.HTML(label='Output Html')
163
+
164
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex, output_html])
165
+ predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex, output_html])
166
+
 
 
 
167
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)