hellopahe commited on
Commit
e0738a2
·
1 Parent(s): 534bdc5

add luotuo summary

Browse files
Files changed (1) hide show
  1. app.py +58 -25
app.py CHANGED
@@ -2,7 +2,7 @@ import numpy
2
  import torch
3
  import gradio as gr
4
 
5
- from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
6
  from article_extractor.tokenizers_pegasus import PegasusTokenizer
7
  from embed import Embed
8
 
@@ -12,6 +12,9 @@ from harvesttext import HarvestText
12
  from sentence_transformers import SentenceTransformer, util
13
  from LexRank import degree_centrality_scores
14
 
 
 
 
15
 
16
  class SummaryExtractor(object):
17
  def __init__(self):
@@ -24,6 +27,39 @@ class SummaryExtractor(object):
24
  print(content)
25
  return str(self.text2text_genr(content, min_length=20, do_sample=False, num_return_sequences=3)[0]["generated_text"])
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class LexRank(object):
28
  def __init__(self):
29
  self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
@@ -56,35 +92,28 @@ class LexRank(object):
56
 
57
  # ---===--- worker instances ---===---
58
  t_randeng = SummaryExtractor()
 
 
59
  embedder = Embed()
60
  lex = LexRank()
61
 
62
 
63
  def randeng_extract(content):
64
  sentences = lex.find_central(content)
 
 
 
 
 
 
 
 
65
 
66
- num = 500
67
- ptr = 0
68
- for index, sentence in enumerate(sentences):
69
- num -= len(sentence)
70
- if num < 0 and index > 0:
71
- ptr = index - 1
72
- break
73
- if num < 0 and index == 0:
74
- ptr = index
75
- break
76
- print(">>>")
77
- for ele in sentences[:ptr]:
78
- print(ele)
79
- return t_randeng.extract("".join(sentences[:ptr]))
80
-
81
-
82
- def similarity_check(inputs: list):
83
- doc_list = inputs[1].split("\n")
84
- doc_list.append(inputs[0])
85
- embedding_list = embedder.encode(doc_list)
86
- scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist()
87
- return numpy.array2string(scores, separator=',')
88
 
89
  with gr.Blocks() as app:
90
  gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
@@ -92,10 +121,14 @@ with gr.Blocks() as app:
92
  # text_input = gr.Textbox()
93
  # text_output = gr.Textbox()
94
  # text_button = gr.Button("生成摘要")
95
- with gr.Tab("Randeng-Pegasus-523M"):
96
  text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
97
  text_output_1 = gr.Textbox(label="摘要文本")
98
  text_button_1 = gr.Button("生成摘要")
 
 
 
 
99
  with gr.Tab("相似度检���"):
100
  with gr.Row():
101
  text_input_query = gr.Textbox(label="查询文本")
@@ -103,7 +136,7 @@ with gr.Blocks() as app:
103
  text_button_similarity = gr.Button("对比相似度")
104
  text_output_similarity = gr.Textbox()
105
 
106
- # text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
107
  text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
108
  text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
109
 
 
2
  import torch
3
  import gradio as gr
4
 
5
+ from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline, AutoModel, AutoTokenizer
6
  from article_extractor.tokenizers_pegasus import PegasusTokenizer
7
  from embed import Embed
8
 
 
12
  from sentence_transformers import SentenceTransformer, util
13
  from LexRank import degree_centrality_scores
14
 
15
+ from luotuo_util import DeviceMap
16
+ from peft import get_peft_model, LoraConfig, TaskType
17
+
18
 
19
  class SummaryExtractor(object):
20
  def __init__(self):
 
27
  print(content)
28
  return str(self.text2text_genr(content, min_length=20, do_sample=False, num_return_sequences=3)[0]["generated_text"])
29
 
30
+ class Tuoling_6B_extractor(object):
31
+ def __init__(self):
32
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
33
+ self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
34
+ self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map=DeviceMap("ChatGLM").get())
35
+
36
+ # load fine-tuned pretrained model.
37
+ peft_path = "./luotuoC.pt"
38
+ peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=True, r=8, lora_alpha=32, lora_dropout=0.1)
39
+ self.model = get_peft_model(self.model, peft_config)
40
+ self.model.load_state_dict(torch.load(peft_path), strict=False)
41
+ torch.set_default_tensor_type(torch.cuda.FloatTensor)
42
+
43
+ @staticmethod
44
+ def format_example(example: dict) -> dict:
45
+ context = f"Instruction: {example['instruction']}\n"
46
+ if example.get("input"):
47
+ context += f"Input: {example['input']}\n"
48
+ context += "Answer: "
49
+ target = example["output"]
50
+ return {"context": context, "target": target}
51
+
52
+ def extract(self, instruction: str, input=None) -> str:
53
+ with torch.no_grad():
54
+ feature = Tuoling_6B_extractor.format_example(
55
+ {"instruction": "请帮我总结以下内容", "output": "", "input": f"{instruction}"}
56
+ )
57
+ input_text = feature["context"]
58
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
59
+ out = self.model.generate(input_ids=input_ids, max_length=2048, temperature=0)
60
+ answer = self.tokenizer.decode(out[0])
61
+ return answer.split('Answer:')[1]
62
+
63
  class LexRank(object):
64
  def __init__(self):
65
  self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
 
92
 
93
  # ---===--- worker instances ---===---
94
  t_randeng = SummaryExtractor()
95
+ t_tuoling = Tuoling_6B_extractor()
96
+
97
  embedder = Embed()
98
  lex = LexRank()
99
 
100
 
101
  def randeng_extract(content):
102
  sentences = lex.find_central(content)
103
+ return str(list(t_randeng.extract(sentence) for sentence in sentences))
104
+
105
+ def tuoling_extract(content):
106
+ sentences = lex.find_central(content)
107
+ return str(list(t_tuoling.extract(sentence) for sentence in sentences))
108
+
109
+ def similarity_check(query, doc):
110
+ doc_list = doc.split("\n")
111
 
112
+ query_embedding = embedder.encode(query)
113
+ doc_embedding = embedder.encode(doc_list)
114
+ scores = (query_embedding @ tf.transpose(doc_embedding))[0].numpy().tolist()
115
+ # scores = list(util.cos_sim(embedding_list[-1], doc_embedding) for doc_embedding in embedding_list[:-1])
116
+ return str(scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  with gr.Blocks() as app:
119
  gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
 
121
  # text_input = gr.Textbox()
122
  # text_output = gr.Textbox()
123
  # text_button = gr.Button("生成摘要")
124
+ with gr.Tab("LexRank->Randeng-Pegasus-523M"):
125
  text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
126
  text_output_1 = gr.Textbox(label="摘要文本")
127
  text_button_1 = gr.Button("生成摘要")
128
+ with gr.Tab("LexRank->Tuoling-6B-chatGLM"):
129
+ text_input = gr.Textbox(label="请输入长文本:", max_lines=1000)
130
+ text_output = gr.Textbox(label="摘要文本")
131
+ text_button = gr.Button("生成摘要")
132
  with gr.Tab("相似度检���"):
133
  with gr.Row():
134
  text_input_query = gr.Textbox(label="查询文本")
 
136
  text_button_similarity = gr.Button("对比相似度")
137
  text_output_similarity = gr.Textbox()
138
 
139
+ text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
140
  text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
141
  text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
142