|
import spaces
|
|
import gradio as gr
|
|
from pathlib import Path
|
|
import re
|
|
import torch
|
|
import gc
|
|
from typing import Any
|
|
from huggingface_hub import hf_hub_download, HfApi
|
|
from llama_cpp import Llama
|
|
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
|
|
from llama_cpp_agent.providers import LlamaCppPythonProvider
|
|
from llama_cpp_agent.chat_history import BasicChatHistory
|
|
from llama_cpp_agent.chat_history.messages import Roles
|
|
from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags
|
|
import wrapt_timeout_decorator
|
|
from llama_cpp_agent.messages_formatter import MessagesFormatter
|
|
from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter
|
|
from llmenv import llm_models, llm_models_dir, llm_formats, llm_languages, dolphin_system_prompt
|
|
import subprocess
|
|
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
|
|
|
|
|
|
llm_models_tupled_list = []
|
|
default_llm_model_filename = list(llm_models.keys())[0]
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def to_list(s: str):
|
|
return [x.strip() for x in s.split(",") if not s == ""]
|
|
|
|
|
|
def list_uniq(l: list):
|
|
return sorted(set(l), key=l.index)
|
|
|
|
|
|
DEFAULT_STATE = {
|
|
"dolphin_sysprompt_mode": "Default",
|
|
"dolphin_output_language": llm_languages[0],
|
|
}
|
|
|
|
|
|
def get_state(state: dict, key: str):
|
|
if key in state.keys(): return state[key]
|
|
elif key in DEFAULT_STATE.keys():
|
|
print(f"State '{key}' not found. Use dedault value.")
|
|
return DEFAULT_STATE[key]
|
|
else:
|
|
print(f"State '{key}' not found.")
|
|
return None
|
|
|
|
|
|
def set_state(state: dict, key: str, value: Any):
|
|
state[key] = value
|
|
|
|
|
|
@wrapt_timeout_decorator.timeout(dec_timeout=3.5)
|
|
def to_list_ja(s: str):
|
|
s = re.sub(r'[、。]', ',', s)
|
|
return [x.strip() for x in s.split(",") if not s == ""]
|
|
|
|
|
|
def is_japanese(s: str):
|
|
import unicodedata
|
|
for ch in s:
|
|
name = unicodedata.name(ch, "")
|
|
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
|
return True
|
|
return False
|
|
|
|
|
|
def update_llm_model_tupled_list():
|
|
global llm_models_tupled_list
|
|
llm_models_tupled_list = []
|
|
for k, v in llm_models.items():
|
|
name = k
|
|
value = k
|
|
llm_models_tupled_list.append((name, value))
|
|
model_files = Path(llm_models_dir).glob('*.gguf')
|
|
for path in model_files:
|
|
name = path.name
|
|
value = path.name
|
|
llm_models_tupled_list.append((name, value))
|
|
llm_models_tupled_list = list_uniq(llm_models_tupled_list)
|
|
return llm_models_tupled_list
|
|
|
|
|
|
def download_llm_models():
|
|
global llm_models_tupled_list
|
|
llm_models_tupled_list = []
|
|
for k, v in llm_models.items():
|
|
try:
|
|
hf_hub_download(repo_id = v[0], filename = k, local_dir = llm_models_dir)
|
|
except Exception:
|
|
continue
|
|
name = k
|
|
value = k
|
|
llm_models_tupled_list.append((name, value))
|
|
|
|
|
|
def download_llm_model(filename: str):
|
|
if not filename in llm_models.keys(): return default_llm_model_filename
|
|
try:
|
|
hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir)
|
|
except Exception as e:
|
|
print(e)
|
|
return default_llm_model_filename
|
|
update_llm_model_tupled_list()
|
|
return filename
|
|
|
|
|
|
def get_dolphin_model_info(filename: str):
|
|
md = "None"
|
|
items = llm_models.get(filename, None)
|
|
if items:
|
|
md = f'Repo: [{items[0]}](https://huggingface.co/{items[0]})'
|
|
return md
|
|
|
|
|
|
def select_dolphin_model(filename: str, state: dict, progress=gr.Progress(track_tqdm=True)):
|
|
set_state(state, "override_llm_format", None)
|
|
progress(0, desc="Loading model...")
|
|
value = download_llm_model(filename)
|
|
progress(1, desc="Model loaded.")
|
|
md = get_dolphin_model_info(filename)
|
|
return gr.update(value=value, choices=get_dolphin_models()), gr.update(value=get_dolphin_model_format(value)), gr.update(value=md), state
|
|
|
|
|
|
def select_dolphin_format(format_name: str, state: dict):
|
|
set_state(state, "override_llm_format", llm_formats[format_name])
|
|
return gr.update(value=format_name), state
|
|
|
|
|
|
download_llm_model(default_llm_model_filename)
|
|
|
|
|
|
def get_dolphin_models():
|
|
return update_llm_model_tupled_list()
|
|
|
|
|
|
def get_llm_formats():
|
|
return list(llm_formats.keys())
|
|
|
|
|
|
def get_key_from_value(d, val):
|
|
keys = [k for k, v in d.items() if v == val]
|
|
if keys:
|
|
return keys[0]
|
|
return None
|
|
|
|
|
|
def get_dolphin_model_format(filename: str):
|
|
if not filename in llm_models.keys(): filename = default_llm_model_filename
|
|
format = llm_models[filename][1]
|
|
format_name = get_key_from_value(llm_formats, format)
|
|
return format_name
|
|
|
|
|
|
def add_dolphin_models(query: str, format_name: str):
|
|
global llm_models
|
|
api = HfApi()
|
|
add_models = {}
|
|
format = llm_formats[format_name]
|
|
filename = ""
|
|
repo = ""
|
|
try:
|
|
s = list(re.findall(r'^(?:https?://huggingface.co/)?(.+?/.+?)(?:/.*/(.+?.gguf).*?)?$', query)[0])
|
|
if s and "" in s: s.remove("")
|
|
if len(s) == 1:
|
|
repo = s[0]
|
|
if not api.repo_exists(repo_id = repo): return gr.update()
|
|
files = api.list_repo_files(repo_id = repo)
|
|
for file in files:
|
|
if str(file).endswith(".gguf"): add_models[filename] = [repo, format]
|
|
elif len(s) >= 2:
|
|
repo = s[0]
|
|
filename = s[1]
|
|
if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update()
|
|
add_models[filename] = [repo, format]
|
|
else: return gr.update()
|
|
except Exception as e:
|
|
print(e)
|
|
return gr.update()
|
|
llm_models = (llm_models | add_models).copy()
|
|
update_llm_model_tupled_list()
|
|
choices = get_dolphin_models()
|
|
return gr.update(choices=choices, value=choices[-1][1])
|
|
|
|
|
|
def get_dolphin_sysprompt(state: dict={}):
|
|
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
|
|
dolphin_output_language = get_state(state, "dolphin_output_language")
|
|
prompt = re.sub('<LANGUAGE>', dolphin_output_language if dolphin_output_language else llm_languages[0],
|
|
dolphin_system_prompt.get(dolphin_sysprompt_mode, dolphin_system_prompt[list(dolphin_system_prompt.keys())[0]]))
|
|
return prompt
|
|
|
|
|
|
def get_dolphin_sysprompt_mode():
|
|
return list(dolphin_system_prompt.keys())
|
|
|
|
|
|
def select_dolphin_sysprompt(key: str, state: dict):
|
|
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
|
|
if not key in dolphin_system_prompt.keys(): dolphin_sysprompt_mode = "Default"
|
|
else: dolphin_sysprompt_mode = key
|
|
set_state(state, "dolphin_sysprompt_mode", dolphin_sysprompt_mode)
|
|
return gr.update(value=get_dolphin_sysprompt(state)), state
|
|
|
|
|
|
def get_dolphin_languages():
|
|
return llm_languages
|
|
|
|
|
|
def select_dolphin_language(lang: str, state: dict):
|
|
set_state(state, "dolphin_output_language", lang)
|
|
return gr.update(value=get_dolphin_sysprompt(state)), state
|
|
|
|
|
|
@wrapt_timeout_decorator.timeout(dec_timeout=5.0)
|
|
def get_raw_prompt(msg: str):
|
|
m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL)
|
|
return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else ""
|
|
|
|
|
|
@torch.inference_mode()
|
|
@spaces.GPU(duration=59)
|
|
def dolphin_respond(
|
|
message: str,
|
|
history: list[tuple[str, str]],
|
|
model: str = default_llm_model_filename,
|
|
system_message: str = get_dolphin_sysprompt(),
|
|
max_tokens: int = 1024,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.95,
|
|
top_k: int = 40,
|
|
repeat_penalty: float = 1.1,
|
|
state: dict = {},
|
|
progress=gr.Progress(track_tqdm=True),
|
|
):
|
|
try:
|
|
model_path = Path(f"{llm_models_dir}/{model}")
|
|
if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}")
|
|
progress(0, desc="Processing...")
|
|
override_llm_format = get_state(state, "override_llm_format")
|
|
if override_llm_format: chat_template = override_llm_format
|
|
else: chat_template = llm_models[model][1]
|
|
|
|
llm = Llama(
|
|
model_path=str(model_path),
|
|
flash_attn=True,
|
|
n_gpu_layers=81,
|
|
n_batch=1024,
|
|
n_ctx=8192,
|
|
)
|
|
provider = LlamaCppPythonProvider(llm)
|
|
|
|
agent = LlamaCppAgent(
|
|
provider,
|
|
system_prompt=f"{system_message}",
|
|
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
|
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
|
debug_output=False
|
|
)
|
|
|
|
settings = provider.get_provider_default_settings()
|
|
settings.temperature = temperature
|
|
settings.top_k = top_k
|
|
settings.top_p = top_p
|
|
settings.max_tokens = max_tokens
|
|
settings.repeat_penalty = repeat_penalty
|
|
settings.stream = True
|
|
|
|
messages = BasicChatHistory()
|
|
|
|
for msn in history:
|
|
user = {
|
|
'role': Roles.user,
|
|
'content': msn[0]
|
|
}
|
|
assistant = {
|
|
'role': Roles.assistant,
|
|
'content': msn[1]
|
|
}
|
|
messages.add_message(user)
|
|
messages.add_message(assistant)
|
|
|
|
stream = agent.get_chat_response(
|
|
message,
|
|
llm_sampling_settings=settings,
|
|
chat_history=messages,
|
|
returns_streaming_generator=True,
|
|
print_output=False
|
|
)
|
|
|
|
progress(0.5, desc="Processing...")
|
|
|
|
outputs = ""
|
|
for output in stream:
|
|
outputs += output
|
|
yield [(outputs, None)]
|
|
except Exception as e:
|
|
print(e)
|
|
raise gr.Error(f"Error: {e}")
|
|
|
|
finally:
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
def dolphin_parse(
|
|
history: list[tuple[str, str]],
|
|
state: dict,
|
|
):
|
|
try:
|
|
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
|
|
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1:
|
|
return "", gr.update(), gr.update()
|
|
msg = history[-1][0]
|
|
raw_prompt = get_raw_prompt(msg)
|
|
prompts = []
|
|
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
|
|
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"])
|
|
else:
|
|
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"])
|
|
return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True)
|
|
except Exception as e:
|
|
print(e)
|
|
return "", gr.update(), gr.update()
|
|
|
|
|
|
@torch.inference_mode()
|
|
@spaces.GPU(duration=59)
|
|
def dolphin_respond_auto(
|
|
message: str,
|
|
history: list[tuple[str, str]],
|
|
model: str = default_llm_model_filename,
|
|
system_message: str = get_dolphin_sysprompt(),
|
|
max_tokens: int = 1024,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.95,
|
|
top_k: int = 40,
|
|
repeat_penalty: float = 1.1,
|
|
state: dict = {},
|
|
progress=gr.Progress(track_tqdm=True),
|
|
):
|
|
try:
|
|
model_path = Path(f"{llm_models_dir}/{model}")
|
|
|
|
progress(0, desc="Processing...")
|
|
|
|
override_llm_format = get_state(state, "override_llm_format")
|
|
if override_llm_format: chat_template = override_llm_format
|
|
else: chat_template = llm_models[model][1]
|
|
|
|
llm = Llama(
|
|
model_path=str(model_path),
|
|
flash_attn=True,
|
|
n_gpu_layers=81,
|
|
n_batch=1024,
|
|
n_ctx=8192,
|
|
)
|
|
provider = LlamaCppPythonProvider(llm)
|
|
|
|
agent = LlamaCppAgent(
|
|
provider,
|
|
system_prompt=f"{system_message}",
|
|
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
|
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
|
debug_output=False
|
|
)
|
|
|
|
settings = provider.get_provider_default_settings()
|
|
settings.temperature = temperature
|
|
settings.top_k = top_k
|
|
settings.top_p = top_p
|
|
settings.max_tokens = max_tokens
|
|
settings.repeat_penalty = repeat_penalty
|
|
settings.stream = True
|
|
|
|
messages = BasicChatHistory()
|
|
|
|
for msn in history:
|
|
user = {
|
|
'role': Roles.user,
|
|
'content': msn[0]
|
|
}
|
|
assistant = {
|
|
'role': Roles.assistant,
|
|
'content': msn[1]
|
|
}
|
|
messages.add_message(user)
|
|
messages.add_message(assistant)
|
|
|
|
progress(0, desc="Translating...")
|
|
stream = agent.get_chat_response(
|
|
message,
|
|
llm_sampling_settings=settings,
|
|
chat_history=messages,
|
|
returns_streaming_generator=True,
|
|
print_output=False
|
|
)
|
|
|
|
progress(0.5, desc="Processing...")
|
|
|
|
outputs = ""
|
|
for output in stream:
|
|
outputs += output
|
|
yield [(outputs, None)], gr.update(), gr.update()
|
|
except Exception as e:
|
|
print(e)
|
|
yield [("", None)], gr.update(), gr.update()
|
|
finally:
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
def dolphin_parse_simple(
|
|
message: str,
|
|
history: list[tuple[str, str]],
|
|
state: dict,
|
|
):
|
|
try:
|
|
|
|
dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode")
|
|
if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message
|
|
msg = history[-1][0]
|
|
raw_prompt = get_raw_prompt(msg)
|
|
prompts = []
|
|
if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt):
|
|
prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"])
|
|
else:
|
|
prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"])
|
|
return ", ".join(prompts)
|
|
except Exception as e:
|
|
print(e)
|
|
return ""
|
|
|
|
|
|
|
|
import cv2
|
|
cv2.setNumThreads(1)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@spaces.GPU(duration=59)
|
|
def respond_playground(
|
|
message: str,
|
|
history: list[tuple[str, str]],
|
|
model: str = default_llm_model_filename,
|
|
system_message: str = get_dolphin_sysprompt(),
|
|
max_tokens: int = 1024,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.95,
|
|
top_k: int = 40,
|
|
repeat_penalty: float = 1.1,
|
|
state: dict = {},
|
|
progress=gr.Progress(track_tqdm=True),
|
|
):
|
|
try:
|
|
model_path = Path(f"{llm_models_dir}/{model}")
|
|
if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}")
|
|
override_llm_format = get_state(state, "override_llm_format")
|
|
if override_llm_format: chat_template = override_llm_format
|
|
else: chat_template = llm_models[model][1]
|
|
|
|
llm = Llama(
|
|
model_path=str(model_path),
|
|
flash_attn=True,
|
|
n_gpu_layers=81,
|
|
n_batch=1024,
|
|
n_ctx=8192,
|
|
)
|
|
provider = LlamaCppPythonProvider(llm)
|
|
|
|
agent = LlamaCppAgent(
|
|
provider,
|
|
system_prompt=f"{system_message}",
|
|
predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None,
|
|
custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None,
|
|
debug_output=False
|
|
)
|
|
|
|
settings = provider.get_provider_default_settings()
|
|
settings.temperature = temperature
|
|
settings.top_k = top_k
|
|
settings.top_p = top_p
|
|
settings.max_tokens = max_tokens
|
|
settings.repeat_penalty = repeat_penalty
|
|
settings.stream = True
|
|
|
|
messages = BasicChatHistory()
|
|
|
|
|
|
for msn in history:
|
|
user = {'role': Roles.user, 'content': msn[0]}
|
|
assistant = {'role': Roles.assistant, 'content': msn[1]}
|
|
messages.add_message(user)
|
|
messages.add_message(assistant)
|
|
|
|
|
|
stream = agent.get_chat_response(
|
|
message,
|
|
llm_sampling_settings=settings,
|
|
chat_history=messages,
|
|
returns_streaming_generator=True,
|
|
print_output=False
|
|
)
|
|
|
|
outputs = ""
|
|
for output in stream:
|
|
outputs += output
|
|
yield outputs
|
|
except Exception as e:
|
|
print(e)
|
|
raise gr.Error(f"Error: {e}")
|
|
|
|
finally:
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|