Spaces:
Runtime error
Runtime error
Damien Benveniste
commited on
Commit
·
d038098
1
Parent(s):
0690567
deployment test
Browse files- Dockerfile +10 -1
- app/__init__.py +0 -0
- app/__pycache__/api_handlers.cpython-312.pyc +0 -0
- app/__pycache__/callbacks.cpython-312.pyc +0 -0
- app/__pycache__/chains.cpython-312.pyc +0 -0
- app/__pycache__/crud.cpython-312.pyc +0 -0
- app/__pycache__/data_indexing.cpython-312.pyc +0 -0
- app/__pycache__/database.cpython-312.pyc +0 -0
- app/__pycache__/main.cpython-312.pyc +0 -0
- app/__pycache__/models.cpython-312.pyc +0 -0
- app/__pycache__/prompts.cpython-312.pyc +0 -0
- app/__pycache__/schemas.cpython-312.pyc +0 -0
- app/callbacks.py +24 -0
- app/chains.py +51 -0
- app/crud.py +34 -0
- app/data_indexing.py +154 -0
- app/database.py +12 -0
- app/main.py +118 -0
- app/models.py +23 -0
- app/prompts.py +84 -0
- app/schemas.py +26 -0
- app/sources.txt +0 -0
- requirements.txt +11 -1
Dockerfile
CHANGED
@@ -6,8 +6,17 @@ ENV PATH="/home/user/.local/bin:$PATH"
|
|
6 |
|
7 |
WORKDIR /app
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
|
12 |
COPY --chown=user . /app
|
13 |
-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
6 |
|
7 |
WORKDIR /app
|
8 |
|
9 |
+
# Expose the secret SECRET_EXAMPLE at buildtime and use its value as git remote URL
|
10 |
+
RUN --mount=type=secret,id=PINECONE_API_KEY,mode=0444,required=true \
|
11 |
+
git init && \
|
12 |
+
git remote add origin $(cat /run/secrets/PINECONE_API_KEY)
|
13 |
+
|
14 |
+
RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true \
|
15 |
+
git init && \
|
16 |
+
git remote add origin $(cat /run/secrets/OPENAI_API_KEY)
|
17 |
+
|
18 |
COPY --chown=user ./requirements.txt requirements.txt
|
19 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
20 |
|
21 |
COPY --chown=user . /app
|
22 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
File without changes
|
app/__pycache__/api_handlers.cpython-312.pyc
ADDED
Binary file (511 Bytes). View file
|
|
app/__pycache__/callbacks.cpython-312.pyc
ADDED
Binary file (1.9 kB). View file
|
|
app/__pycache__/chains.cpython-312.pyc
ADDED
Binary file (1.85 kB). View file
|
|
app/__pycache__/crud.cpython-312.pyc
ADDED
Binary file (2.22 kB). View file
|
|
app/__pycache__/data_indexing.cpython-312.pyc
ADDED
Binary file (7.43 kB). View file
|
|
app/__pycache__/database.cpython-312.pyc
ADDED
Binary file (622 Bytes). View file
|
|
app/__pycache__/main.cpython-312.pyc
ADDED
Binary file (6.75 kB). View file
|
|
app/__pycache__/models.cpython-312.pyc
ADDED
Binary file (1.37 kB). View file
|
|
app/__pycache__/prompts.cpython-312.pyc
ADDED
Binary file (2.9 kB). View file
|
|
app/__pycache__/schemas.cpython-312.pyc
ADDED
Binary file (1.67 kB). View file
|
|
app/callbacks.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Dict, Any, List
|
3 |
+
from langchain_core.callbacks import BaseCallbackHandler
|
4 |
+
import schemas
|
5 |
+
import crud
|
6 |
+
|
7 |
+
|
8 |
+
class LogResponseCallback(BaseCallbackHandler):
|
9 |
+
|
10 |
+
def __init__(self, user_request: schemas.UserRequest, db):
|
11 |
+
super().__init__()
|
12 |
+
self.user_request = user_request
|
13 |
+
self.db = db
|
14 |
+
|
15 |
+
def on_llm_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
16 |
+
"""Run when chain ends running."""
|
17 |
+
message = schemas.MessageBase(message=outputs.generations[0][0].text, type='AI')
|
18 |
+
crud.add_message(self.db, message, self.user_request.username)
|
19 |
+
|
20 |
+
def on_llm_start(
|
21 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
22 |
+
) -> Any:
|
23 |
+
for prompt in prompts:
|
24 |
+
print(prompt)
|
app/chains.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
2 |
+
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
3 |
+
from langchain_core.output_parsers import CommaSeparatedListOutputParser
|
4 |
+
import schemas
|
5 |
+
from prompts import (
|
6 |
+
raw_prompt_formatted,
|
7 |
+
history_prompt_formatted,
|
8 |
+
question_prompt_formatted,
|
9 |
+
context_prompt_formatted,
|
10 |
+
format_context,
|
11 |
+
tokenizer
|
12 |
+
)
|
13 |
+
from data_indexing import DataIndexer
|
14 |
+
from operator import itemgetter
|
15 |
+
|
16 |
+
|
17 |
+
data_indexer = DataIndexer()
|
18 |
+
|
19 |
+
llm = HuggingFaceEndpoint(
|
20 |
+
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
21 |
+
max_new_tokens=512,
|
22 |
+
stop_sequences=[tokenizer.eos_token]
|
23 |
+
)
|
24 |
+
|
25 |
+
formatted_chain = (
|
26 |
+
raw_prompt_formatted
|
27 |
+
| llm
|
28 |
+
).with_types(input_type=schemas.UserQuestion)
|
29 |
+
|
30 |
+
history_chain = (
|
31 |
+
history_prompt_formatted
|
32 |
+
| llm
|
33 |
+
).with_types(input_type=schemas.HistoryInput)
|
34 |
+
|
35 |
+
rag_chain = (
|
36 |
+
{
|
37 |
+
'question': question_prompt_formatted | llm,
|
38 |
+
'hybrid_search': RunnablePassthrough()
|
39 |
+
}
|
40 |
+
| {
|
41 |
+
'context': lambda x: format_context(data_indexer.search(x['question'], hybrid_search=x['hybrid_search'])),
|
42 |
+
'standalone_question': lambda x: x['question']
|
43 |
+
}
|
44 |
+
| context_prompt_formatted
|
45 |
+
| llm
|
46 |
+
).with_types(input_type=schemas.RagInput)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
app/crud.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy.orm import Session
|
2 |
+
import models, schemas
|
3 |
+
from fastapi import HTTPException
|
4 |
+
|
5 |
+
# def create_user(db: Session, user: schemas.UserCreate):
|
6 |
+
# db_user = models.User(username=user.username)
|
7 |
+
# db.add(db_user)
|
8 |
+
# db.commit()
|
9 |
+
# db.refresh(db_user)
|
10 |
+
# return db_user
|
11 |
+
|
12 |
+
def get_or_create_user(db: Session, username: str):
|
13 |
+
user = db.query(models.User).filter(models.User.username == username).first()
|
14 |
+
if not user:
|
15 |
+
user = models.User(username=username)
|
16 |
+
db.add(user)
|
17 |
+
db.commit()
|
18 |
+
db.refresh(user)
|
19 |
+
return user
|
20 |
+
|
21 |
+
def add_message(db: Session, message: schemas.MessageBase, username: str):
|
22 |
+
user = get_or_create_user(db, username)
|
23 |
+
message = models.Message(**message.dict())
|
24 |
+
message.user = user
|
25 |
+
db.add(message)
|
26 |
+
db.commit()
|
27 |
+
db.refresh(message)
|
28 |
+
return message
|
29 |
+
|
30 |
+
def get_user_chat_history(db: Session, username: str):
|
31 |
+
user = db.query(models.User).filter(models.User.username == username).first()
|
32 |
+
if not user:
|
33 |
+
return []
|
34 |
+
return db.query(models.Message).filter(models.Message.user_id == user.id).all()
|
app/data_indexing.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
from pathlib import Path
|
4 |
+
from pinecone.grpc import PineconeGRPC as Pinecone
|
5 |
+
from pinecone import ServerlessSpec
|
6 |
+
from langchain_community.vectorstores import Chroma
|
7 |
+
from langchain_openai import OpenAIEmbeddings
|
8 |
+
|
9 |
+
current_dir = Path(__file__).resolve().parent
|
10 |
+
|
11 |
+
|
12 |
+
os.environ['PINECONE_API_KEY'] = "988da8ab-3725-4047-b622-cc42d07ecb6c"
|
13 |
+
os.environ['OPENAI_API_KEY'] = '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'
|
14 |
+
|
15 |
+
|
16 |
+
class DataIndexer:
|
17 |
+
|
18 |
+
source_file = os.path.join(current_dir, 'sources.txt')
|
19 |
+
|
20 |
+
def __init__(self, index_name='langchain-repo') -> None:
|
21 |
+
# self.embedding_client = InferenceClient(
|
22 |
+
# "dunzhang/stella_en_1.5B_v5",
|
23 |
+
# )
|
24 |
+
self.embedding_client = OpenAIEmbeddings()
|
25 |
+
|
26 |
+
self.index_name = index_name
|
27 |
+
self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
|
28 |
+
|
29 |
+
if index_name not in self.pinecone_client.list_indexes().names():
|
30 |
+
self.pinecone_client.create_index(
|
31 |
+
name=index_name,
|
32 |
+
dimension=1536,
|
33 |
+
metric='cosine',
|
34 |
+
spec=ServerlessSpec(
|
35 |
+
cloud='aws',
|
36 |
+
region='us-east-1'
|
37 |
+
)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.index = self.pinecone_client.Index(self.index_name)
|
41 |
+
self.source_index = self.get_source_index()
|
42 |
+
# self.source_index = None
|
43 |
+
|
44 |
+
def get_source_index(self):
|
45 |
+
if not os.path.isfile(self.source_file):
|
46 |
+
print('No source file')
|
47 |
+
return None
|
48 |
+
|
49 |
+
print('create source index')
|
50 |
+
|
51 |
+
with open(self.source_file, 'r') as file:
|
52 |
+
sources = file.readlines()
|
53 |
+
|
54 |
+
sources = [s.rstrip('\n') for s in sources]
|
55 |
+
vectorstore = Chroma.from_texts(
|
56 |
+
sources, embedding=self.embedding_client
|
57 |
+
)
|
58 |
+
return vectorstore
|
59 |
+
|
60 |
+
def index_data(self, docs, batch_size=32):
|
61 |
+
|
62 |
+
with open(self.source_file, 'a') as file:
|
63 |
+
for doc in docs:
|
64 |
+
file.writelines(doc.metadata['source'] + '\n')
|
65 |
+
|
66 |
+
for i in range(0, len(docs), batch_size):
|
67 |
+
batch = docs[i: i + batch_size]
|
68 |
+
values = self.embedding_client.embed_documents([
|
69 |
+
doc.page_content for doc in batch
|
70 |
+
])
|
71 |
+
|
72 |
+
# values = self.embedding_client.feature_extraction([
|
73 |
+
# doc.page_content for doc in batch
|
74 |
+
# ])
|
75 |
+
vector_ids = [str(uuid.uuid4()) for _ in batch]
|
76 |
+
|
77 |
+
metadatas = [{
|
78 |
+
'text': doc.page_content,
|
79 |
+
**doc.metadata
|
80 |
+
} for doc in batch]
|
81 |
+
|
82 |
+
vectors = [{
|
83 |
+
'id': vector_id,
|
84 |
+
'values': value,
|
85 |
+
'metadata': metadata
|
86 |
+
} for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
|
87 |
+
|
88 |
+
try:
|
89 |
+
upsert_response = self.index.upsert(vectors=vectors)
|
90 |
+
print(upsert_response)
|
91 |
+
except Exception as e:
|
92 |
+
print(e)
|
93 |
+
|
94 |
+
def search(self, text_query, top_k=5, hybrid_search=False):
|
95 |
+
|
96 |
+
print('text query:', text_query)
|
97 |
+
|
98 |
+
filter = None
|
99 |
+
if hybrid_search and self.source_index:
|
100 |
+
source_docs = self.source_index.similarity_search(text_query, 50)
|
101 |
+
print("source_docs", source_docs)
|
102 |
+
filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
|
103 |
+
|
104 |
+
# vector = self.embedding_client.feature_extraction(text_query)
|
105 |
+
vector = self.embedding_client.embed_query(text_query)
|
106 |
+
result = self.index.query(
|
107 |
+
vector=vector,
|
108 |
+
top_k=top_k,
|
109 |
+
include_metadata=True,
|
110 |
+
filter=filter
|
111 |
+
)
|
112 |
+
|
113 |
+
docs = []
|
114 |
+
for res in result["matches"]:
|
115 |
+
metadata = res["metadata"]
|
116 |
+
if 'text' in metadata:
|
117 |
+
text = metadata.pop('text')
|
118 |
+
docs.append(text)
|
119 |
+
return docs
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == '__main__':
|
123 |
+
|
124 |
+
from langchain_community.document_loaders import GitLoader
|
125 |
+
from langchain_text_splitters import (
|
126 |
+
Language,
|
127 |
+
RecursiveCharacterTextSplitter,
|
128 |
+
)
|
129 |
+
|
130 |
+
loader = GitLoader(
|
131 |
+
clone_url="https://github.com/langchain-ai/langchain",
|
132 |
+
repo_path="./code_data/langchain_repo/",
|
133 |
+
branch="master",
|
134 |
+
)
|
135 |
+
|
136 |
+
python_splitter = RecursiveCharacterTextSplitter.from_language(
|
137 |
+
language=Language.PYTHON, chunk_size=10000, chunk_overlap=100
|
138 |
+
)
|
139 |
+
|
140 |
+
docs = loader.load()
|
141 |
+
docs = [doc for doc in docs if doc.metadata['file_type'] in ['.py', '.md']]
|
142 |
+
docs = [doc for doc in docs if len(doc.page_content) < 50000]
|
143 |
+
docs = python_splitter.split_documents(docs)
|
144 |
+
for doc in docs:
|
145 |
+
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
|
146 |
+
|
147 |
+
indexer = DataIndexer()
|
148 |
+
with open('/app/sources.txt', 'a') as file:
|
149 |
+
for doc in docs:
|
150 |
+
file.writelines(doc.metadata['source'] + '\n')
|
151 |
+
print('DONE')
|
152 |
+
indexer.index_data(docs)
|
153 |
+
|
154 |
+
|
app/database.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import create_engine
|
2 |
+
from sqlalchemy.ext.declarative import declarative_base
|
3 |
+
from sqlalchemy.orm import sessionmaker
|
4 |
+
|
5 |
+
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
6 |
+
|
7 |
+
engine = create_engine(
|
8 |
+
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
9 |
+
)
|
10 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
11 |
+
|
12 |
+
Base = declarative_base()
|
app/main.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.runnables import Runnable
|
2 |
+
from langchain_core.callbacks import BaseCallbackHandler
|
3 |
+
from fastapi import FastAPI, Request, Depends
|
4 |
+
from sse_starlette.sse import EventSourceResponse
|
5 |
+
from sqlalchemy.orm import Session
|
6 |
+
from langserve.serialization import WellKnownLCSerializer
|
7 |
+
from typing import Any, List
|
8 |
+
import crud, models, schemas
|
9 |
+
from database import SessionLocal, engine
|
10 |
+
from chains import llm, formatted_chain, history_chain, rag_chain
|
11 |
+
from prompts import format_chat_history, format_context
|
12 |
+
from callbacks import LogResponseCallback
|
13 |
+
from data_indexing import DataIndexer
|
14 |
+
|
15 |
+
models.Base.metadata.create_all(bind=engine)
|
16 |
+
|
17 |
+
app = FastAPI()
|
18 |
+
|
19 |
+
def get_db():
|
20 |
+
db = SessionLocal()
|
21 |
+
try:
|
22 |
+
yield db
|
23 |
+
finally:
|
24 |
+
db.close()
|
25 |
+
|
26 |
+
|
27 |
+
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
28 |
+
for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
|
29 |
+
data = WellKnownLCSerializer().dumps(output).decode("utf-8")
|
30 |
+
yield {'data': data, "event": "data"}
|
31 |
+
yield {"event": "end"}
|
32 |
+
|
33 |
+
|
34 |
+
@app.get("/")
|
35 |
+
def greet_json():
|
36 |
+
return {"Hello": "World!"}
|
37 |
+
|
38 |
+
|
39 |
+
@app.post("/simple/stream")
|
40 |
+
async def simple_stream(request: Request):
|
41 |
+
data = await request.json()
|
42 |
+
user_question = schemas.UserQuestion(**data['input'])
|
43 |
+
return EventSourceResponse(generate_stream(user_question, llm))
|
44 |
+
|
45 |
+
|
46 |
+
@app.post("/formatted/stream")
|
47 |
+
async def formatted_stream(request: Request):
|
48 |
+
data = await request.json()
|
49 |
+
user_question = schemas.UserQuestion(**data['input'])
|
50 |
+
return EventSourceResponse(generate_stream(user_question, formatted_chain))
|
51 |
+
|
52 |
+
|
53 |
+
@app.post("/history/stream")
|
54 |
+
async def history_stream(request: Request, db: Session = Depends(get_db)):
|
55 |
+
data = await request.json()
|
56 |
+
user_request = schemas.UserRequest(**data['input'])
|
57 |
+
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
|
58 |
+
message = schemas.MessageBase(message=user_request.question, type='User')
|
59 |
+
crud.add_message(db, message, user_request.username)
|
60 |
+
|
61 |
+
history_input = schemas.HistoryInput(
|
62 |
+
question=user_request.question,
|
63 |
+
chat_history=format_chat_history(chat_history)
|
64 |
+
)
|
65 |
+
|
66 |
+
return EventSourceResponse(generate_stream(
|
67 |
+
history_input,
|
68 |
+
history_chain,
|
69 |
+
[LogResponseCallback(user_request, db)]
|
70 |
+
))
|
71 |
+
|
72 |
+
|
73 |
+
@app.post("/rag/stream")
|
74 |
+
async def rag_stream(request: Request, db: Session = Depends(get_db)):
|
75 |
+
data = await request.json()
|
76 |
+
user_request = schemas.UserRequest(**data['input'])
|
77 |
+
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
|
78 |
+
message = schemas.MessageBase(message=user_request.question, type='User')
|
79 |
+
crud.add_message(db, message, user_request.username)
|
80 |
+
|
81 |
+
rag_input = schemas.RagInput(
|
82 |
+
question=user_request.question,
|
83 |
+
chat_history=format_chat_history(chat_history),
|
84 |
+
)
|
85 |
+
|
86 |
+
return EventSourceResponse(generate_stream(
|
87 |
+
rag_input,
|
88 |
+
rag_chain,
|
89 |
+
[LogResponseCallback(user_request, db)]
|
90 |
+
))
|
91 |
+
|
92 |
+
@app.post("/filtered_rag/stream")
|
93 |
+
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
94 |
+
data = await request.json()
|
95 |
+
print(data)
|
96 |
+
user_request = schemas.UserRequest(**data['input'])
|
97 |
+
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
|
98 |
+
message = schemas.MessageBase(message=user_request.question, type='User')
|
99 |
+
crud.add_message(db, message, user_request.username)
|
100 |
+
|
101 |
+
rag_input = schemas.RagInput(
|
102 |
+
question=user_request.question,
|
103 |
+
chat_history=format_chat_history(chat_history),
|
104 |
+
hybrid_search=True
|
105 |
+
)
|
106 |
+
|
107 |
+
return EventSourceResponse(generate_stream(
|
108 |
+
rag_input,
|
109 |
+
rag_chain,
|
110 |
+
[LogResponseCallback(user_request, db)]
|
111 |
+
))
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
import uvicorn
|
118 |
+
uvicorn.run("main:app", host="localhost", reload=True, port=8002)
|
app/models.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime
|
2 |
+
from sqlalchemy.orm import relationship
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
from database import Base
|
6 |
+
|
7 |
+
class User(Base):
|
8 |
+
__tablename__ = "users"
|
9 |
+
|
10 |
+
id = Column(Integer, primary_key=True, index=True)
|
11 |
+
username = Column(String, unique=True, index=True)
|
12 |
+
messages = relationship("Message", back_populates="user")
|
13 |
+
|
14 |
+
class Message(Base):
|
15 |
+
__tablename__ = "messages"
|
16 |
+
|
17 |
+
id = Column(Integer, primary_key=True, index=True)
|
18 |
+
user_id = Column(Integer, ForeignKey("users.id"))
|
19 |
+
message = Column(String)
|
20 |
+
type = Column(String)
|
21 |
+
timestamp = Column(DateTime, default=datetime.now)
|
22 |
+
|
23 |
+
user = relationship("User", back_populates="messages")
|
app/prompts.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
from typing import List
|
4 |
+
import models
|
5 |
+
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
7 |
+
|
8 |
+
raw_prompt = "{question}"
|
9 |
+
|
10 |
+
history_prompt = """
|
11 |
+
Given the following conversation provide a helpful answer to the follow up question.
|
12 |
+
|
13 |
+
Chat History:
|
14 |
+
{chat_history}
|
15 |
+
|
16 |
+
Follow Up question: {question}
|
17 |
+
helpful answer:
|
18 |
+
"""
|
19 |
+
|
20 |
+
question_prompt = """
|
21 |
+
Given the following conversation and a follow up question, rephrase the
|
22 |
+
follow up question to be a standalone question, in its original language.
|
23 |
+
|
24 |
+
Chat History:
|
25 |
+
{chat_history}
|
26 |
+
|
27 |
+
Follow Up Input: {question}
|
28 |
+
|
29 |
+
Standalone question:
|
30 |
+
"""
|
31 |
+
|
32 |
+
context_prompt = """
|
33 |
+
Answer the question based only on the following context:
|
34 |
+
{context}
|
35 |
+
|
36 |
+
Question: {standalone_question}
|
37 |
+
"""
|
38 |
+
|
39 |
+
map_prompt = """
|
40 |
+
Given the following list of file paths, return a comma separated list of the most likely files to have content that could potentially help answer the question. Return nothing if none of those would help.
|
41 |
+
Make sure to return the complete full paths as it is writen in the original list
|
42 |
+
|
43 |
+
File list:
|
44 |
+
{file_list}
|
45 |
+
|
46 |
+
Question: {question}
|
47 |
+
|
48 |
+
Return a comma separated list of files and nothing else!
|
49 |
+
Comma separated list:
|
50 |
+
"""
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
def format_prompt(prompt):
|
55 |
+
chat = [
|
56 |
+
{"role": "system", "content": "You are a helpful AI assistant."},
|
57 |
+
{"role": "user", "content": prompt},
|
58 |
+
]
|
59 |
+
|
60 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
61 |
+
chat,
|
62 |
+
tokenize=False,
|
63 |
+
add_generation_prompt=True
|
64 |
+
)
|
65 |
+
|
66 |
+
return PromptTemplate.from_template(formatted_prompt)
|
67 |
+
|
68 |
+
|
69 |
+
def format_chat_history(messages: List[models.Message]):
|
70 |
+
return '\n'.join([
|
71 |
+
'{}: {}'.format(message.type, message.message)
|
72 |
+
for message in messages
|
73 |
+
])
|
74 |
+
|
75 |
+
def format_context(docs: List[str]):
|
76 |
+
return '\n\n'.join(docs)
|
77 |
+
|
78 |
+
|
79 |
+
raw_prompt_formatted = format_prompt(raw_prompt)
|
80 |
+
history_prompt_formatted = format_prompt(history_prompt)
|
81 |
+
question_prompt_formatted = format_prompt(question_prompt)
|
82 |
+
context_prompt_formatted = format_prompt(context_prompt)
|
83 |
+
map_prompt_formatted = format_prompt(map_prompt)
|
84 |
+
|
app/schemas.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic.v1 import BaseModel
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
class UserQuestion(BaseModel):
|
6 |
+
question: str
|
7 |
+
|
8 |
+
class UserRequest(UserQuestion):
|
9 |
+
username: str
|
10 |
+
|
11 |
+
class HistoryInput(BaseModel):
|
12 |
+
chat_history: str
|
13 |
+
question: str
|
14 |
+
|
15 |
+
class RagInput(HistoryInput):
|
16 |
+
hybrid_search: bool = False
|
17 |
+
|
18 |
+
class MessageBase(BaseModel):
|
19 |
+
id: Optional[int] = None
|
20 |
+
user_id: Optional[int] = None
|
21 |
+
message: str
|
22 |
+
type: str
|
23 |
+
timestamp: Optional[datetime] = None
|
24 |
+
|
25 |
+
class Config:
|
26 |
+
orm_mode = True
|
app/sources.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
@@ -2,4 +2,14 @@ fastapi
|
|
2 |
uvicorn[standard]
|
3 |
langserve[server]
|
4 |
langchain
|
5 |
-
langchain-huggingface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
uvicorn[standard]
|
3 |
langserve[server]
|
4 |
langchain
|
5 |
+
langchain-huggingface
|
6 |
+
langchain-community
|
7 |
+
langchain-openai
|
8 |
+
langchain-chroma
|
9 |
+
openai
|
10 |
+
pinecone-client[grpc]
|
11 |
+
google
|
12 |
+
pypd
|
13 |
+
pypdf
|
14 |
+
google-api-python-client
|
15 |
+
faiss-cpu
|