File size: 2,667 Bytes
8b6196b
871255a
 
 
 
 
 
 
 
 
 
 
 
ca9a177
 
 
 
 
871255a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c575b59
871255a
 
 
c575b59
871255a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca9a177
871255a
ca9a177
871255a
c575b59
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from search import SemanticSearch, GoogleSearch, Document
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available
import yaml
import torch


def load_configs(config_file: str) -> dict:
    with open(config_file, "r") as f:
        configs = yaml.safe_load(f)

    return configs


class RAGModel:
    def __init__(self, configs) -> None:
        self.configs = configs
        self.device = configs["model"]["device"]
        model_url = configs["model"]["genration_model"]
        # quantization_config = BitsAndBytesConfig(
        #     load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
        # )

        self.model = AutoModelForCausalLM.from_pretrained(
            model_url,
            torch_dtype=torch.float16,
            # quantization_config=quantization_config,
            low_cpu_mem_usage=False,
            attn_implementation="sdpa",
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_url,
        )

    def create_prompt(self, query, topk_items: list[str]):

        context = "_ " + "\n-".join(c for c in topk_items)

        base_prompt = f"""Give time for yourself to read the context and then answer the query. 
        Do not return thinking process, just return the answer.
        If you do not find the answer, or if the query is offesnsive or in any other way harmfull just return "I'm not aware of it"
        Now use the following context items to answer the user query.
        context: {context}.
        user query : {query} 
        """

        dialog_template = [{"role": "user", "content": base_prompt}]

        prompt = self.tokenizer.apply_chat_template(
            conversation=dialog_template, tokenize=False, add_feneration_prompt=True
        )
        return prompt

    def answer_query(self, query: str, topk_items: list[str]):

        prompt = self.create_prompt(query, topk_items)
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        output = self.model.generate(**input_ids, max_new_tokens=512)
        text = self.tokenizer.decode(output[0])

        return text


if __name__ == "__main__":

    configs = load_configs(config_file="rag.configs.yml")
    query = "what is computer vision"
    g = GoogleSearch(query)
    data = g.all_page_data
    d = Document(data, 512)
    doc_chunks = d.doc()
    s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
    topk, u = s.semantic_search(query=query, k=32)
    r = RAGModel(configs)
    output = r.answer_query(query=query, topk_items=topk)
    print(output)