nachoremer commited on
Commit
a436de7
·
1 Parent(s): 36f8e22
Files changed (4) hide show
  1. app.py +51 -12
  2. config.py +9 -0
  3. import_model.py +32 -0
  4. requirements.txt +6 -5
app.py CHANGED
@@ -3,12 +3,14 @@
3
  import gradio as gr
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
- from config import groq_token, groq_model, QUESTION_PROMPT, init_google_sheets_client, groq_model, default_model_name, user_names, google_sheets_name, AtlasClient
7
  import gspread
8
  from groq import Client
9
  import random, string, json, io
10
-
 
11
  import groq
 
12
  print(groq.__version__)
13
  # Initialize Google Sheets client
14
  client = init_google_sheets_client()
@@ -18,7 +20,13 @@ stories_sheet = sheet.worksheet("Stories")
18
  system_prompts_sheet = sheet.worksheet("System Prompts")
19
 
20
  # Combine both model dictionaries
21
- all_models = {**groq_model}
 
 
 
 
 
 
22
 
23
  def randomize_key_order(aux):
24
  keys = list(aux.keys())
@@ -131,6 +139,24 @@ def save_comment_score(score, comment, story_name, user_name, system_prompt, mod
131
 
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Function to handle interaction with model
135
  def interact_groq(context, model_name):
136
  chat_completion = groq_clinet.chat.completions.create(
@@ -147,6 +173,7 @@ def interact_groq(context, model_name):
147
  #o=[chatbot_output, chat_history_json, data_table, selected_story_textbox])
148
  # Function to send selected story and initial message
149
  def send_selected_story(title, model_name, system_prompt):
 
150
  global chat_history
151
  global selected_story
152
  global data # Ensure data is reset
@@ -163,13 +190,21 @@ Here is the story:
163
  """
164
  combined_message = system_prompt.strip()
165
  if combined_message:
166
- chat_history = [] # Reset chat history
167
- chat_history.append({"role": "system", "content": combined_message})
168
- chat_history.append({"role": "user", "content": QUESTION_PROMPT})
169
-
170
- response = interact_groq(chat_history, model_name)
171
- resp = {"role": "assistant", "content": response.strip()}
172
- return resp, chat_history, story["story"]
 
 
 
 
 
 
 
 
173
  else:
174
  print("Combined message is empty.")
175
  else:
@@ -276,7 +311,10 @@ def multiple_interact(query, models, selected_model, assistant_user_input): #, i
276
  #no es models es....
277
  random.shuffle(active_models)
278
  for index, model in enumerate(active_models):
279
- resp = interact_groq(aux_history, model)
 
 
 
280
  resp = {"role": "assistant", "content": resp.strip()}
281
  chatbot_answser_list[alphabet[index]] = {'response': resp, 'model': model}
282
  try:
@@ -451,4 +489,5 @@ with gr.Blocks() as demo:
451
  inputs=[score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown, model_checkbox],
452
  outputs=[data_table, comment_input])
453
 
454
- demo.launch()
 
 
3
  import gradio as gr
4
  import pandas as pd
5
  from datetime import datetime, timedelta, timezone
6
+ from config import groq_token, groq_model, QUESTION_PROMPT, init_google_sheets_client, groq_model, default_model_name, user_names, google_sheets_name, AtlasClient, custom_model
7
  import gspread
8
  from groq import Client
9
  import random, string, json, io
10
+ #from trash_folder.alter_app import Local_llm
11
+ from import_model import Local_llm
12
  import groq
13
+ import torch
14
  print(groq.__version__)
15
  # Initialize Google Sheets client
16
  client = init_google_sheets_client()
 
20
  system_prompts_sheet = sheet.worksheet("System Prompts")
21
 
22
  # Combine both model dictionaries
23
+ all_models = {}
24
+ all_models.update(groq_model)
25
+ if torch.cuda.is_available():
26
+ all_models.update(custom_model)
27
+
28
+ #init local modeel as None
29
+ local_model = None
30
 
31
  def randomize_key_order(aux):
32
  keys = list(aux.keys())
 
139
 
140
 
141
 
142
+
143
+
144
+ from openai import OpenAI
145
+ client = OpenAI(
146
+ base_url="https://openrouter.ai/api/v1",
147
+ api_key="$OPENROUTER_API_KEY",
148
+ )
149
+ def interact_openrouter(context, model_name):
150
+ completion = client.chat.completions.create(
151
+ model=model_name,
152
+ messages=context,
153
+ )
154
+ return completion.choices[0].message.content
155
+
156
+
157
+
158
+
159
+
160
  # Function to handle interaction with model
161
  def interact_groq(context, model_name):
162
  chat_completion = groq_clinet.chat.completions.create(
 
173
  #o=[chatbot_output, chat_history_json, data_table, selected_story_textbox])
174
  # Function to send selected story and initial message
175
  def send_selected_story(title, model_name, system_prompt):
176
+ global local_model
177
  global chat_history
178
  global selected_story
179
  global data # Ensure data is reset
 
190
  """
191
  combined_message = system_prompt.strip()
192
  if combined_message:
193
+ chat_history = [] # Reset chat history
194
+ chat_history.append({"role": "system", "content": combined_message})
195
+ chat_history.append({"role": "user", "content": QUESTION_PROMPT})
196
+ if model_name in custom_model:
197
+ if local_model is None or local_model.model_name != custom_model[model_name]:
198
+ #si hay que cambiar el modelo o levantarlo
199
+ del local_model
200
+ torch.cuda.empty_cache()
201
+ torch.cuda.synchronize() #ver si funciona este
202
+ local_model = Local_llm(custom_model[model_name])
203
+ response = local_model.interact(chat_history)
204
+ else:
205
+ response = interact_groq(chat_history, model_name)
206
+ resp = {"role": "assistant", "content": response.strip()}
207
+ return resp, chat_history, story["story"]
208
  else:
209
  print("Combined message is empty.")
210
  else:
 
311
  #no es models es....
312
  random.shuffle(active_models)
313
  for index, model in enumerate(active_models):
314
+ if model in custom_model:
315
+ resp = local_model.interact(aux_history)
316
+ else:
317
+ resp = interact_groq(aux_history, model)
318
  resp = {"role": "assistant", "content": resp.strip()}
319
  chatbot_answser_list[alphabet[index]] = {'response': resp, 'model': model}
320
  try:
 
489
  inputs=[score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown, model_checkbox],
490
  outputs=[data_table, comment_input])
491
 
492
+ demo.launch(share=True)
493
+ #demo.launch(share=True)
config.py CHANGED
@@ -56,6 +56,15 @@ groq_model = {
56
  "llama3-70b-8192": "llama3-70b-8192",
57
  }
58
 
 
 
 
 
 
 
 
 
 
59
 
60
  # Default model (first in list)
61
  default_model_name = list(replicate_model.items())[0][0]
 
56
  "llama3-70b-8192": "llama3-70b-8192",
57
  }
58
 
59
+ custom_model = {
60
+ "rodrisouza/Llama-3-8B-Finetuning-Stories": "rodrisouza/Llama-3-8B-Finetuning-Stories"
61
+ }
62
+
63
+ openai_model = {
64
+ "meta-llama/llama-3.1-70b-instruct:free": "meta-llama/llama-3.1-70b-instruct:free",
65
+ "meta-llama/llama-3.1-8b-instruct:free": "meta-llama/llama-3.1-8b-instruct:free",
66
+ }
67
+
68
 
69
  # Default model (first in list)
70
  default_model_name = list(replicate_model.items())[0][0]
import_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2
+ import torch
3
+
4
+ class ModelLoader:
5
+ def __init__(self, model_name, hugging_face_token):
6
+ self.model_name = model_name
7
+ # Configure 4-bit quantization
8
+ self.bnb_config = BitsAndBytesConfig(
9
+ load_in_4bit=True,
10
+ bnb_4bit_quant_type="nf4",
11
+ bnb_4bit_compute_dtype=torch.bfloat16,
12
+ llm_int8_enable_fp32_cpu_offload=True
13
+ )
14
+
15
+ # Load tokenizer
16
+ self.tokenizer = AutoTokenizer.from_pretrained(
17
+ model_name,
18
+ token=hugging_face_token
19
+ )
20
+
21
+ # Load model with memory optimizations
22
+ self.model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ quantization_config=self.bnb_config,
25
+ device_map="auto",
26
+ low_cpu_mem_usage=True,
27
+ max_memory={
28
+ "cpu": "12GiB",
29
+ "cuda:0": "4GiB",
30
+ },
31
+ token=hugging_face_token
32
+ )
requirements.txt CHANGED
@@ -1,14 +1,15 @@
1
  huggingface_hub==0.25.1
2
  minijinja
3
- #transformers
4
- #torch
5
  pandas
6
  gspread
7
  oauth2client
8
- #accelerate
9
- #bitsandbytes
10
  replicate
11
  groq==0.11.0
12
  gradio
13
  google-api-python-client
14
- pymongo==4.6.2
 
 
1
  huggingface_hub==0.25.1
2
  minijinja
3
+ transformers
4
+ torch
5
  pandas
6
  gspread
7
  oauth2client
8
+ accelerate
9
+ bitsandbytes
10
  replicate
11
  groq==0.11.0
12
  gradio
13
  google-api-python-client
14
+ pymongo==4.6.2
15
+ openai