HoneyTian commited on
Commit
5353e72
·
1 Parent(s): 0139e20
examples/gradio_client/predict.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+
5
+ from gradio_client import Client, file
6
+
7
+ from project_settings import project_path
8
+
9
+
10
+ def get_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--filename",
14
+ default=(project_path / "data/test_wavs/paraformer-zh/si_chuan_hua.wav").as_posix(),
15
+ type=str
16
+ )
17
+ args = parser.parse_args()
18
+ return args
19
+
20
+
21
+ def main():
22
+ args = get_args()
23
+
24
+ filename = args.filename
25
+
26
+ client = Client("https://qgyd2021-asr.hf.space/")
27
+ result = client.predict(
28
+ language="Chinese",
29
+ repo_id="csukuangfj/wenet-chinese-model",
30
+ decoding_method="greedy_search",
31
+ num_active_paths=4,
32
+ add_punctuation="Yes",
33
+ in_filename=file(filename),
34
+ api_name="/partial"
35
+ )
36
+ transcript, note = result
37
+ print(transcript)
38
+
39
+ return
40
+
41
+
42
+ if __name__ == '__main__':
43
+ main()
server/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
- if __name__ == '__main__':
5
- pass
 
 
 
 
 
 
server/asr_server/run_asr_server.py DELETED
@@ -1,49 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import logging
5
- import os
6
- import sys
7
-
8
- pwd = os.path.abspath(os.path.dirname(__file__))
9
- sys.path.append(os.path.join(pwd, "../../"))
10
-
11
- from flask import Flask
12
- from gevent import pywsgi
13
-
14
- import log
15
- from server.asr_server import settings
16
-
17
- log.setup(log_directory=settings.log_directory)
18
-
19
- from server.flask_server.view_func.heart_beat import heart_beat
20
-
21
- logger = logging.getLogger("server")
22
-
23
-
24
- # 初始化服务
25
- flask_app = Flask(__name__)
26
- flask_app.add_url_rule(rule="/HeartBeat", view_func=heart_beat, methods=["GET", "POST"], endpoint="HeartBeat")
27
-
28
-
29
- if __name__ == "__main__":
30
- parser = argparse.ArgumentParser()
31
- parser.add_argument(
32
- "--port",
33
- default=settings.port,
34
- type=int,
35
- )
36
- args = parser.parse_args()
37
-
38
- logger.info("model server is already, port: {}".format(args.port))
39
-
40
- # flask_app.run(
41
- # host="0.0.0.0",
42
- # port=args.port,
43
- # )
44
-
45
- server = pywsgi.WSGIServer(
46
- listener=("0.0.0.0", args.port),
47
- application=flask_app
48
- )
49
- server.serve_forever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/asr_server/settings.py DELETED
@@ -1,12 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from project_settings import project_path
4
-
5
- log_directory = project_path / "server/train_model_server/logs"
6
- log_directory.mkdir(parents=True, exist_ok=True)
7
-
8
- port = 9527
9
-
10
-
11
- if __name__ == "__main__":
12
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
server/asr_server/start.sh DELETED
@@ -1,7 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.8/site-packages/k2/lib/
4
-
5
- rm -rf logs/
6
-
7
- python3 run_asr_server.py
 
 
 
 
 
 
 
 
server/asr_server/stop.sh DELETED
@@ -1,3 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- kill -9 `ps -aef | grep 'run_asr_server.py' | grep -v grep | awk '{print $2}'`
 
 
 
 
server/exception.py DELETED
@@ -1,8 +0,0 @@
1
-
2
-
3
- class ExpectedError(Exception):
4
- def __init__(self, status_code, message, traceback="", detail=""):
5
- self.status_code = status_code
6
- self.message = message
7
- self.traceback = traceback
8
- self.detail = detail
 
 
 
 
 
 
 
 
 
server/flask_server/__init__.py DELETED
File without changes
server/flask_server/route_wrap/__init__.py DELETED
File without changes
server/flask_server/route_wrap/common_route_wrap.py DELETED
@@ -1,77 +0,0 @@
1
- import logging
2
- import time
3
- import traceback
4
-
5
- import jsonschema
6
-
7
- from server.exception import ExpectedError
8
- from toolbox.logging.misc import json_2_str
9
-
10
- logger = logging.getLogger('server')
11
-
12
-
13
- result_schema = {
14
- 'type': 'object',
15
- 'required': ['result', 'debug'],
16
- 'properties': {
17
- 'result': {},
18
- 'debug': {}
19
- },
20
- 'additionalProperties': False
21
-
22
- }
23
-
24
-
25
- def common_route_wrap(f):
26
- def inner(*args, **kwargs):
27
- begin = time.time()
28
- try:
29
- ret = f(*args, **kwargs)
30
- try:
31
- jsonschema.validate(ret, result_schema)
32
- debug = ret['debug']
33
- result = ret['result']
34
- except jsonschema.exceptions.ValidationError as e:
35
- debug = None
36
- result = ret
37
-
38
- response = {
39
- 'status_code': 60200,
40
- 'result': result,
41
- 'debug': debug,
42
- 'message': 'success',
43
- 'detail': None
44
- }
45
- status_code = 200
46
- except ExpectedError as e:
47
- response = {
48
- 'status_code': e.status_code,
49
- 'result': None,
50
- 'message': e.message,
51
- 'detail': e.detail,
52
- 'traceback': e.traceback,
53
- }
54
- status_code = 400
55
-
56
- except Exception as e:
57
- response = {
58
- 'status_code': 60500,
59
- 'result': None,
60
- 'message': str(e),
61
- 'detail': None,
62
- 'traceback': traceback.format_exc(),
63
- }
64
- status_code = 500
65
-
66
- cost = time.time() - begin
67
- response['time_cost'] = round(cost, 4)
68
-
69
- abstract_response = json_2_str(response)
70
- if 'traceback' in response:
71
- abstract_response['traceback'] = response['traceback']
72
-
73
- logger.info('response: {}'.format(abstract_response))
74
- # logger.info('response: {}'.format(json.dumps(response, ensure_ascii=False)))
75
-
76
- return response, status_code
77
- return inner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/flask_server/view_func/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
server/flask_server/view_func/heart_beat.py DELETED
@@ -1,8 +0,0 @@
1
- # -*- encoding=UTF-8 -*-
2
- from server.flask_server.route_wrap.common_route_wrap import common_route_wrap
3
-
4
-
5
- @common_route_wrap
6
- def heart_beat():
7
- # curl -X POST http://127.0.0.1:9527/HeartBeat
8
- return "OK"
 
 
 
 
 
 
 
 
 
server/log.py DELETED
@@ -1,110 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import logging
4
- from logging.handlers import TimedRotatingFileHandler
5
- import os
6
-
7
-
8
- def setup(log_directory: str):
9
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
10
-
11
- stream_handler = logging.StreamHandler()
12
- stream_handler.setLevel(logging.INFO)
13
- stream_handler.setFormatter(logging.Formatter(fmt))
14
-
15
- # main
16
- main_logger = logging.getLogger("main")
17
- main_logger.addHandler(stream_handler)
18
- main_info_file_handler = TimedRotatingFileHandler(
19
- filename=os.path.join(log_directory, "main.log"),
20
- encoding="utf-8",
21
- when="midnight",
22
- interval=1,
23
- backupCount=30
24
- )
25
- main_info_file_handler.setLevel(logging.INFO)
26
- main_info_file_handler.setFormatter(logging.Formatter(fmt))
27
- main_logger.addHandler(main_info_file_handler)
28
-
29
- # http
30
- http_logger = logging.getLogger("http")
31
- http_file_handler = TimedRotatingFileHandler(
32
- filename=os.path.join(log_directory, "http.log"),
33
- encoding='utf-8',
34
- when="midnight",
35
- interval=1,
36
- backupCount=30
37
- )
38
- http_file_handler.setLevel(logging.DEBUG)
39
- http_file_handler.setFormatter(logging.Formatter(fmt))
40
- http_logger.addHandler(http_file_handler)
41
-
42
- # api
43
- api_logger = logging.getLogger("api")
44
- api_file_handler = TimedRotatingFileHandler(
45
- filename=os.path.join(log_directory, "api.log"),
46
- encoding='utf-8',
47
- when="midnight",
48
- interval=1,
49
- backupCount=30
50
- )
51
- api_file_handler.setLevel(logging.DEBUG)
52
- api_file_handler.setFormatter(logging.Formatter(fmt))
53
- api_logger.addHandler(api_file_handler)
54
-
55
- # alarm
56
- alarm_logger = logging.getLogger("alarm")
57
- alarm_file_handler = TimedRotatingFileHandler(
58
- filename=os.path.join(log_directory, "alarm.log"),
59
- encoding="utf-8",
60
- when="midnight",
61
- interval=1,
62
- backupCount=30
63
- )
64
- alarm_file_handler.setLevel(logging.DEBUG)
65
- alarm_file_handler.setFormatter(logging.Formatter(fmt))
66
- alarm_logger.addHandler(alarm_file_handler)
67
-
68
- debug_file_handler = TimedRotatingFileHandler(
69
- filename=os.path.join(log_directory, "debug.log"),
70
- encoding="utf-8",
71
- when="D",
72
- interval=1,
73
- backupCount=7
74
- )
75
- debug_file_handler.setLevel(logging.DEBUG)
76
- debug_file_handler.setFormatter(logging.Formatter(fmt))
77
-
78
- info_file_handler = TimedRotatingFileHandler(
79
- filename=os.path.join(log_directory, "info.log"),
80
- encoding="utf-8",
81
- when="D",
82
- interval=1,
83
- backupCount=7
84
- )
85
- info_file_handler.setLevel(logging.INFO)
86
- info_file_handler.setFormatter(logging.Formatter(fmt))
87
-
88
- error_file_handler = TimedRotatingFileHandler(
89
- filename=os.path.join(log_directory, "error.log"),
90
- encoding="utf-8",
91
- when="D",
92
- interval=1,
93
- backupCount=7
94
- )
95
- error_file_handler.setLevel(logging.ERROR)
96
- error_file_handler.setFormatter(logging.Formatter(fmt))
97
-
98
- logging.basicConfig(
99
- level=logging.DEBUG,
100
- datefmt="%a, %d %b %Y %H:%M:%S",
101
- handlers=[
102
- debug_file_handler,
103
- info_file_handler,
104
- error_file_handler,
105
- ]
106
- )
107
-
108
-
109
- if __name__ == "__main__":
110
- pass