Spaces:
Running
Running
from loguru import logger | |
import json | |
from bin_public.utils.utils_db import * | |
from bin_public.config.presets import MIGRAINE_PROMPT | |
import PyPDF2 | |
import pinecone | |
from langchain.vectorstores import Pinecone | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
PINECONE_API_KEY = os.environ['PINECONE_API_KEY'] | |
PINECONE_API_ENV = os.environ['PINECONE_API_ENV'] | |
def load_local_file_PDF(path, file_name): | |
result = {} | |
temp = '' | |
pdf_reader = PyPDF2.PdfReader(open(path, 'rb')) | |
for i in range(len(pdf_reader.pages)): | |
pages = pdf_reader.pages[i] | |
temp += pages.extract_text() | |
if file_name.endswith('.pdf'): | |
index = file_name[:-4] | |
temp = temp.replace('\n', '').replace('\t', '') | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_text(temp) | |
i = 0 | |
for content in texts: | |
result[f'{index}_{i}'] = content | |
i += 1 | |
return result | |
def holo_query_insert_file_contents(file_name, file_content): | |
run_sql = f""" | |
insert into s_context( | |
file_name, | |
content | |
) | |
select | |
'{file_name}' as file_name, | |
'{file_content}' as content | |
""" | |
holo_query_func(run_sql, is_query=0) | |
def holo_query_get_content(run_sql): | |
temp = [] | |
data = holo_query_func(run_sql, is_query=1) | |
for i in data: | |
temp.append(i[1].replace('\n', '').replace('\t', '')) | |
return temp | |
def pdf2database(path, file_name): | |
temp = '' | |
pdf_reader = PyPDF2.PdfReader(open(path, 'rb')) | |
for i in range(len(pdf_reader.pages)): | |
pages = pdf_reader.pages[i] | |
temp += pages.extract_text() | |
if file_name.endswith('.pdf'): | |
index = file_name[:-4] | |
temp = temp.replace('\n', '').replace('\t', '') | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_text(temp) | |
for i in range(len(texts)): | |
holo_query_insert_file_contents(f'{index}_{i}', f'{texts[i]}') | |
logger.info(f'{index}_{i} stored') | |
def load_json(path): | |
with open(path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
return data | |
def get_content_from_json(path): | |
result = [] | |
data = load_json(path) | |
for item in data: | |
key = list(item.keys())[0] | |
value = item[key] | |
result.append(key + ',' + value) | |
return result | |
def data2embeddings(index_name, data, embeddings): | |
pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_API_ENV) | |
Pinecone.from_texts([t for t in data], embeddings, index_name=index_name) | |
logger.info("Stored Successfully") | |
def context_construction(api_key, query, model, pinecone_api_key, pinecone_api_env, temperature, index_name, mode="map_reduce"): | |
temp = [] | |
embeddings = OpenAIEmbeddings(openai_api_key=api_key) | |
# llm = OpenAI(temperature=temperature, openai_api_key=api_key, model_name=model) | |
pinecone.init(api_key=pinecone_api_key, environment=pinecone_api_env) | |
docsearch = Pinecone.from_existing_index(index_name=index_name, embedding=embeddings) | |
# chain = load_qa_chain(llm, chain_type=mode) | |
if not any(char.isalnum() for char in query): | |
return " ", MIGRAINE_PROMPT, "Connecting to Pinecone" | |
else: | |
docs = docsearch.similarity_search(query, include_metadata=True, k=2) | |
# response = chain.run(input_documents=docs, question=str(query)) | |
for i in docs: | |
temp.append(i.page_content) | |
return '用以下资料进行辅助回答\n' + ' '.join(temp), '\n' + ' '.join(temp), "Connecting to Pinecone" | |
def chat_prerequisites(input, filter, embeddings, top_k=4): | |
# filter : dic | |
# input_prompt = '只基于以下规范的两种分类对形如 "position_name: xx job_name: xx job_description: xxx"的描述进行分类,只要回复规范的类别名' | |
input_prompt = '接下来我会给你一段"不规范的招聘职位描述",以及4个用(选项一,选项二,选项三,选项四)四个选项表示的规范的职业分类描述。' \ | |
'你需要将"不规范的招聘职位描述"归类为”选项一“或“选项二”或“选项三”或“选项四”。' \ | |
'你只需要回复”选项一“或“选项二”或“选项三”或“选项四”,不要回复任何别的东西' | |
query = input_prompt + input | |
temp = [] | |
docsearch = Pinecone.from_existing_index(index_name=pinecone.list_indexes()[0], embedding=embeddings) | |
docs = docsearch.similarity_search(query, k=top_k, filter=filter) | |
for index, i in enumerate(docs): | |
if index == 0: | |
temp.append("选项一:" + i.page_content + "##") | |
if index == 1: | |
temp.append("选项二:" + i.page_content + "##") | |
if index == 2: | |
temp.append("选项三:" + i.page_content + "##") | |
if index == 3: | |
temp.append("选项四:" + i.page_content + "##") | |
system_prompt = ' '.join(temp) | |
return system_prompt, query | |
def chat(input, filter, embeddings): | |
system_prompt, query = chat_prerequisites(input, filter, embeddings) | |
logger.info('prerequisites satisfied') | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": query} | |
]) | |
return completion.choices[0].message['content'], system_prompt | |
def chat_data_cleaning(input): | |
clean_prompt = '我要求你提取出这段文字中的岗位名称、岗位描述(用一句或者两句话概括),去除无关紧要的信息,比如工资,地点等等,并严格遵守"岗位名称: xxx # 岗位描述: xxx # "的格式进行回复' | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": clean_prompt}, | |
{"role": "user", "content": clean_prompt + input} | |
]) | |
return completion.choices[0].message['content'] | |
def local_emb2pinecone(PINECONE_API_KEY, PINECONE_API_ENV, level, emb_path, text_path, delete=False): | |
pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_API_ENV) | |
logger.info('Pinecone initialized') | |
logger.info(pinecone.list_indexes()[0]) | |
l = load_json(emb_path) | |
print(f'level{level} loaded') | |
with open(text_path, 'r', encoding='utf-8') as f: | |
texts = f.readlines() | |
texts = [i.replace('\n', '') for i in texts] | |
index = pinecone.Index(pinecone.list_indexes()[0]) | |
if delete: | |
if input('press y to delete all the vectors: ') == 'y': | |
index.delete(delete_all=True) | |
logger.info('delete all') | |
else: | |
pass | |
else: | |
pass | |
for key, value, text in zip(list(l.keys()), list(l.values()), texts): | |
index.upsert([(key, value, {"text": text, "level": level})]) | |
logger.info('upload successfully') |