Spaces:
Build error
Build error
hellopahe
commited on
Commit
·
e0738a2
1
Parent(s):
534bdc5
add luotuo summary
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import numpy
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
|
5 |
-
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
|
6 |
from article_extractor.tokenizers_pegasus import PegasusTokenizer
|
7 |
from embed import Embed
|
8 |
|
@@ -12,6 +12,9 @@ from harvesttext import HarvestText
|
|
12 |
from sentence_transformers import SentenceTransformer, util
|
13 |
from LexRank import degree_centrality_scores
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
class SummaryExtractor(object):
|
17 |
def __init__(self):
|
@@ -24,6 +27,39 @@ class SummaryExtractor(object):
|
|
24 |
print(content)
|
25 |
return str(self.text2text_genr(content, min_length=20, do_sample=False, num_return_sequences=3)[0]["generated_text"])
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
class LexRank(object):
|
28 |
def __init__(self):
|
29 |
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
|
@@ -56,35 +92,28 @@ class LexRank(object):
|
|
56 |
|
57 |
# ---===--- worker instances ---===---
|
58 |
t_randeng = SummaryExtractor()
|
|
|
|
|
59 |
embedder = Embed()
|
60 |
lex = LexRank()
|
61 |
|
62 |
|
63 |
def randeng_extract(content):
|
64 |
sentences = lex.find_central(content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
ptr = index - 1
|
72 |
-
break
|
73 |
-
if num < 0 and index == 0:
|
74 |
-
ptr = index
|
75 |
-
break
|
76 |
-
print(">>>")
|
77 |
-
for ele in sentences[:ptr]:
|
78 |
-
print(ele)
|
79 |
-
return t_randeng.extract("".join(sentences[:ptr]))
|
80 |
-
|
81 |
-
|
82 |
-
def similarity_check(inputs: list):
|
83 |
-
doc_list = inputs[1].split("\n")
|
84 |
-
doc_list.append(inputs[0])
|
85 |
-
embedding_list = embedder.encode(doc_list)
|
86 |
-
scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist()
|
87 |
-
return numpy.array2string(scores, separator=',')
|
88 |
|
89 |
with gr.Blocks() as app:
|
90 |
gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
|
@@ -92,10 +121,14 @@ with gr.Blocks() as app:
|
|
92 |
# text_input = gr.Textbox()
|
93 |
# text_output = gr.Textbox()
|
94 |
# text_button = gr.Button("生成摘要")
|
95 |
-
with gr.Tab("Randeng-Pegasus-523M"):
|
96 |
text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
|
97 |
text_output_1 = gr.Textbox(label="摘要文本")
|
98 |
text_button_1 = gr.Button("生成摘要")
|
|
|
|
|
|
|
|
|
99 |
with gr.Tab("相似度检���"):
|
100 |
with gr.Row():
|
101 |
text_input_query = gr.Textbox(label="查询文本")
|
@@ -103,7 +136,7 @@ with gr.Blocks() as app:
|
|
103 |
text_button_similarity = gr.Button("对比相似度")
|
104 |
text_output_similarity = gr.Textbox()
|
105 |
|
106 |
-
|
107 |
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
|
108 |
text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
|
109 |
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
|
5 |
+
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline, AutoModel, AutoTokenizer
|
6 |
from article_extractor.tokenizers_pegasus import PegasusTokenizer
|
7 |
from embed import Embed
|
8 |
|
|
|
12 |
from sentence_transformers import SentenceTransformer, util
|
13 |
from LexRank import degree_centrality_scores
|
14 |
|
15 |
+
from luotuo_util import DeviceMap
|
16 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
17 |
+
|
18 |
|
19 |
class SummaryExtractor(object):
|
20 |
def __init__(self):
|
|
|
27 |
print(content)
|
28 |
return str(self.text2text_genr(content, min_length=20, do_sample=False, num_return_sequences=3)[0]["generated_text"])
|
29 |
|
30 |
+
class Tuoling_6B_extractor(object):
|
31 |
+
def __init__(self):
|
32 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
33 |
+
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
34 |
+
self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map=DeviceMap("ChatGLM").get())
|
35 |
+
|
36 |
+
# load fine-tuned pretrained model.
|
37 |
+
peft_path = "./luotuoC.pt"
|
38 |
+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=True, r=8, lora_alpha=32, lora_dropout=0.1)
|
39 |
+
self.model = get_peft_model(self.model, peft_config)
|
40 |
+
self.model.load_state_dict(torch.load(peft_path), strict=False)
|
41 |
+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def format_example(example: dict) -> dict:
|
45 |
+
context = f"Instruction: {example['instruction']}\n"
|
46 |
+
if example.get("input"):
|
47 |
+
context += f"Input: {example['input']}\n"
|
48 |
+
context += "Answer: "
|
49 |
+
target = example["output"]
|
50 |
+
return {"context": context, "target": target}
|
51 |
+
|
52 |
+
def extract(self, instruction: str, input=None) -> str:
|
53 |
+
with torch.no_grad():
|
54 |
+
feature = Tuoling_6B_extractor.format_example(
|
55 |
+
{"instruction": "请帮我总结以下内容", "output": "", "input": f"{instruction}"}
|
56 |
+
)
|
57 |
+
input_text = feature["context"]
|
58 |
+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
|
59 |
+
out = self.model.generate(input_ids=input_ids, max_length=2048, temperature=0)
|
60 |
+
answer = self.tokenizer.decode(out[0])
|
61 |
+
return answer.split('Answer:')[1]
|
62 |
+
|
63 |
class LexRank(object):
|
64 |
def __init__(self):
|
65 |
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
|
|
|
92 |
|
93 |
# ---===--- worker instances ---===---
|
94 |
t_randeng = SummaryExtractor()
|
95 |
+
t_tuoling = Tuoling_6B_extractor()
|
96 |
+
|
97 |
embedder = Embed()
|
98 |
lex = LexRank()
|
99 |
|
100 |
|
101 |
def randeng_extract(content):
|
102 |
sentences = lex.find_central(content)
|
103 |
+
return str(list(t_randeng.extract(sentence) for sentence in sentences))
|
104 |
+
|
105 |
+
def tuoling_extract(content):
|
106 |
+
sentences = lex.find_central(content)
|
107 |
+
return str(list(t_tuoling.extract(sentence) for sentence in sentences))
|
108 |
+
|
109 |
+
def similarity_check(query, doc):
|
110 |
+
doc_list = doc.split("\n")
|
111 |
|
112 |
+
query_embedding = embedder.encode(query)
|
113 |
+
doc_embedding = embedder.encode(doc_list)
|
114 |
+
scores = (query_embedding @ tf.transpose(doc_embedding))[0].numpy().tolist()
|
115 |
+
# scores = list(util.cos_sim(embedding_list[-1], doc_embedding) for doc_embedding in embedding_list[:-1])
|
116 |
+
return str(scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
with gr.Blocks() as app:
|
119 |
gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
|
|
|
121 |
# text_input = gr.Textbox()
|
122 |
# text_output = gr.Textbox()
|
123 |
# text_button = gr.Button("生成摘要")
|
124 |
+
with gr.Tab("LexRank->Randeng-Pegasus-523M"):
|
125 |
text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
|
126 |
text_output_1 = gr.Textbox(label="摘要文本")
|
127 |
text_button_1 = gr.Button("生成摘要")
|
128 |
+
with gr.Tab("LexRank->Tuoling-6B-chatGLM"):
|
129 |
+
text_input = gr.Textbox(label="请输入长文本:", max_lines=1000)
|
130 |
+
text_output = gr.Textbox(label="摘要文本")
|
131 |
+
text_button = gr.Button("生成摘要")
|
132 |
with gr.Tab("相似度检���"):
|
133 |
with gr.Row():
|
134 |
text_input_query = gr.Textbox(label="查询文本")
|
|
|
136 |
text_button_similarity = gr.Button("对比相似度")
|
137 |
text_output_similarity = gr.Textbox()
|
138 |
|
139 |
+
text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
|
140 |
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
|
141 |
text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
|
142 |
|