chinesesummary / fengshen /cli /fengshen_pipeline.py
HaloMaster's picture
add fengshen
50f0fbb
raw
history blame
1.16 kB
import sys
from importlib import import_module
from datasets import load_dataset
import argparse
def main():
if len(sys.argv) < 3:
raise Exception(
'args len < 3, example: fengshen_pipeline text_classification predict xxxxx')
pipeline_name = sys.argv[1]
method = sys.argv[2]
pipeline_class = getattr(import_module('fengshen.pipelines.' + pipeline_name), 'Pipeline')
total_parser = argparse.ArgumentParser("FengShen Pipeline")
total_parser.add_argument('--model', default='', type=str)
total_parser.add_argument('--datasets', default='', type=str)
total_parser.add_argument('--text', default='', type=str)
total_parser = pipeline_class.add_pipeline_specific_args(total_parser)
args = total_parser.parse_args(args=sys.argv[3:])
pipeline = pipeline_class(args=args, model=args.model)
if method == 'predict':
print(pipeline(args.text))
elif method == 'train':
datasets = load_dataset(args.datasets)
pipeline.train(datasets)
else:
raise Exception(
'cmd not support, now only support {predict, train}')
if __name__ == '__main__':
main()