captain-awesome commited on
Commit
9d2bf07
·
verified ·
1 Parent(s): 48e3505

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -23,14 +23,16 @@ import torch
23
 
24
 
25
  def get_vector_store_from_url(url):
26
- model_name = "BAAI/bge-large-en"
27
- model_kwargs = {'device': 'cpu'}
28
- encode_kwargs = {'normalize_embeddings': False}
29
- embeddings = HuggingFaceBgeEmbeddings(
30
- model_name=model_name,
31
- model_kwargs=model_kwargs,
32
- encode_kwargs=encode_kwargs
33
- )
 
 
34
 
35
  loader = WebBaseLoader(url)
36
  document = loader.load()
@@ -114,17 +116,23 @@ def get_response(user_input):
114
  # lib="avx2", # for CPU
115
  # )
116
 
117
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
118
- # llm = HuggingFaceHub(
119
- # repo_id=llm_model,
120
- # model_kwargs={"temperature": 0.3, "max_new_tokens": 250, "top_k": 3}
 
 
 
 
 
 
 
121
  # )
122
 
123
- llm = transformers.AutoModelForCausalLM.from_pretrained(
124
- model_name,
125
- trust_remote_code=True,
126
- torch_dtype=torch.bfloat16,
127
- device_map='auto'
128
  )
129
  retriever_chain = get_context_retriever_chain(st.session_state.vector_store,llm)
130
  conversation_rag_chain = get_conversational_rag_chain(retriever_chain,llm)
 
23
 
24
 
25
  def get_vector_store_from_url(url):
26
+ # model_name = "BAAI/bge-large-en"
27
+ # model_kwargs = {'device': 'cpu'}
28
+ # encode_kwargs = {'normalize_embeddings': False}
29
+ # embeddings = HuggingFaceBgeEmbeddings(
30
+ # model_name=model_name,
31
+ # model_kwargs=model_kwargs,
32
+ # encode_kwargs=encode_kwargs
33
+ # )
34
+ embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-large',
35
+ model_kwargs={'device': 'cpu'})
36
 
37
  loader = WebBaseLoader(url)
38
  document = loader.load()
 
116
  # lib="avx2", # for CPU
117
  # )
118
 
119
+ # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
120
+ # # llm = HuggingFaceHub(
121
+ # # repo_id=llm_model,
122
+ # # model_kwargs={"temperature": 0.3, "max_new_tokens": 250, "top_k": 3}
123
+ # # )
124
+
125
+ # llm = transformers.AutoModelForCausalLM.from_pretrained(
126
+ # model_name,
127
+ # trust_remote_code=True,
128
+ # torch_dtype=torch.bfloat16,
129
+ # device_map='auto'
130
  # )
131
 
132
+ llm = HuggingFacePipeline.from_model_id(
133
+ model_id="google/flan-t5-base",
134
+ task="text2text-generation",
135
+ # model_kwargs={"temperature": 0.2},
 
136
  )
137
  retriever_chain = get_context_retriever_chain(st.session_state.vector_store,llm)
138
  conversation_rag_chain = get_conversational_rag_chain(retriever_chain,llm)