MikeCraBash commited on
Commit
7f049a0
1 Parent(s): 9ba91e5
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,14 +1,14 @@
1
- # HACK AI MAKERSPACE PREPR
 
2
  # Date: 2024-5-16
3
 
4
  # Basic Imports & Setup
5
  import os
6
- from openai import AsyncOpenAI
7
 
8
  # Using Chainlit for our UI
9
  import chainlit as cl
10
  from chainlit.prompt import Prompt, PromptMessage
11
- from chainlit.playground.providers import ChatOpenAI
12
 
13
  # Getting the API key from the .env file
14
  from dotenv import load_dotenv
@@ -27,7 +27,7 @@ docs = PyMuPDFLoader(direct_url).load()
27
 
28
  import tiktoken
29
  def tiktoken_len(text):
30
- tokens = tiktoken.encoding_for_model("gpt-3.5-turbo").encode(
31
  text,
32
  )
33
  return len(tokens)
@@ -44,12 +44,12 @@ text_splitter = RecursiveCharacterTextSplitter(
44
  split_chunks = text_splitter.split_documents(docs)
45
 
46
  # Load the embeddings model
47
- from langchain_openai.embeddings import OpenAIEmbeddings
48
 
49
- embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
50
 
51
  # Load the vector store and retriever from Qdrant
52
- from langchain_community.vectorstores import Qdrant
53
 
54
  qdrant_vectorstore = Qdrant.from_documents(
55
  split_chunks,
@@ -60,10 +60,11 @@ qdrant_vectorstore = Qdrant.from_documents(
60
 
61
  qdrant_retriever = qdrant_vectorstore.as_retriever()
62
 
63
- from langchain_openai import ChatOpenAI
64
- openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo")
 
65
 
66
- from langchain_core.prompts import ChatPromptTemplate
67
 
68
  RAG_PROMPT = """
69
  SYSTEM:
@@ -120,14 +121,14 @@ from langchain.schema.runnable import RunnablePassthrough
120
  retrieval_augmented_qa_chain = (
121
  {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
122
  | RunnablePassthrough.assign(context=itemgetter("context"))
123
- | {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
124
  )
125
 
126
  # Chainlit App
127
  @cl.on_chat_start
128
  async def start_chat():
129
  settings = {
130
- "model": "gpt-3.5-turbo",
131
  "temperature": 0,
132
  "max_tokens": 500,
133
  "top_p": 1,
@@ -145,3 +146,5 @@ async def main(message: cl.Message):
145
 
146
  msg = cl.Message(content=chainlit_answer)
147
  await msg.send()
 
 
 
1
+ #
2
+ # HACK AI MAKERSPACE PREPR
3
  # Date: 2024-5-16
4
 
5
  # Basic Imports & Setup
6
  import os
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  # Using Chainlit for our UI
10
  import chainlit as cl
11
  from chainlit.prompt import Prompt, PromptMessage
 
12
 
13
  # Getting the API key from the .env file
14
  from dotenv import load_dotenv
 
27
 
28
  import tiktoken
29
  def tiktoken_len(text):
30
+ tokens = tiktoken.encoding_for_model("solar-10.7b").encode(
31
  text,
32
  )
33
  return len(tokens)
 
44
  split_chunks = text_splitter.split_documents(docs)
45
 
46
  # Load the embeddings model
47
+ from langchain.embeddings import HuggingFaceEmbeddings
48
 
49
+ embedding_model = HuggingFaceEmbeddings(model_name="solar-10.7b")
50
 
51
  # Load the vector store and retriever from Qdrant
52
+ from langchain.vectorstores import Qdrant
53
 
54
  qdrant_vectorstore = Qdrant.from_documents(
55
  split_chunks,
 
60
 
61
  qdrant_retriever = qdrant_vectorstore.as_retriever()
62
 
63
+ # Load the Solar 10.7B model
64
+ tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-v1.0")
65
+ model = AutoModelForCausalLM.from_pretrained("Upstage/SOLAR-10.7B-v1.0")
66
 
67
+ from langchain.prompts import ChatPromptTemplate
68
 
69
  RAG_PROMPT = """
70
  SYSTEM:
 
121
  retrieval_augmented_qa_chain = (
122
  {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
123
  | RunnablePassthrough.assign(context=itemgetter("context"))
124
+ | {"response": rag_prompt | model, "context": itemgetter("context")}
125
  )
126
 
127
  # Chainlit App
128
  @cl.on_chat_start
129
  async def start_chat():
130
  settings = {
131
+ "model": "solar-10.7b",
132
  "temperature": 0,
133
  "max_tokens": 500,
134
  "top_p": 1,
 
146
 
147
  msg = cl.Message(content=chainlit_answer)
148
  await msg.send()
149
+
150
+