ashishraics commited on
Commit
ce75bf1
·
1 Parent(s): dc3d11f

change structure

Browse files
Files changed (1) hide show
  1. sentiment_clf_helper.py +18 -10
sentiment_clf_helper.py CHANGED
@@ -4,14 +4,21 @@ from onnxruntime.quantization import quantize_dynamic,QuantType
4
  import transformers.convert_graph_to_onnx as onnx_convert
5
  from pathlib import Path
6
  import os
7
-
8
  import torch
9
- from transformers import AutoModelForSequenceClassification,AutoTokenizer
10
 
11
- chkpt='distilbert-base-uncased-finetuned-sst-2-english'
12
- model=AutoModelForSequenceClassification.from_pretrained(chkpt)
13
- tokenizer=AutoTokenizer.from_pretrained(chkpt)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def classify_sentiment(texts,model,tokenizer):
17
  """
@@ -30,7 +37,7 @@ def classify_sentiment(texts,model,tokenizer):
30
  return output
31
 
32
 
33
- def create_onnx_model_sentiment(_model, _tokenizer):
34
  """
35
 
36
  Args:
@@ -41,20 +48,21 @@ def create_onnx_model_sentiment(_model, _tokenizer):
41
  Creates a simple ONNX model & int8 Quantized Model in the directory "sent_clf_onnx/" if directory not present
42
 
43
  """
44
- if not os.path.exists('sent_clf_onnx_dir'):
45
  try:
46
- os.mkdir('sent_clf_onnx_dir')
47
  except:
48
  pass
49
  pipeline=transformers.pipeline("text-classification", model=_model, tokenizer=_tokenizer)
50
 
51
  onnx_convert.convert_pytorch(pipeline,
52
  opset=11,
53
- output=Path("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx"),
54
  use_external_format=False
55
  )
56
 
57
- quantize_dynamic("sent_clf_onnx_dir/sentiment_classifier_onnx.onnx","sent_clf_onnx_dir/sentiment_classifier_onnx_quant.onnx",
 
58
  weight_type=QuantType.QUInt8)
59
  else:
60
  pass
 
4
  import transformers.convert_graph_to_onnx as onnx_convert
5
  from pathlib import Path
6
  import os
 
7
  import torch
 
8
 
 
 
 
9
 
10
+ import yaml
11
+ def read_yaml(file_path):
12
+ with open(file_path, "r") as f:
13
+ return yaml.safe_load(f)
14
+
15
+ config = read_yaml('config.yaml')
16
+
17
+ sent_chkpt=config['SENTIMENT_CLF']['sent_chkpt']
18
+ sent_mdl_dir=config['SENTIMENT_CLF']['sent_mdl_dir']
19
+ sent_onnx_mdl_dir=config['SENTIMENT_CLF']['sent_onnx_mdl_dir']
20
+ sent_onnx_mdl_name=config['SENTIMENT_CLF']['sent_onnx_mdl_name']
21
+ sent_onnx_quant_mdl_name=config['SENTIMENT_CLF']['sent_onnx_quant_mdl_name']
22
 
23
  def classify_sentiment(texts,model,tokenizer):
24
  """
 
37
  return output
38
 
39
 
40
+ def create_onnx_model_sentiment(_model, _tokenizer,sent_onnx_mdl_dir=sent_onnx_mdl_dir):
41
  """
42
 
43
  Args:
 
48
  Creates a simple ONNX model & int8 Quantized Model in the directory "sent_clf_onnx/" if directory not present
49
 
50
  """
51
+ if not os.path.exists(sent_onnx_mdl_dir):
52
  try:
53
+ os.mkdir(sent_onnx_mdl_dir)
54
  except:
55
  pass
56
  pipeline=transformers.pipeline("text-classification", model=_model, tokenizer=_tokenizer)
57
 
58
  onnx_convert.convert_pytorch(pipeline,
59
  opset=11,
60
+ output=Path(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}"),
61
  use_external_format=False
62
  )
63
 
64
+ quantize_dynamic(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}",
65
+ f"{sent_onnx_mdl_dir}/{sent_onnx_quant_mdl_name}",
66
  weight_type=QuantType.QUInt8)
67
  else:
68
  pass