8bitnand commited on
Commit
871255a
1 Parent(s): 87d5c64

Added support for streamlit and rag model

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. __init__.py +1 -0
  3. __pycache__/google.cpython-39.pyc +0 -0
  4. app.py +32 -3
  5. google.py +10 -7
  6. model.py +71 -8
  7. rag.configs.yml +3 -3
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from google import GoogleSearch, Document, SemanticSearch
__pycache__/google.cpython-39.pyc DELETED
Binary file (5.39 kB)
 
app.py CHANGED
@@ -1,10 +1,33 @@
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  st.title("LLM powred Google search")
4
 
5
  if "messages" not in st.session_state:
 
6
  st.session_state.messages = []
7
 
 
 
 
 
8
  for message in st.session_state.messages:
9
  with st.chat_message(message["role"]):
10
  st.markdown(message["content"])
@@ -14,10 +37,16 @@ if prompt := st.chat_input("Search Here insetad of Google"):
14
  st.chat_message("user").markdown(prompt)
15
  st.session_state.messages.append({"role": "user", "content": prompt})
16
 
17
- response = (
18
- f"Ans - {prompt}" # TODO add answer to the prompt by calling the answer method
 
 
 
 
19
  )
20
-
 
 
21
  with st.chat_message("assistant"):
22
  st.markdown(response)
23
 
 
1
+ import sys
2
  import streamlit as st
3
+ from google import SemanticSearch, GoogleSearch, Document
4
+ from model import RAGModel, load_configs
5
+
6
+
7
+ def run_on_start():
8
+ global r
9
+ global configs
10
+ configs = load_configs(config_file="rag.configs.yml")
11
+ r = RAGModel(configs)
12
+
13
+
14
+ def search(query):
15
+ g = GoogleSearch(query)
16
+ data = g.all_page_data
17
+ d = Document(data, min_char_len=configs["document"]["min_char_length"])
18
+ st.session_state.doc = d.doc()[0]
19
+
20
 
21
  st.title("LLM powred Google search")
22
 
23
  if "messages" not in st.session_state:
24
+ run_on_start()
25
  st.session_state.messages = []
26
 
27
+ if "doc" not in st.session_state:
28
+ st.session_state.doc = None
29
+
30
+
31
  for message in st.session_state.messages:
32
  with st.chat_message(message["role"]):
33
  st.markdown(message["content"])
 
37
  st.chat_message("user").markdown(prompt)
38
  st.session_state.messages.append({"role": "user", "content": prompt})
39
 
40
+ search(prompt)
41
+ s = SemanticSearch(
42
+ prompt,
43
+ st.session_state.doc,
44
+ configs["model"]["embeding_model"],
45
+ configs["model"]["device"],
46
  )
47
+ topk = s.semantic_search(query=prompt, k=32)
48
+ output = r.answer_query(query=prompt, topk_items=topk)
49
+ response = output
50
  with st.chat_message("assistant"):
51
  st.markdown(response)
52
 
google.py CHANGED
@@ -13,7 +13,7 @@ class GoogleSearch:
13
  escaped_query = urllib.parse.quote_plus(query)
14
  self.URL = f"https://www.google.com/search?q={escaped_query}"
15
 
16
- self.headers = headers = {
17
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3538.102 Safari/537.36"
18
  }
19
  self.links = self.get_initial_links()
@@ -46,7 +46,7 @@ class GoogleSearch:
46
  """
47
  scrape google for the query with keyword based search
48
  """
49
-
50
  response = requests.get(self.URL, headers=self.headers)
51
  soup = BeautifulSoup(response.text, "html.parser")
52
  anchors = soup.find_all("a", href=True)
@@ -95,6 +95,7 @@ class Document:
95
  return min_len_chunks
96
 
97
  def doc(self) -> tuple[list[str], list[str]]:
 
98
  chunked_data: list[str] = []
99
  urls: list[str] = []
100
  for url, dataitem in self.data:
@@ -108,16 +109,17 @@ class Document:
108
 
109
  class SemanticSearch:
110
  def __init__(
111
- self, query: str, d: Document, g: GoogleSearch, model_path: str, device: str
112
  ) -> None:
113
  query = query
114
- self.doc_chunks, self.urls = d.doc()
115
  self.st = SentenceTransformer(
116
  model_path,
117
  device,
118
  )
119
 
120
- def semanti_search(self, query: str, k: int = 10):
 
121
  query_embeding = self.get_embeding(query)
122
  doc_embeding = self.get_embeding(self.doc_chunks)
123
  scores = util.dot_score(a=query_embeding, b=doc_embeding)[0]
@@ -136,8 +138,9 @@ if __name__ == "__main__":
136
  g = GoogleSearch(query)
137
  data = g.all_page_data
138
  d = Document(data, 333)
139
- s = SemanticSearch(query, d, g, "all-mpnet-base-v2", "mps")
140
- print(len(s.semanti_search(query, k=64)))
 
141
 
142
  # g = GoogleSearch("what is LLM")
143
  # d = Document(g.all_page_data)
 
13
  escaped_query = urllib.parse.quote_plus(query)
14
  self.URL = f"https://www.google.com/search?q={escaped_query}"
15
 
16
+ self.headers = {
17
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3538.102 Safari/537.36"
18
  }
19
  self.links = self.get_initial_links()
 
46
  """
47
  scrape google for the query with keyword based search
48
  """
49
+ print("Searching Google...")
50
  response = requests.get(self.URL, headers=self.headers)
51
  soup = BeautifulSoup(response.text, "html.parser")
52
  anchors = soup.find_all("a", href=True)
 
95
  return min_len_chunks
96
 
97
  def doc(self) -> tuple[list[str], list[str]]:
98
+ print("Creating Document...")
99
  chunked_data: list[str] = []
100
  urls: list[str] = []
101
  for url, dataitem in self.data:
 
109
 
110
  class SemanticSearch:
111
  def __init__(
112
+ self, doc_chunks: tuple[list, list], model_path: str, device: str
113
  ) -> None:
114
  query = query
115
+ self.doc_chunks, self.urls = doc_chunks
116
  self.st = SentenceTransformer(
117
  model_path,
118
  device,
119
  )
120
 
121
+ def semantic_search(self, query: str, k: int = 10):
122
+ print("Searhing Top k in document...")
123
  query_embeding = self.get_embeding(query)
124
  doc_embeding = self.get_embeding(self.doc_chunks)
125
  scores = util.dot_score(a=query_embeding, b=doc_embeding)[0]
 
138
  g = GoogleSearch(query)
139
  data = g.all_page_data
140
  d = Document(data, 333)
141
+
142
+ s = SemanticSearch("all-mpnet-base-v2", "mps")
143
+ print(len(s.semantic_search(query, k=64)))
144
 
145
  # g = GoogleSearch("what is LLM")
146
  # d = Document(g.all_page_data)
model.py CHANGED
@@ -1,15 +1,78 @@
1
- from google import SemanticSearch
2
- from transformers import AutoTokenizer, AutoModel
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class RAGModel:
6
  def __init__(self, configs) -> None:
7
  self.configs = configs
8
- model_url = configs["RAG"]["genration_model"]
9
- self.model = AutoModel.from_pretrained(model_url)
10
- self.tokenizer = AutoTokenizer.from_pretrained(model_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def create_propmt(self, topk_items: list[str]):
13
 
14
-
15
- def answer_query(self, query: str, context: list[str]) :
 
 
 
 
 
 
 
 
 
1
+ from google import SemanticSearch, GoogleSearch, Document
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers import BitsAndBytesConfig
4
+ from transformers.utils import is_flash_attn_2_available
5
+ import yaml
6
+ import torch
7
+
8
+
9
+ def load_configs(config_file: str) -> dict:
10
+ with open(config_file, "r") as f:
11
+ configs = yaml.safe_load(f)
12
+
13
+ return configs
14
 
15
 
16
  class RAGModel:
17
  def __init__(self, configs) -> None:
18
  self.configs = configs
19
+ self.device = configs["model"]["device"]
20
+ model_url = configs["model"]["genration_model"]
21
+ # quantization_config = BitsAndBytesConfig(
22
+ # load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
23
+ # )
24
+
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_url,
27
+ torch_dtype=torch.float16,
28
+ # quantization_config=quantization_config,
29
+ low_cpu_mem_usage=False,
30
+ attn_implementation="sdpa",
31
+ ).to(self.device)
32
+ self.tokenizer = AutoTokenizer.from_pretrained(
33
+ model_url,
34
+ )
35
+
36
+ def create_prompt(self, query, topk_items: list[str]):
37
+
38
+ context = "_ " + "\n-".join(c for c in topk_items)
39
+
40
+ base_prompt = f"""Based on the follwing context items, please answer the query.
41
+ Give time for yourself to read the context and then answer the query.
42
+ Do not return thinking process, just return the answer.
43
+ If you do not find the answer, or if the query is offesnsive or in any other way harmfull just return "I'm not aware of it"
44
+ Now use the following context items to answer the user query.
45
+ {context}.
46
+ user query : {query}
47
+ """
48
+
49
+ dialog_template = [{"role": "user", "content": base_prompt}]
50
+
51
+ prompt = self.tokenizer.apply_chat_template(
52
+ conversation=dialog_template, tokenize=False, add_feneration_prompt=True
53
+ )
54
+ return prompt
55
+
56
+ def answer_query(self, query: str, topk_items: list[str]):
57
+
58
+ prompt = self.create_prompt(query, topk_items)
59
+ print(prompt)
60
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
61
+ output = self.model.generate(**input_ids, max_new_tokens=512)
62
+ text = self.tokenizer.decode(output[0])
63
+
64
+ return text
65
+
66
 
67
+ if __name__ == "__main__":
68
 
69
+ configs = load_configs(config_file="rag.configs.yml")
70
+ query = "what is LLM"
71
+ # g = GoogleSearch(query)
72
+ # data = g.all_page_data
73
+ # d = Document(data, 512)
74
+ # s = SemanticSearch( "all-mpnet-base-v2", "mps")
75
+ # topk = s.semantic_search(query=query, k=32)
76
+ r = RAGModel(configs)
77
+ output = r.answer_query(query=query, topk_items=[""])
78
+ print(output)
rag.configs.yml CHANGED
@@ -1,8 +1,8 @@
1
  document:
2
  min_char_length: 333
3
 
4
- common:
5
  embeding_model: all-mpnet-base-v2
6
- genration_model: meta-llama/Llama-2-7b
7
- device: cpu
8
 
 
1
  document:
2
  min_char_length: 333
3
 
4
+ model:
5
  embeding_model: all-mpnet-base-v2
6
+ genration_model: google/gemma-2b-it
7
+ device: mps
8