Spaces:
Runtime error
Runtime error
import argparse | |
import yaml | |
import gdown | |
import os | |
def load_config(path='configs/model.yaml'): | |
with open(path, 'r', encoding='utf-8') as f: | |
return yaml.load(f, Loader=yaml.FullLoader) | |
def str2bool(v): | |
return v.lower() in ("true", "yes", "t", "y", "1") | |
def init_args(): | |
parser = argparse.ArgumentParser() | |
# params for prediction engine | |
parser.add_argument("--use_gpu", type=str2bool, default=False) | |
parser.add_argument("--use_xpu", type=str2bool, default=False) | |
parser.add_argument("--use_npu", type=str2bool, default=False) | |
parser.add_argument("--ir_optim", type=str2bool, default=True) | |
parser.add_argument("--use_tensorrt", type=str2bool, default=False) | |
parser.add_argument("--min_subgraph_size", type=int, default=15) | |
parser.add_argument("--precision", type=str, default="fp32") | |
parser.add_argument("--gpu_mem", type=int, default=500) | |
parser.add_argument("--gpu_id", type=int, default=0) | |
# params for text detector | |
parser.add_argument("--image_dir", type=str) | |
parser.add_argument("--page_num", type=int, default=0) | |
parser.add_argument("--det_algorithm", type=str, default='DB') | |
parser.add_argument("--det_model_dir", type=str) | |
parser.add_argument("--det_limit_side_len", type=float, default=960) | |
parser.add_argument("--det_limit_type", type=str, default='max') | |
parser.add_argument("--det_box_type", type=str, default='quad') | |
# DB parmas | |
parser.add_argument("--det_db_thresh", type=float, default=0.3) | |
parser.add_argument("--det_db_box_thresh", type=float, default=0.6) | |
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5) | |
parser.add_argument("--max_batch_size", type=int, default=10) | |
parser.add_argument("--use_dilation", type=str2bool, default=False) | |
parser.add_argument("--det_db_score_mode", type=str, default="fast") | |
# EAST parmas | |
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) | |
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) | |
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) | |
# SAST parmas | |
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) | |
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) | |
# PSE parmas | |
parser.add_argument("--det_pse_thresh", type=float, default=0) | |
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) | |
parser.add_argument("--det_pse_min_area", type=float, default=16) | |
parser.add_argument("--det_pse_scale", type=int, default=1) | |
# FCE parmas | |
parser.add_argument("--scales", type=list, default=[8, 16, 32]) | |
parser.add_argument("--alpha", type=float, default=1.0) | |
parser.add_argument("--beta", type=float, default=1.0) | |
parser.add_argument("--fourier_degree", type=int, default=5) | |
# params for text recognizer | |
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet') | |
parser.add_argument("--rec_model_dir", type=str) | |
parser.add_argument("--rec_image_inverse", type=str2bool, default=True) | |
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320") | |
parser.add_argument("--rec_batch_num", type=int, default=6) | |
parser.add_argument("--max_text_length", type=int, default=25) | |
parser.add_argument( | |
"--rec_char_dict_path", | |
type=str, | |
default="./ppocr/utils/ppocr_keys_v1.txt") | |
parser.add_argument("--use_space_char", type=str2bool, default=True) | |
parser.add_argument( | |
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") | |
parser.add_argument("--drop_score", type=float, default=0.5) | |
# params for e2e | |
parser.add_argument("--e2e_algorithm", type=str, default='PGNet') | |
parser.add_argument("--e2e_model_dir", type=str) | |
parser.add_argument("--e2e_limit_side_len", type=float, default=768) | |
parser.add_argument("--e2e_limit_type", type=str, default='max') | |
# PGNet parmas | |
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) | |
parser.add_argument( | |
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") | |
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') | |
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') | |
# params for text classifier | |
parser.add_argument("--use_angle_cls", type=str2bool, default=False) | |
parser.add_argument("--cls_model_dir", type=str) | |
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") | |
parser.add_argument("--label_list", type=list, default=['0', '180']) | |
parser.add_argument("--cls_batch_num", type=int, default=6) | |
parser.add_argument("--cls_thresh", type=float, default=0.9) | |
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) | |
parser.add_argument("--cpu_threads", type=int, default=10) | |
parser.add_argument("--use_pdserving", type=str2bool, default=False) | |
parser.add_argument("--warmup", type=str2bool, default=False) | |
# SR parmas | |
parser.add_argument("--sr_model_dir", type=str) | |
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128") | |
parser.add_argument("--sr_batch_num", type=int, default=1) | |
# | |
parser.add_argument( | |
"--draw_img_save_dir", type=str, default="./inference_results") | |
parser.add_argument("--save_crop_res", type=str2bool, default=False) | |
parser.add_argument("--crop_res_save_dir", type=str, default="./output") | |
# multi-process | |
parser.add_argument("--use_mp", type=str2bool, default=False) | |
parser.add_argument("--total_process_num", type=int, default=1) | |
parser.add_argument("--process_id", type=int, default=0) | |
parser.add_argument("--benchmark", type=str2bool, default=False) | |
parser.add_argument("--save_log_path", type=str, default="./log_output/") | |
parser.add_argument("--show_log", type=str2bool, default=True) | |
parser.add_argument("--use_onnx", type=str2bool, default=False) | |
return parser | |
def get_args(model_params): | |
print(model_params) | |
args, _ = init_args().parse_known_args() | |
for key, val in model_params.items(): | |
setattr(args, key, val) | |
return args | |
def download_model(filename, url, save_path='weights'): | |
gdown.download(url=url, output=os.path.join(save_path, filename), quiet=False, fuzzy=True, use_cookies=False) |