ChatPDF / chatpdf.py
hahahafofo's picture
fix bug
0e8caab
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""
import logging
from similarities import Similarity
from textgen import ChatGlmModel, LlamaModel
from transformers import pipeline
from loguru import logger
PROMPT_TEMPLATE = """\
基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
已知内容:
{context_str}
问题:
{query_str}
"""
class ChatPDF:
def __init__(
self,
sim_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
gen_model_type: str = "chatglm",
gen_model_name_or_path: str = "THUDM/chatglm-6b-int4",
lora_model_name_or_path: str = None,
):
self.sim_model = Similarity(model_name_or_path=sim_model_name_or_path)
self.model_type = gen_model_type
if gen_model_type == "chatglm":
self.gen_model = ChatGlmModel(gen_model_type, gen_model_name_or_path, lora_name=lora_model_name_or_path)
elif gen_model_type == "llama":
self.gen_model = LlamaModel(gen_model_type, gen_model_name_or_path, lora_name=lora_model_name_or_path)
elif gen_model_type == "t5":
self.gen_model = pipeline('text2text-generation', model=gen_model_name_or_path)
else:
raise ValueError('gen_model_type must be chatglm or llama.')
self.history = None
self.pdf_path = None
def load_pdf_file(self, pdf_path: str):
"""Load a PDF file."""
if pdf_path.endswith('.pdf'):
corpus = self.extract_text_from_pdf(pdf_path)
elif pdf_path.endswith('.docx'):
corpus = self.extract_text_from_docx(pdf_path)
elif pdf_path.endswith('.md'):
corpus = self.extract_text_from_markdown(pdf_path)
else:
corpus = self.extract_text_from_txt(pdf_path)
self.sim_model.add_corpus(corpus)
self.pdf_path = pdf_path
@staticmethod
def extract_text_from_pdf(file_path: str):
"""Extract text content from a PDF file."""
import PyPDF2
contents = []
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
for page in pdf_reader.pages:
page_text = page.extract_text().strip()
raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
new_text = ''
for text in raw_text:
new_text += text
if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
'』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
contents.append(new_text)
new_text = ''
if new_text:
contents.append(new_text)
return contents
@staticmethod
def extract_text_from_txt(file_path: str):
"""Extract text content from a TXT file."""
contents = []
with open(file_path, 'r', encoding='utf-8') as f:
contents = [text.strip() for text in f.readlines() if text.strip()]
return contents
@staticmethod
def extract_text_from_docx(file_path: str):
"""Extract text content from a DOCX file."""
import docx
document = docx.Document(file_path)
contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()]
return contents
@staticmethod
def extract_text_from_markdown(file_path: str):
"""Extract text content from a Markdown file."""
import markdown
from bs4 import BeautifulSoup
with open(file_path, 'r', encoding='utf-8') as f:
markdown_text = f.read()
html = markdown.markdown(markdown_text)
soup = BeautifulSoup(html, 'html.parser')
contents = [text.strip() for text in soup.get_text().splitlines() if text.strip()]
return contents
@staticmethod
def _add_source_numbers(lst):
"""Add source numbers to a list of strings."""
return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)]
def _generate_answer(self, query_str, context_str, history=None, max_length=1024):
"""Generate answer from query and context."""
if self.model_type == "t5":
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
return response, history
prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query_str)
response, out_history = self.gen_model.chat(prompt, history, max_length=max_length)
return response, out_history
def chat(self, query_str, history=None, max_length=1024):
if self.model_type == "t5":
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
logger.debug(response)
return response, history
response, out_history = self.gen_model.chat(query_str, history, max_length=max_length)
return response, out_history
def query(
self,
query,
topn: int = 5,
max_length: int = 1024,
max_input_size: int = 1024,
use_history: bool = False
):
"""Query from corpus."""
sim_contents = self.sim_model.most_similar(query, topn=topn)
reference_results = []
for query_id, id_score_dict in sim_contents.items():
for corpus_id, s in id_score_dict.items():
reference_results.append(self.sim_model.corpus[corpus_id])
if not reference_results:
return '没有提供足够的相关信息', reference_results
reference_results = self._add_source_numbers(reference_results)
context_str = '\n'.join(reference_results)[:(max_input_size - len(PROMPT_TEMPLATE))]
if use_history:
response, out_history = self._generate_answer(query, context_str, self.history, max_length=max_length)
self.history = out_history
else:
response, out_history = self._generate_answer(query, context_str)
return response, out_history, reference_results
def save_index(self, index_path=None):
"""Save model."""
if index_path is None:
index_path = '.'.join(self.pdf_path.split('.')[:-1]) + '_index.json'
self.sim_model.save_index(index_path)
def load_index(self, index_path=None):
"""Load model."""
if index_path is None:
index_path = '.'.join(self.pdf_path.split('.')[:-1]) + '_index.json'
self.sim_model.load_index(index_path)
if __name__ == "__main__":
import sys
if len(sys.argv) > 2:
gen_model_name_or_path = sys.argv[1]
else:
print('Usage: python chatpdf.py <gen_model_name_or_path>')
gen_model_name_or_path = "THUDM/chatglm-6b-int4"
m = ChatPDF(gen_model_name_or_path=gen_model_name_or_path)
m.load_pdf_file(pdf_path='sample.pdf')
response = m.query('自然语言中的非平行迁移是指什么?')
print(response[0])
response = m.query('本文作者是谁?')
print(response[0])