Pclanglais commited on
Commit
094bf8b
·
verified ·
1 Parent(s): b407d63

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +7 -6
  2. app.py +215 -0
  3. gitattributes +4 -0
  4. requirements.txt +14 -0
  5. theme_builder.py +3 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: SDN
3
- emoji: 🐢
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Motta
3
+ emoji: 📜
4
+ colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import re
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
+ from vllm import LLM, SamplingParams
5
+ import torch
6
+ import gradio as gr
7
+ import json
8
+ import os
9
+ import shutil
10
+ import requests
11
+ import chromadb
12
+ import pandas as pd
13
+ from chromadb.config import Settings
14
+ from chromadb.utils import embedding_functions
15
+
16
+ device = "cuda:0"
17
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="intfloat/multilingual-e5-base", device = "cuda")
18
+ client = chromadb.PersistentClient(path="education_corrected")
19
+ collection = client.get_collection(name="corrected", embedding_function = sentence_transformer_ef)
20
+
21
+
22
+ # Define the device
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ #Define variables
25
+ temperature=0.2
26
+ max_new_tokens=1000
27
+ top_p=0.92
28
+ repetition_penalty=1.7
29
+
30
+ model_name = "Pclanglais/Cassandre-Test"
31
+
32
+ llm = LLM(model_name, max_model_len=4096)
33
+
34
+ #Vector search over the database
35
+ def vector_search(collection, text):
36
+
37
+ results = collection.query(
38
+ query_texts=[text],
39
+ n_results=5,
40
+ )
41
+
42
+ document = []
43
+ document_html = []
44
+ id_list = ""
45
+ list_elm = 0
46
+ for ids in results["ids"][0]:
47
+ first_link = str(results["metadatas"][0][list_elm]["identifier"])
48
+ first_title = results["metadatas"][0][list_elm]["context"] + " " + results["documents"][0][list_elm]
49
+ list_elm = list_elm+1
50
+
51
+ document.append(first_link + " : " + first_title)
52
+ document_html.append('<div class="source" id="' + first_link + '"><p><b>' + first_link + "</b> : " + first_title + "</div>")
53
+
54
+ document = "\n\n".join(document)
55
+ document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
56
+ # Replace this with the actual implementation of the vector search
57
+ return document, document_html
58
+
59
+ #CSS for references formatting
60
+ css = """
61
+ .generation {
62
+ margin-left:2em;
63
+ margin-right:2em;
64
+ }
65
+
66
+ :target {
67
+ background-color: #CCF3DF; /* Change the text color to red */
68
+ }
69
+
70
+ .source {
71
+ float:left;
72
+ max-width:17%;
73
+ margin-left:2%;
74
+ }
75
+
76
+ .tooltip {
77
+ position: relative;
78
+ cursor: pointer;
79
+ font-variant-position: super;
80
+ color: #97999b;
81
+ }
82
+
83
+ .tooltip:hover::after {
84
+ content: attr(data-text);
85
+ position: absolute;
86
+ left: 0;
87
+ top: 120%; /* Adjust this value as needed to control the vertical spacing between the text and the tooltip */
88
+ white-space: pre-wrap; /* Allows the text to wrap */
89
+ width: 500px; /* Sets a fixed maximum width for the tooltip */
90
+ max-width: 500px; /* Ensures the tooltip does not exceed the maximum width */
91
+ z-index: 1;
92
+ background-color: #f9f9f9;
93
+ color: #000;
94
+ border: 1px solid #ddd;
95
+ border-radius: 5px;
96
+ padding: 5px;
97
+ display: block;
98
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Optional: Adds a subtle shadow for better visibility */
99
+ }"""
100
+
101
+ #Curtesy of chatgpt
102
+ def format_references(text):
103
+ # Define start and end markers for the reference
104
+ ref_start_marker = '<ref text="'
105
+ ref_end_marker = '</ref>'
106
+
107
+ # Initialize an empty list to hold parts of the text
108
+ parts = []
109
+ current_pos = 0
110
+ ref_number = 1
111
+
112
+ # Loop until no more reference start markers are found
113
+ while True:
114
+ start_pos = text.find(ref_start_marker, current_pos)
115
+ if start_pos == -1:
116
+ # No more references found, add the rest of the text
117
+ parts.append(text[current_pos:])
118
+ break
119
+
120
+ # Add text up to the start of the reference
121
+ parts.append(text[current_pos:start_pos])
122
+
123
+ # Find the end of the reference text attribute
124
+ end_pos = text.find('">', start_pos)
125
+ if end_pos == -1:
126
+ # Malformed reference, break to avoid infinite loop
127
+ break
128
+
129
+ # Extract the reference text
130
+ ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
131
+ ref_text_encoded = ref_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
132
+
133
+ # Find the end of the reference tag
134
+ ref_end_pos = text.find(ref_end_marker, end_pos)
135
+ if ref_end_pos == -1:
136
+ # Malformed reference, break to avoid infinite loop
137
+ break
138
+
139
+ # Extract the reference ID
140
+ ref_id = text[end_pos + 2:ref_end_pos].strip()
141
+
142
+ # Create the HTML for the tooltip
143
+ tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[' + str(ref_number) +']</a></span>'
144
+ parts.append(tooltip_html)
145
+
146
+ # Update current_pos to the end of the current reference
147
+ current_pos = ref_end_pos + len(ref_end_marker)
148
+ ref_number = ref_number + 1
149
+
150
+ # Join and return the parts
151
+ parts = ''.join(parts)
152
+
153
+ return parts
154
+
155
+ # Class to encapsulate the Falcon chatbot
156
+ class MistralChatBot:
157
+ def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
158
+ self.system_prompt = system_prompt
159
+
160
+ def predict(self, user_message):
161
+ fiches, fiches_html = vector_search(collection, user_message)
162
+ sampling_params = SamplingParams(temperature=.7, top_p=.95, max_tokens=2000, presence_penalty = 1.5, stop = ["``"])
163
+ detailed_prompt = """<|im_start|>system
164
+ Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées.<|im_end|>
165
+ <|im_start|>user
166
+ Ecrit un texte référencé en réponse à cette question : """ + user_message + """
167
+
168
+ Les références doivent être citées de cette manière : texte rédigé<ref text=\"[passage pertinent dans la référence]\">[\"identifiant de la référence\"]</ref>Si les références ne permettent pas de répondre, qu'il n'y a pas de réponse.
169
+
170
+ Les cinq références disponibles : """ + fiches + "<|im_end|>\n<|im_start|>assistant\n"
171
+ prompts = [detailed_prompt]
172
+ outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
173
+ generated_text = outputs[0].outputs[0].text
174
+ generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
175
+ fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
176
+ return generated_text, fiches_html
177
+
178
+ # Create the Falcon chatbot instance
179
+ mistral_bot = MistralChatBot()
180
+
181
+ # Define the Gradio interface
182
+ title = "Motta"
183
+ description = "Le LLM répond à toutes les questions sur la SDN."
184
+ examples = [
185
+ [
186
+ "Comment garantir la paix universelle?", # user_message
187
+ 0.7 # temperature
188
+ ]
189
+ ]
190
+
191
+ additional_inputs=[
192
+ gr.Slider(
193
+ label="Température",
194
+ value=0.2, # Default value
195
+ minimum=0.05,
196
+ maximum=1.0,
197
+ step=0.05,
198
+ interactive=True,
199
+ info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté",
200
+ ),
201
+ ]
202
+
203
+
204
+ demo = gr.Blocks()
205
+
206
+ with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo:
207
+ gr.HTML("""<h1 style="text-align:center">Motta</h1>""")
208
+ text_input = gr.Textbox(label="Votre question ou votre instruction.", type="text", lines=1)
209
+ text_button = gr.Button("Interroger Motta")
210
+ text_output = gr.HTML(label="La réponse de Motta")
211
+ embedding_output = gr.HTML(label="Les sources utilisées")
212
+ text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output, embedding_output])
213
+
214
+ if __name__ == "__main__":
215
+ demo.queue().launch()
gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ education_corrected/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
2
+ education_corrected/e150eb41-e894-45c4-b97c-80ced9ff2123/data_level0.bin filter=lfs diff=lfs merge=lfs -text
3
+ education_corrected/a9ac8f33-9498-450a-ae99-f116efb66330/data_level0.bin filter=lfs diff=lfs merge=lfs -text
4
+ education_corrected/6af97eb5-0cfa-40b2-a4df-732ca13bd66a/data_level0.bin filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ einops
4
+ accelerate
5
+ tiktoken
6
+ scipy
7
+ transformers_stream_generator==0.0.4
8
+ peft
9
+ deepspeed
10
+ bitsandbytes
11
+ optimum
12
+ vllm==0.3.2
13
+ chromadb
14
+ sentence_transformers
theme_builder.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import gradio as gr
2
+
3
+ gr.themes.builder()