import sys from importlib import import_module from datasets import load_dataset import argparse def main(): if len(sys.argv) < 3: raise Exception( 'args len < 3, example: fengshen_pipeline text_classification predict xxxxx') pipeline_name = sys.argv[1] method = sys.argv[2] pipeline_class = getattr(import_module('fengshen.pipelines.' + pipeline_name), 'Pipeline') total_parser = argparse.ArgumentParser("FengShen Pipeline") total_parser.add_argument('--model', default='', type=str) total_parser.add_argument('--datasets', default='', type=str) total_parser.add_argument('--text', default='', type=str) total_parser = pipeline_class.add_pipeline_specific_args(total_parser) args = total_parser.parse_args(args=sys.argv[3:]) pipeline = pipeline_class(args=args, model=args.model) if method == 'predict': print(pipeline(args.text)) elif method == 'train': datasets = load_dataset(args.datasets) pipeline.train(datasets) else: raise Exception( 'cmd not support, now only support {predict, train}') if __name__ == '__main__': main()