File size: 1,161 Bytes
50f0fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sys
from importlib import import_module
from datasets import load_dataset
import argparse


def main():
    if len(sys.argv) < 3:
        raise Exception(
            'args len < 3, example: fengshen_pipeline text_classification predict xxxxx')
    pipeline_name = sys.argv[1]
    method = sys.argv[2]
    pipeline_class = getattr(import_module('fengshen.pipelines.' + pipeline_name), 'Pipeline')

    total_parser = argparse.ArgumentParser("FengShen Pipeline")
    total_parser.add_argument('--model', default='', type=str)
    total_parser.add_argument('--datasets', default='', type=str)
    total_parser.add_argument('--text', default='', type=str)
    total_parser = pipeline_class.add_pipeline_specific_args(total_parser)
    args = total_parser.parse_args(args=sys.argv[3:])
    pipeline = pipeline_class(args=args, model=args.model)

    if method == 'predict':
        print(pipeline(args.text))
    elif method == 'train':
        datasets = load_dataset(args.datasets)
        pipeline.train(datasets)
    else:
        raise Exception(
            'cmd not support, now only support {predict, train}')


if __name__ == '__main__':
    main()