parkerjj commited on
Commit
fcfffd7
·
1 Parent(s): 9d3e0cb

优化预测函数,增加执行时间和输入文本长度的打印,调整处理逻辑以提高可读性

Browse files
Files changed (2) hide show
  1. blkeras.py +7 -1
  2. preprocess.py +2 -1
blkeras.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from huggingface_hub import login
3
  from huggingface_hub import hf_hub_download
4
 
@@ -98,6 +99,8 @@ def predict(text: str, stock_codes: list):
98
  from preprocess import get_document_vector, get_stock_info, preprocessing_entry, process_entities, process_pos_tags, processing_entry
99
 
100
  try:
 
 
101
  input_text = text
102
  affected_stock_codes = stock_codes
103
 
@@ -110,7 +113,7 @@ def predict(text: str, stock_codes: list):
110
  processed_entry = processing_entry(input_text)
111
 
112
  # 解包 processed_entry 中的各个值
113
- lemmatized_entry, pos_tag, ner, dependency_parsing, sentiment_score = processed_entry
114
 
115
  # 分别打印每个变量,便于调试
116
  #print("Lemmatized Entry:", lemmatized_entry)
@@ -403,6 +406,9 @@ def predict(text: str, stock_codes: list):
403
  print(f"predict() error: {e}")
404
  print(traceback_str)
405
  return {"predict() error": str(e), "traceback": traceback_str}
 
 
 
406
 
407
 
408
  def stock_fix_for_1118_model(score, predictions, last_prices, is_index=True):
 
1
  import os
2
+ from tracemalloc import start
3
  from huggingface_hub import login
4
  from huggingface_hub import hf_hub_download
5
 
 
99
  from preprocess import get_document_vector, get_stock_info, preprocessing_entry, process_entities, process_pos_tags, processing_entry
100
 
101
  try:
102
+
103
+ start_time = datetime.now()
104
  input_text = text
105
  affected_stock_codes = stock_codes
106
 
 
113
  processed_entry = processing_entry(input_text)
114
 
115
  # 解包 processed_entry 中的各个值
116
+ lemmatized_entry, pos_tag, ner, _ , sentiment_score = processed_entry
117
 
118
  # 分别打印每个变量,便于调试
119
  #print("Lemmatized Entry:", lemmatized_entry)
 
406
  print(f"predict() error: {e}")
407
  print(traceback_str)
408
  return {"predict() error": str(e), "traceback": traceback_str}
409
+ finally:
410
+ end_time = datetime.now()
411
+ print(f"predict() Text: {input_text[:200] if len(input_text) > 200 else input_text} \n execution time: {end_time - start_time}, Text Length: {len(input_text)} \n")
412
 
413
 
414
  def stock_fix_for_1118_model(score, predictions, last_prices, is_index=True):
preprocess.py CHANGED
@@ -562,8 +562,9 @@ def processing_entry(entry):
562
  ner = named_entity_recognition(cleaned_text)
563
  # print(f"named_entity_recognition: {db_ner}")
564
 
565
- dependency_parsed = dependency_parsing(cleaned_text)
566
  # print(f"dependency_parsing: {db_dependency_parsing}")
 
567
 
568
  sentiment_score = get_sentiment_score(cleaned_text)
569
  # print(f"sentiment_score: {sentiment_score}")
 
562
  ner = named_entity_recognition(cleaned_text)
563
  # print(f"named_entity_recognition: {db_ner}")
564
 
565
+ # dependency_parsed = dependency_parsing(cleaned_text)
566
  # print(f"dependency_parsing: {db_dependency_parsing}")
567
+ dependency_parsed = None
568
 
569
  sentiment_score = get_sentiment_score(cleaned_text)
570
  # print(f"sentiment_score: {sentiment_score}")