Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: MIT | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a | |
# copy of this software and associated documentation files (the "Software"), | |
# to deal in the Software without restriction, including without limitation | |
# the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
# and/or sell copies of the Software, and to permit persons to whom the | |
# Software is furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
# DEALINGS IN THE SOFTWARE. | |
import os | |
import time | |
import json | |
import logging | |
import gc | |
import torch | |
from pathlib import Path | |
from trt_llama_api import TrtLlmAPI | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from collections import defaultdict | |
from llama_index import ServiceContext | |
from llama_index.llms.llama_utils import messages_to_prompt, completion_to_prompt | |
from llama_index import set_global_service_context | |
from faiss_vector_storage import FaissEmbeddingStorage | |
from ui.user_interface import MainInterface | |
app_config_file = 'config\\app_config.json' | |
model_config_file = 'config\\config.json' | |
preference_config_file = 'config\\preferences.json' | |
data_source = 'directory' | |
def read_config(file_name): | |
try: | |
with open(file_name, 'r') as file: | |
return json.load(file) | |
except FileNotFoundError: | |
print(f"The file {file_name} was not found.") | |
except json.JSONDecodeError: | |
print(f"There was an error decoding the JSON from the file {file_name}.") | |
except Exception as e: | |
print(f"An unexpected error occurred: {e}") | |
return None | |
def get_model_config(config, model_name=None): | |
models = config["models"]["supported"] | |
selected_model = next((model for model in models if model["name"] == model_name), models[0]) | |
return { | |
"model_path": os.path.join(os.getcwd(), selected_model["metadata"]["model_path"]), | |
"engine": selected_model["metadata"]["engine"], | |
"tokenizer_path": os.path.join(os.getcwd(), selected_model["metadata"]["tokenizer_path"]), | |
"max_new_tokens": selected_model["metadata"]["max_new_tokens"], | |
"max_input_token": selected_model["metadata"]["max_input_token"], | |
"temperature": selected_model["metadata"]["temperature"] | |
} | |
def get_data_path(config): | |
return os.path.join(os.getcwd(), config["dataset"]["path"]) | |
# read the app specific config | |
app_config = read_config(app_config_file) | |
streaming = app_config["streaming"] | |
similarity_top_k = app_config["similarity_top_k"] | |
is_chat_engine = app_config["is_chat_engine"] | |
embedded_model_name = app_config["embedded_model"] | |
embedded_model = os.path.join(os.getcwd(), "model", embedded_model_name) | |
embedded_dimension = app_config["embedded_dimension"] | |
# read model specific config | |
selected_model_name = None | |
selected_data_directory = None | |
config = read_config(model_config_file) | |
if os.path.exists(preference_config_file): | |
perf_config = read_config(preference_config_file) | |
selected_model_name = perf_config.get('models', {}).get('selected') | |
selected_data_directory = perf_config.get('dataset', {}).get('path') | |
if selected_model_name == None: | |
selected_model_name = config["models"].get("selected") | |
model_config = get_model_config(config, selected_model_name) | |
trt_engine_path = model_config["model_path"] | |
trt_engine_name = model_config["engine"] | |
tokenizer_dir_path = model_config["tokenizer_path"] | |
data_dir = config["dataset"]["path"] if selected_data_directory == None else selected_data_directory | |
# create trt_llm engine object | |
llm = TrtLlmAPI( | |
model_path=model_config["model_path"], | |
engine_name=model_config["engine"], | |
tokenizer_dir=model_config["tokenizer_path"], | |
temperature=model_config["temperature"], | |
max_new_tokens=model_config["max_new_tokens"], | |
context_window=model_config["max_input_token"], | |
messages_to_prompt=messages_to_prompt, | |
completion_to_prompt=completion_to_prompt, | |
verbose=False | |
) | |
# create embeddings model object | |
embed_model = HuggingFaceEmbeddings(model_name=embedded_model) | |
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model, | |
context_window=model_config["max_input_token"], chunk_size=512, | |
chunk_overlap=200) | |
set_global_service_context(service_context) | |
def generate_inferance_engine(data, force_rewrite=False): | |
""" | |
Initialize and return a FAISS-based inference engine. | |
Args: | |
data: The directory where the data for the inference engine is located. | |
force_rewrite (bool): If True, force rewriting the index. | |
Returns: | |
The initialized inference engine. | |
Raises: | |
RuntimeError: If unable to generate the inference engine. | |
""" | |
try: | |
global engine | |
faiss_storage = FaissEmbeddingStorage(data_dir=data, | |
dimension=embedded_dimension) | |
faiss_storage.initialize_index(force_rewrite=force_rewrite) | |
engine = faiss_storage.get_engine(is_chat_engine=is_chat_engine, streaming=streaming, | |
similarity_top_k=similarity_top_k) | |
except Exception as e: | |
raise RuntimeError(f"Unable to generate the inference engine: {e}") | |
# load the vectorstore index | |
generate_inferance_engine(data_dir) | |
def call_llm_streamed(query): | |
partial_response = "" | |
response = llm.stream_complete(query) | |
for token in response: | |
partial_response += token.delta | |
yield partial_response | |
def chatbot(query, chat_history, session_id): | |
if data_source == "nodataset": | |
yield llm.complete(query).text | |
return | |
if is_chat_engine: | |
response = engine.chat(query) | |
else: | |
response = engine.query(query) | |
# Aggregate scores by file | |
file_scores = defaultdict(float) | |
for node in response.source_nodes: | |
metadata = node.metadata | |
if 'filename' in metadata: | |
file_name = metadata['filename'] | |
file_scores[file_name] += node.score | |
# Find the file with the highest aggregated score | |
highest_aggregated_score_file = None | |
if file_scores: | |
highest_aggregated_score_file = max(file_scores, key=file_scores.get) | |
file_links = [] | |
seen_files = set() # Set to track unique file names | |
# Generate links for the file with the highest aggregated score | |
if highest_aggregated_score_file: | |
abs_path = Path(os.path.join(os.getcwd(), highest_aggregated_score_file.replace('\\', '/'))) | |
file_name = os.path.basename(abs_path) | |
file_name_without_ext = abs_path.stem | |
if file_name not in seen_files: # Ensure the file hasn't already been processed | |
if data_source == 'directory': | |
file_link = file_name | |
else: | |
exit("Wrong data_source type") | |
file_links.append(file_link) | |
seen_files.add(file_name) # Mark file as processed | |
response_txt = str(response) | |
if file_links: | |
response_txt += "<br>Reference files:<br>" + "<br>".join(file_links) | |
if not highest_aggregated_score_file: # If no file with a high score was found | |
response_txt = llm.complete(query).text | |
yield response_txt | |
def stream_chatbot(query, chat_history, session_id): | |
if data_source == "nodataset": | |
for response in call_llm_streamed(query): | |
yield response | |
return | |
if is_chat_engine: | |
response = engine.stream_chat(query) | |
else: | |
response = engine.query(query) | |
partial_response = "" | |
if len(response.source_nodes) == 0: | |
response = llm.stream_complete(query) | |
for token in response: | |
partial_response += token.delta | |
yield partial_response | |
else: | |
# Aggregate scores by file | |
file_scores = defaultdict(float) | |
for node in response.source_nodes: | |
if 'filename' in node.metadata: | |
file_name = node.metadata['filename'] | |
file_scores[file_name] += node.score | |
# Find the file with the highest aggregated score | |
highest_score_file = max(file_scores, key=file_scores.get, default=None) | |
file_links = [] | |
seen_files = set() | |
for token in response.response_gen: | |
partial_response += token | |
yield partial_response | |
time.sleep(0.05) | |
time.sleep(0.2) | |
if highest_score_file: | |
abs_path = Path(os.path.join(os.getcwd(), highest_score_file.replace('\\', '/'))) | |
file_name = os.path.basename(abs_path) | |
file_name_without_ext = abs_path.stem | |
if file_name not in seen_files: # Check if file_name is already seen | |
if data_source == 'directory': | |
file_link = file_name | |
else: | |
exit("Wrong data_source type") | |
file_links.append(file_link) | |
seen_files.add(file_name) # Add file_name to the set | |
if file_links: | |
partial_response += "<br>Reference files:<br>" + "<br>".join(file_links) | |
yield partial_response | |
# call garbage collector after inference | |
torch.cuda.empty_cache() | |
gc.collect() | |
interface = MainInterface(chatbot=stream_chatbot if streaming else chatbot, streaming=streaming) | |
def on_shutdown_handler(session_id): | |
global llm, service_context, embed_model, faiss_storage, engine | |
import gc | |
if llm is not None: | |
llm.unload_model() | |
del llm | |
# Force a garbage collection cycle | |
gc.collect() | |
interface.on_shutdown(on_shutdown_handler) | |
def reset_chat_handler(session_id): | |
global faiss_storage | |
global engine | |
print('reset chat called', session_id) | |
if is_chat_engine == True: | |
faiss_storage.reset_engine(engine) | |
interface.on_reset_chat(reset_chat_handler) | |
def on_dataset_path_updated_handler(source, new_directory, video_count, session_id): | |
print('data set path updated to ', source, new_directory, video_count, session_id) | |
global engine | |
global data_dir | |
if source == 'directory': | |
if data_dir != new_directory: | |
data_dir = new_directory | |
generate_inferance_engine(data_dir) | |
interface.on_dataset_path_updated(on_dataset_path_updated_handler) | |
def on_model_change_handler(model, metadata, session_id): | |
model_path = os.path.join(os.getcwd(), metadata.get('model_path', None)) | |
engine_name = metadata.get('engine', None) | |
tokenizer_path = os.path.join(os.getcwd(), metadata.get('tokenizer_path', None)) | |
if not model_path or not engine_name: | |
print("Model path or engine not provided in metadata") | |
return | |
global llm, embedded_model, engine, data_dir, service_context | |
if llm is not None: | |
llm.unload_model() | |
del llm | |
llm = TrtLlmAPI( | |
model_path=model_path, | |
engine_name=engine_name, | |
tokenizer_dir=tokenizer_path, | |
temperature=metadata.get('temperature', 0.1), | |
max_new_tokens=metadata.get('max_new_tokens', 512), | |
context_window=metadata.get('max_input_token', 512), | |
messages_to_prompt=messages_to_prompt, | |
completion_to_prompt=completion_to_prompt, | |
verbose=False | |
) | |
service_context = ServiceContext.from_service_context(service_context=service_context, llm=llm) | |
set_global_service_context(service_context) | |
generate_inferance_engine(data_dir) | |
interface.on_model_change(on_model_change_handler) | |
def on_dataset_source_change_handler(source, path, session_id): | |
global data_source, data_dir, engine | |
data_source = source | |
if data_source == "nodataset": | |
print(' No dataset source selected', session_id) | |
return | |
print('dataset source updated ', source, path, session_id) | |
if data_source == "directory": | |
data_dir = path | |
else: | |
print("Wrong data type selected") | |
generate_inferance_engine(data_dir) | |
interface.on_dataset_source_updated(on_dataset_source_change_handler) | |
def handle_regenerate_index(source, path, session_id): | |
generate_inferance_engine(path, force_rewrite=True) | |
print("on regenerate index", source, path, session_id) | |
interface.on_regenerate_index(handle_regenerate_index) | |
# render the interface | |
interface.render() | |