Spaces:
Sleeping
Sleeping
seanpedrickcase
commited on
Commit
·
f5a842c
1
Parent(s):
a10d388
Enhanced local model support by adding model loading functionality in chatfuncs.py and updating llm_api_call.py to utilize local models for topic extraction and summarization. Improved model path handling and ensured compatibility with GPU configurations.
Browse files- app.py +1 -0
- tools/chatfuncs.py +22 -25
- tools/llm_api_call.py +29 -14
app.py
CHANGED
@@ -6,6 +6,7 @@ from tools.aws_functions import upload_file_to_s3, RUN_AWS_FUNCTIONS
|
|
6 |
from tools.llm_api_call import extract_topics, load_in_data_file, load_in_previous_data_files, sample_reference_table_summaries, summarise_output_topics, batch_size_default
|
7 |
from tools.auth import authenticate_user
|
8 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt
|
|
|
9 |
#from tools.aws_functions import load_data_from_aws
|
10 |
import gradio as gr
|
11 |
import pandas as pd
|
|
|
6 |
from tools.llm_api_call import extract_topics, load_in_data_file, load_in_previous_data_files, sample_reference_table_summaries, summarise_output_topics, batch_size_default
|
7 |
from tools.auth import authenticate_user
|
8 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt
|
9 |
+
from tools.chatfuncs import load_model
|
10 |
#from tools.aws_functions import load_data_from_aws
|
11 |
import gradio as gr
|
12 |
import pandas as pd
|
tools/chatfuncs.py
CHANGED
@@ -2,7 +2,6 @@ from typing import TypeVar
|
|
2 |
import torch.cuda
|
3 |
import os
|
4 |
import time
|
5 |
-
import spaces
|
6 |
from llama_cpp import Llama
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
from tools.helper_functions import RUN_LOCAL_MODEL
|
@@ -111,9 +110,23 @@ class LlamaCPPGenerationConfig:
|
|
111 |
###
|
112 |
# Load local model
|
113 |
###
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
'''
|
118 |
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
119 |
'''
|
@@ -136,26 +149,11 @@ def load_model(local_model_type:str, gpu_layers:int, max_context_length:int, gpu
|
|
136 |
|
137 |
#print(vars(gpu_config))
|
138 |
#print(vars(cpu_config))
|
139 |
-
|
140 |
-
def get_model_path():
|
141 |
-
repo_id = os.environ.get("REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
|
142 |
-
filename = os.environ.get("MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
|
143 |
-
model_dir = "model/gemma" #"model/phi" # Assuming this is your intended directory
|
144 |
-
|
145 |
-
# Construct the expected local path
|
146 |
-
local_path = os.path.join(model_dir, filename)
|
147 |
-
|
148 |
-
if os.path.exists(local_path):
|
149 |
-
print(f"Model already exists at: {local_path}")
|
150 |
-
return local_path
|
151 |
-
else:
|
152 |
-
print(f"Checking default Hugging Face folder. Downloading model from Hugging Face Hub if not found")
|
153 |
-
return hf_hub_download(repo_id=repo_id, filename=filename)
|
154 |
|
155 |
model_path = get_model_path()
|
156 |
|
157 |
try:
|
158 |
-
print(vars(gpu_config))
|
159 |
llama_model = Llama(model_path=model_path, **vars(gpu_config)) # type_k=8, type_v = 8, flash_attn=True,
|
160 |
|
161 |
except Exception as e:
|
@@ -172,15 +170,15 @@ def load_model(local_model_type:str, gpu_layers:int, max_context_length:int, gpu
|
|
172 |
load_confirmation = "Finished loading model: " + local_model_type
|
173 |
|
174 |
print(load_confirmation)
|
175 |
-
return
|
176 |
|
177 |
###
|
178 |
# Load local model
|
179 |
###
|
180 |
-
if RUN_LOCAL_MODEL == "1":
|
181 |
-
|
182 |
-
|
183 |
-
print("model loaded:", model)
|
184 |
|
185 |
|
186 |
def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
@@ -216,7 +214,6 @@ def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
|
216 |
print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
|
217 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
218 |
|
219 |
-
@spaces.GPU
|
220 |
def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
221 |
"""
|
222 |
Calls your generation model with parameters from the LlamaCPPGenerationConfig object.
|
|
|
2 |
import torch.cuda
|
3 |
import os
|
4 |
import time
|
|
|
5 |
from llama_cpp import Llama
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from tools.helper_functions import RUN_LOCAL_MODEL
|
|
|
110 |
###
|
111 |
# Load local model
|
112 |
###
|
113 |
+
def get_model_path():
|
114 |
+
repo_id = os.environ.get("REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
|
115 |
+
filename = os.environ.get("MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
|
116 |
+
model_dir = "model/gemma" #"model/phi" # Assuming this is your intended directory
|
117 |
|
118 |
+
# Construct the expected local path
|
119 |
+
local_path = os.path.join(model_dir, filename)
|
120 |
+
|
121 |
+
if os.path.exists(local_path):
|
122 |
+
print(f"Model already exists at: {local_path}")
|
123 |
+
return local_path
|
124 |
+
else:
|
125 |
+
print(f"Checking default Hugging Face folder. Downloading model from Hugging Face Hub if not found")
|
126 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
127 |
+
|
128 |
+
|
129 |
+
def load_model(local_model_type:str=local_model_type, gpu_layers:int=gpu_layers, max_context_length:int=context_length, gpu_config:llama_cpp_init_config_gpu=gpu_config, cpu_config:llama_cpp_init_config_cpu=cpu_config, torch_device:str=torch_device):
|
130 |
'''
|
131 |
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
132 |
'''
|
|
|
149 |
|
150 |
#print(vars(gpu_config))
|
151 |
#print(vars(cpu_config))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
model_path = get_model_path()
|
154 |
|
155 |
try:
|
156 |
+
print("GPU load variables:" , vars(gpu_config))
|
157 |
llama_model = Llama(model_path=model_path, **vars(gpu_config)) # type_k=8, type_v = 8, flash_attn=True,
|
158 |
|
159 |
except Exception as e:
|
|
|
170 |
load_confirmation = "Finished loading model: " + local_model_type
|
171 |
|
172 |
print(load_confirmation)
|
173 |
+
return model, tokenizer
|
174 |
|
175 |
###
|
176 |
# Load local model
|
177 |
###
|
178 |
+
# if RUN_LOCAL_MODEL == "1":
|
179 |
+
# print("Loading model")
|
180 |
+
# local_model_type, load_confirmation, local_model_type, model, tokenizer = load_model(local_model_type, gpu_layers, context_length, gpu_config, cpu_config, torch_device)
|
181 |
+
# print("model loaded:", model)
|
182 |
|
183 |
|
184 |
def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
|
|
214 |
print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
|
215 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
216 |
|
|
|
217 |
def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
218 |
"""
|
219 |
Calls your generation model with parameters from the LlamaCPPGenerationConfig object.
|
tools/llm_api_call.py
CHANGED
@@ -9,6 +9,7 @@ import boto3
|
|
9 |
import json
|
10 |
import string
|
11 |
import re
|
|
|
12 |
from rapidfuzz import process, fuzz
|
13 |
from tqdm import tqdm
|
14 |
from gradio import Progress
|
@@ -19,7 +20,7 @@ GradioFileData = gr.FileData
|
|
19 |
|
20 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt
|
21 |
from tools.helper_functions import output_folder, detect_file_type, get_file_path_end, read_file, get_or_create_env_var, model_name_map, put_columns_in_df
|
22 |
-
from tools.chatfuncs import LlamaCPPGenerationConfig, call_llama_cpp_model
|
23 |
|
24 |
# ResponseObject class for AWS Bedrock calls
|
25 |
class ResponseObject:
|
@@ -331,7 +332,7 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
331 |
return response
|
332 |
|
333 |
# Function to send a request and update history
|
334 |
-
def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
335 |
"""
|
336 |
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
337 |
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
@@ -412,7 +413,7 @@ def send_request(prompt: str, conversation_history: List[dict], model: object, c
|
|
412 |
gen_config = LlamaCPPGenerationConfig()
|
413 |
gen_config.update_temp(temperature)
|
414 |
|
415 |
-
response = call_llama_cpp_model(prompt, gen_config)
|
416 |
|
417 |
#progress_bar.close()
|
418 |
#tqdm._instances.clear()
|
@@ -449,7 +450,7 @@ def send_request(prompt: str, conversation_history: List[dict], model: object, c
|
|
449 |
|
450 |
return response, conversation_history
|
451 |
|
452 |
-
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
453 |
"""
|
454 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
455 |
|
@@ -464,6 +465,7 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
464 |
model_choice (str): The choice of model to use.
|
465 |
temperature (float): The temperature parameter for the model.
|
466 |
batch_no (int): Batch number of the large language model request.
|
|
|
467 |
master (bool): Is this request for the master table.
|
468 |
|
469 |
Returns:
|
@@ -478,7 +480,7 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
478 |
|
479 |
#print("prompt to LLM:", prompt)
|
480 |
|
481 |
-
response, conversation_history = send_request(prompt, conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature)
|
482 |
|
483 |
if isinstance(response, ResponseObject):
|
484 |
responses.append(response)
|
@@ -872,7 +874,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
|
|
872 |
return topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_with_response_df, markdown_table, out_reference_df, out_unique_topics_df, batch_file_path_details, is_error
|
873 |
|
874 |
|
875 |
-
|
876 |
def extract_topics(in_data_file,
|
877 |
file_data:pd.DataFrame,
|
878 |
existing_topics_table:pd.DataFrame,
|
@@ -991,6 +993,11 @@ def extract_topics(in_data_file,
|
|
991 |
out_file_paths = []
|
992 |
print("model_choice_clean:", model_choice_clean)
|
993 |
|
|
|
|
|
|
|
|
|
|
|
994 |
#print("latest_batch_completed:", str(latest_batch_completed))
|
995 |
|
996 |
# If we have already redacted the last file, return the input out_message and file list to the relevant components
|
@@ -1096,6 +1103,8 @@ def extract_topics(in_data_file,
|
|
1096 |
topics_loop_description = "Extracting topics from response batches (each batch of " + str(batch_size) + " responses)."
|
1097 |
topics_loop = tqdm(range(latest_batch_completed, num_batches), desc = topics_loop_description, unit="batches remaining")
|
1098 |
|
|
|
|
|
1099 |
for i in topics_loop:
|
1100 |
|
1101 |
#for latest_batch_completed in range(num_batches):
|
@@ -1207,7 +1216,7 @@ def extract_topics(in_data_file,
|
|
1207 |
summary_whole_conversation = []
|
1208 |
|
1209 |
# Process requests to large language model
|
1210 |
-
master_summary_response, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(summary_prompt_list, add_existing_topics_system_prompt, summary_conversation_history, summary_whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, master = True)
|
1211 |
|
1212 |
# print("master_summary_response:", master_summary_response[-1].text)
|
1213 |
# print("Whole conversation metadata:", whole_conversation_metadata)
|
@@ -1299,7 +1308,7 @@ def extract_topics(in_data_file,
|
|
1299 |
whole_conversation = [system_prompt]
|
1300 |
|
1301 |
# Process requests to large language model
|
1302 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata = process_requests(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no)
|
1303 |
|
1304 |
# print("Whole conversation metadata before:", whole_conversation_metadata)
|
1305 |
|
@@ -1533,7 +1542,7 @@ def sample_reference_table_summaries(reference_df:pd.DataFrame,
|
|
1533 |
|
1534 |
return summarised_references, summarised_references_markdown, reference_df, unique_topics_df
|
1535 |
|
1536 |
-
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str):
|
1537 |
conversation_history = []
|
1538 |
whole_conversation_metadata = []
|
1539 |
|
@@ -1549,7 +1558,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
1549 |
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
1550 |
|
1551 |
# Process requests to large language model
|
1552 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature)
|
1553 |
|
1554 |
print("Finished summary query")
|
1555 |
|
@@ -1569,6 +1578,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
1569 |
|
1570 |
return latest_response_text, conversation_history, whole_conversation_metadata
|
1571 |
|
|
|
1572 |
def summarise_output_topics(summarised_references:pd.DataFrame,
|
1573 |
unique_table_df:pd.DataFrame,
|
1574 |
reference_table_df:pd.DataFrame,
|
@@ -1646,11 +1656,16 @@ def summarise_output_topics(summarised_references:pd.DataFrame,
|
|
1646 |
|
1647 |
tic = time.perf_counter()
|
1648 |
|
1649 |
-
print("Starting with:", latest_summary_completed)
|
1650 |
-
print("Last summary number:", length_all_summaries)
|
|
|
|
|
|
|
|
|
|
|
1651 |
|
1652 |
summary_loop_description = "Creating summaries. " + str(latest_summary_completed) + " summaries completed so far."
|
1653 |
-
summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Creating summaries", unit="summaries")
|
1654 |
|
1655 |
for summary_no in summary_loop:
|
1656 |
|
@@ -1661,7 +1676,7 @@ def summarise_output_topics(summarised_references:pd.DataFrame,
|
|
1661 |
formatted_summary_prompt = [summarise_topic_descriptions_prompt.format(summaries=summary_text)]
|
1662 |
|
1663 |
try:
|
1664 |
-
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, summarise_topic_descriptions_system_prompt)
|
1665 |
summarised_output = response
|
1666 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
1667 |
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
|
|
9 |
import json
|
10 |
import string
|
11 |
import re
|
12 |
+
import spaces
|
13 |
from rapidfuzz import process, fuzz
|
14 |
from tqdm import tqdm
|
15 |
from gradio import Progress
|
|
|
20 |
|
21 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt
|
22 |
from tools.helper_functions import output_folder, detect_file_type, get_file_path_end, read_file, get_or_create_env_var, model_name_map, put_columns_in_df
|
23 |
+
from tools.chatfuncs import LlamaCPPGenerationConfig, call_llama_cpp_model, load_model, RUN_LOCAL_MODEL
|
24 |
|
25 |
# ResponseObject class for AWS Bedrock calls
|
26 |
class ResponseObject:
|
|
|
332 |
return response
|
333 |
|
334 |
# Function to send a request and update history
|
335 |
+
def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, local_model=[], progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
336 |
"""
|
337 |
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
338 |
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
|
|
413 |
gen_config = LlamaCPPGenerationConfig()
|
414 |
gen_config.update_temp(temperature)
|
415 |
|
416 |
+
response = call_llama_cpp_model(prompt, gen_config, model=local_model)
|
417 |
|
418 |
#progress_bar.close()
|
419 |
#tqdm._instances.clear()
|
|
|
450 |
|
451 |
return response, conversation_history
|
452 |
|
453 |
+
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, local_model = [], master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
454 |
"""
|
455 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
456 |
|
|
|
465 |
model_choice (str): The choice of model to use.
|
466 |
temperature (float): The temperature parameter for the model.
|
467 |
batch_no (int): Batch number of the large language model request.
|
468 |
+
local_model: Local gguf model (if loaded)
|
469 |
master (bool): Is this request for the master table.
|
470 |
|
471 |
Returns:
|
|
|
480 |
|
481 |
#print("prompt to LLM:", prompt)
|
482 |
|
483 |
+
response, conversation_history = send_request(prompt, conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model)
|
484 |
|
485 |
if isinstance(response, ResponseObject):
|
486 |
responses.append(response)
|
|
|
874 |
return topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_with_response_df, markdown_table, out_reference_df, out_unique_topics_df, batch_file_path_details, is_error
|
875 |
|
876 |
|
877 |
+
@spaces.GPU
|
878 |
def extract_topics(in_data_file,
|
879 |
file_data:pd.DataFrame,
|
880 |
existing_topics_table:pd.DataFrame,
|
|
|
993 |
out_file_paths = []
|
994 |
print("model_choice_clean:", model_choice_clean)
|
995 |
|
996 |
+
if (model_choice == "gemma_2b_it_local") & (RUN_LOCAL_MODEL == "1"):
|
997 |
+
progress(0.1, "Loading in Gemma 2b model")
|
998 |
+
local_model, tokenizer = load_model()
|
999 |
+
print("Local model loaded:", local_model)
|
1000 |
+
|
1001 |
#print("latest_batch_completed:", str(latest_batch_completed))
|
1002 |
|
1003 |
# If we have already redacted the last file, return the input out_message and file list to the relevant components
|
|
|
1103 |
topics_loop_description = "Extracting topics from response batches (each batch of " + str(batch_size) + " responses)."
|
1104 |
topics_loop = tqdm(range(latest_batch_completed, num_batches), desc = topics_loop_description, unit="batches remaining")
|
1105 |
|
1106 |
+
|
1107 |
+
|
1108 |
for i in topics_loop:
|
1109 |
|
1110 |
#for latest_batch_completed in range(num_batches):
|
|
|
1216 |
summary_whole_conversation = []
|
1217 |
|
1218 |
# Process requests to large language model
|
1219 |
+
master_summary_response, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(summary_prompt_list, add_existing_topics_system_prompt, summary_conversation_history, summary_whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, local_model, master = True)
|
1220 |
|
1221 |
# print("master_summary_response:", master_summary_response[-1].text)
|
1222 |
# print("Whole conversation metadata:", whole_conversation_metadata)
|
|
|
1308 |
whole_conversation = [system_prompt]
|
1309 |
|
1310 |
# Process requests to large language model
|
1311 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata = process_requests(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, reported_batch_no, local_model)
|
1312 |
|
1313 |
# print("Whole conversation metadata before:", whole_conversation_metadata)
|
1314 |
|
|
|
1542 |
|
1543 |
return summarised_references, summarised_references_markdown, reference_df, unique_topics_df
|
1544 |
|
1545 |
+
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, local_model=[]):
|
1546 |
conversation_history = []
|
1547 |
whole_conversation_metadata = []
|
1548 |
|
|
|
1558 |
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
1559 |
|
1560 |
# Process requests to large language model
|
1561 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, model, config, model_choice, temperature, local_model=local_model)
|
1562 |
|
1563 |
print("Finished summary query")
|
1564 |
|
|
|
1578 |
|
1579 |
return latest_response_text, conversation_history, whole_conversation_metadata
|
1580 |
|
1581 |
+
@spaces.GPU
|
1582 |
def summarise_output_topics(summarised_references:pd.DataFrame,
|
1583 |
unique_table_df:pd.DataFrame,
|
1584 |
reference_table_df:pd.DataFrame,
|
|
|
1656 |
|
1657 |
tic = time.perf_counter()
|
1658 |
|
1659 |
+
#print("Starting with:", latest_summary_completed)
|
1660 |
+
#print("Last summary number:", length_all_summaries)
|
1661 |
+
|
1662 |
+
if (model_choice == "gemma_2b_it_local") & (RUN_LOCAL_MODEL == "1"):
|
1663 |
+
progress(0.1, "Loading in Gemma 2b model")
|
1664 |
+
local_model, tokenizer = load_model()
|
1665 |
+
print("Local model loaded:", local_model)
|
1666 |
|
1667 |
summary_loop_description = "Creating summaries. " + str(latest_summary_completed) + " summaries completed so far."
|
1668 |
+
summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Creating summaries", unit="summaries")
|
1669 |
|
1670 |
for summary_no in summary_loop:
|
1671 |
|
|
|
1676 |
formatted_summary_prompt = [summarise_topic_descriptions_prompt.format(summaries=summary_text)]
|
1677 |
|
1678 |
try:
|
1679 |
+
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, summarise_topic_descriptions_system_prompt, local_model)
|
1680 |
summarised_output = response
|
1681 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
1682 |
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|