import spaces import gradio as gr from pathlib import Path import re import torch import gc from typing import Any from huggingface_hub import hf_hub_download, HfApi from llama_cpp import Llama from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType from llama_cpp_agent.providers import LlamaCppPythonProvider from llama_cpp_agent.chat_history import BasicChatHistory from llama_cpp_agent.chat_history.messages import Roles from ja_to_danbooru.ja_to_danbooru import jatags_to_danbooru_tags import wrapt_timeout_decorator from llama_cpp_agent.messages_formatter import MessagesFormatter from formatter import mistral_v1_formatter, mistral_v2_formatter, mistral_v3_tekken_formatter from llmenv import llm_models, llm_models_dir, llm_formats, llm_languages, dolphin_system_prompt llm_models_tupled_list = [] default_llm_model_filename = list(llm_models.keys())[0] device = "cuda" if torch.cuda.is_available() else "cpu" def to_list(s: str): return [x.strip() for x in s.split(",") if not s == ""] def list_uniq(l: list): return sorted(set(l), key=l.index) DEFAULT_STATE = { "dolphin_sysprompt_mode": "Default", "dolphin_output_language": llm_languages[0], } def get_state(state: dict, key: str): if key in state.keys(): return state[key] elif key in DEFAULT_STATE.keys(): print(f"State '{key}' not found. Use dedault value.") return DEFAULT_STATE[key] else: print(f"State '{key}' not found.") return None def set_state(state: dict, key: str, value: Any): state[key] = value @wrapt_timeout_decorator.timeout(dec_timeout=3.5) def to_list_ja(s: str): s = re.sub(r'[、。]', ',', s) return [x.strip() for x in s.split(",") if not s == ""] def is_japanese(s: str): import unicodedata for ch in s: name = unicodedata.name(ch, "") if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name: return True return False def update_llm_model_tupled_list(): global llm_models_tupled_list llm_models_tupled_list = [] for k, v in llm_models.items(): name = k value = k llm_models_tupled_list.append((name, value)) model_files = Path(llm_models_dir).glob('*.gguf') for path in model_files: name = path.name value = path.name llm_models_tupled_list.append((name, value)) llm_models_tupled_list = list_uniq(llm_models_tupled_list) return llm_models_tupled_list def download_llm_models(): global llm_models_tupled_list llm_models_tupled_list = [] for k, v in llm_models.items(): try: hf_hub_download(repo_id = v[0], filename = k, local_dir = llm_models_dir) except Exception: continue name = k value = k llm_models_tupled_list.append((name, value)) def download_llm_model(filename: str): if not filename in llm_models.keys(): return default_llm_model_filename try: hf_hub_download(repo_id = llm_models[filename][0], filename = filename, local_dir = llm_models_dir) except Exception as e: print(e) return default_llm_model_filename update_llm_model_tupled_list() return filename def get_dolphin_model_info(filename: str): md = "None" items = llm_models.get(filename, None) if items: md = f'Repo: [{items[0]}](https://huggingface.co/{items[0]})' return md def select_dolphin_model(filename: str, state: dict, progress=gr.Progress(track_tqdm=True)): set_state(state, "override_llm_format", None) progress(0, desc="Loading model...") value = download_llm_model(filename) progress(1, desc="Model loaded.") md = get_dolphin_model_info(filename) return gr.update(value=value, choices=get_dolphin_models()), gr.update(value=get_dolphin_model_format(value)), gr.update(value=md), state def select_dolphin_format(format_name: str, state: dict): set_state(state, "override_llm_format", llm_formats[format_name]) return gr.update(value=format_name), state download_llm_model(default_llm_model_filename) def get_dolphin_models(): return update_llm_model_tupled_list() def get_llm_formats(): return list(llm_formats.keys()) def get_key_from_value(d, val): keys = [k for k, v in d.items() if v == val] if keys: return keys[0] return None def get_dolphin_model_format(filename: str): if not filename in llm_models.keys(): filename = default_llm_model_filename format = llm_models[filename][1] format_name = get_key_from_value(llm_formats, format) return format_name def add_dolphin_models(query: str, format_name: str): global llm_models api = HfApi() add_models = {} format = llm_formats[format_name] filename = "" repo = "" try: s = list(re.findall(r'^(?:https?://huggingface.co/)?(.+?/.+?)(?:/.*/(.+?.gguf).*?)?$', query)[0]) if s and "" in s: s.remove("") if len(s) == 1: repo = s[0] if not api.repo_exists(repo_id = repo): return gr.update() files = api.list_repo_files(repo_id = repo) for file in files: if str(file).endswith(".gguf"): add_models[filename] = [repo, format] elif len(s) >= 2: repo = s[0] filename = s[1] if not api.repo_exists(repo_id = repo) or not api.file_exists(repo_id = repo, filename = filename): return gr.update() add_models[filename] = [repo, format] else: return gr.update() except Exception as e: print(e) return gr.update() llm_models = (llm_models | add_models).copy() update_llm_model_tupled_list() choices = get_dolphin_models() return gr.update(choices=choices, value=choices[-1][1]) def get_dolphin_sysprompt(state: dict={}): dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") dolphin_output_language = get_state(state, "dolphin_output_language") prompt = re.sub('', dolphin_output_language if dolphin_output_language else llm_languages[0], dolphin_system_prompt.get(dolphin_sysprompt_mode, dolphin_system_prompt[list(dolphin_system_prompt.keys())[0]])) return prompt def get_dolphin_sysprompt_mode(): return list(dolphin_system_prompt.keys()) def select_dolphin_sysprompt(key: str, state: dict): dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") if not key in dolphin_system_prompt.keys(): dolphin_sysprompt_mode = "Default" else: dolphin_sysprompt_mode = key set_state(state, "dolphin_sysprompt_mode", dolphin_sysprompt_mode) return gr.update(value=get_dolphin_sysprompt(state)), state def get_dolphin_languages(): return llm_languages def select_dolphin_language(lang: str, state: dict): set_state(state, "dolphin_output_language", lang) return gr.update(value=get_dolphin_sysprompt(state)), state @wrapt_timeout_decorator.timeout(dec_timeout=5.0) def get_raw_prompt(msg: str): m = re.findall(r'/GENBEGIN/(.+?)/GENEND/', msg, re.DOTALL) return re.sub(r'[*/:_"#\n]', ' ', ", ".join(m)).lower() if m else "" @torch.inference_mode() @spaces.GPU(duration=59) def dolphin_respond( message: str, history: list[tuple[str, str]], model: str = default_llm_model_filename, system_message: str = get_dolphin_sysprompt(), max_tokens: int = 1024, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 40, repeat_penalty: float = 1.1, state: dict = {}, progress=gr.Progress(track_tqdm=True), ): try: model_path = Path(f"{llm_models_dir}/{model}") if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}") progress(0, desc="Processing...") override_llm_format = get_state(state, "override_llm_format") if override_llm_format: chat_template = override_llm_format else: chat_template = llm_models[model][1] llm = Llama( model_path=str(model_path), flash_attn=True, n_gpu_layers=81, # 81 n_batch=1024, n_ctx=8192, #8192 ) provider = LlamaCppPythonProvider(llm) agent = LlamaCppAgent( provider, system_prompt=f"{system_message}", predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None, custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None, debug_output=False ) settings = provider.get_provider_default_settings() settings.temperature = temperature settings.top_k = top_k settings.top_p = top_p settings.max_tokens = max_tokens settings.repeat_penalty = repeat_penalty settings.stream = True messages = BasicChatHistory() for msn in history: user = { 'role': Roles.user, 'content': msn[0] } assistant = { 'role': Roles.assistant, 'content': msn[1] } messages.add_message(user) messages.add_message(assistant) stream = agent.get_chat_response( message, llm_sampling_settings=settings, chat_history=messages, returns_streaming_generator=True, print_output=False ) progress(0.5, desc="Processing...") outputs = "" for output in stream: outputs += output yield [(outputs, None)] except Exception as e: print(e) raise gr.Error(f"Error: {e}") #yield [("", None)] finally: torch.cuda.empty_cache() gc.collect() def dolphin_parse( history: list[tuple[str, str]], state: dict, ): try: dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return "", gr.update(), gr.update() msg = history[-1][0] raw_prompt = get_raw_prompt(msg) prompts = [] if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt): prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit"]) else: prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit"]) return ", ".join(prompts), gr.update(interactive=True), gr.update(interactive=True) except Exception as e: print(e) return "", gr.update(), gr.update() @torch.inference_mode() @spaces.GPU(duration=59) def dolphin_respond_auto( message: str, history: list[tuple[str, str]], model: str = default_llm_model_filename, system_message: str = get_dolphin_sysprompt(), max_tokens: int = 1024, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 40, repeat_penalty: float = 1.1, state: dict = {}, progress=gr.Progress(track_tqdm=True), ): try: model_path = Path(f"{llm_models_dir}/{model}") #if not is_japanese(message): return [(None, None)] progress(0, desc="Processing...") override_llm_format = get_state(state, "override_llm_format") if override_llm_format: chat_template = override_llm_format else: chat_template = llm_models[model][1] llm = Llama( model_path=str(model_path), flash_attn=True, n_gpu_layers=81, # 81 n_batch=1024, n_ctx=8192, #8192 ) provider = LlamaCppPythonProvider(llm) agent = LlamaCppAgent( provider, system_prompt=f"{system_message}", predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None, custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None, debug_output=False ) settings = provider.get_provider_default_settings() settings.temperature = temperature settings.top_k = top_k settings.top_p = top_p settings.max_tokens = max_tokens settings.repeat_penalty = repeat_penalty settings.stream = True messages = BasicChatHistory() for msn in history: user = { 'role': Roles.user, 'content': msn[0] } assistant = { 'role': Roles.assistant, 'content': msn[1] } messages.add_message(user) messages.add_message(assistant) progress(0, desc="Translating...") stream = agent.get_chat_response( message, llm_sampling_settings=settings, chat_history=messages, returns_streaming_generator=True, print_output=False ) progress(0.5, desc="Processing...") outputs = "" for output in stream: outputs += output yield [(outputs, None)], gr.update(), gr.update() except Exception as e: print(e) yield [("", None)], gr.update(), gr.update() finally: torch.cuda.empty_cache() gc.collect() def dolphin_parse_simple( message: str, history: list[tuple[str, str]], state: dict, ): try: #if not is_japanese(message): return message dolphin_sysprompt_mode = get_state(state, "dolphin_sysprompt_mode") if dolphin_sysprompt_mode == "Chat with LLM" or not history or len(history) < 1: return message msg = history[-1][0] raw_prompt = get_raw_prompt(msg) prompts = [] if dolphin_sysprompt_mode == "Japanese to Danbooru Dictionary" and is_japanese(raw_prompt): prompts = list_uniq(jatags_to_danbooru_tags(to_list_ja(raw_prompt)) + ["nsfw", "explicit", "rating_explicit"]) else: prompts = list_uniq(to_list(raw_prompt) + ["nsfw", "explicit", "rating_explicit"]) return ", ".join(prompts) except Exception as e: print(e) return "" # https://huggingface.co/spaces/CaioXapelaum/GGUF-Playground import cv2 cv2.setNumThreads(1) @torch.inference_mode() @spaces.GPU(duration=59) def respond_playground( message: str, history: list[tuple[str, str]], model: str = default_llm_model_filename, system_message: str = get_dolphin_sysprompt(), max_tokens: int = 1024, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 40, repeat_penalty: float = 1.1, state: dict = {}, progress=gr.Progress(track_tqdm=True), ): try: model_path = Path(f"{llm_models_dir}/{model}") if not model_path.exists(): raise gr.Error(f"Model file not found: {str(model_path)}") override_llm_format = get_state(state, "override_llm_format") if override_llm_format: chat_template = override_llm_format else: chat_template = llm_models[model][1] llm = Llama( model_path=str(model_path), flash_attn=True, n_gpu_layers=81, # 81 n_batch=1024, n_ctx=8192, #8192 ) provider = LlamaCppPythonProvider(llm) agent = LlamaCppAgent( provider, system_prompt=f"{system_message}", predefined_messages_formatter_type=chat_template if not isinstance(chat_template, MessagesFormatter) else None, custom_messages_formatter=chat_template if isinstance(chat_template, MessagesFormatter) else None, debug_output=False ) settings = provider.get_provider_default_settings() settings.temperature = temperature settings.top_k = top_k settings.top_p = top_p settings.max_tokens = max_tokens settings.repeat_penalty = repeat_penalty settings.stream = True messages = BasicChatHistory() # Add user and assistant messages to the history for msn in history: user = {'role': Roles.user, 'content': msn[0]} assistant = {'role': Roles.assistant, 'content': msn[1]} messages.add_message(user) messages.add_message(assistant) # Stream the response stream = agent.get_chat_response( message, llm_sampling_settings=settings, chat_history=messages, returns_streaming_generator=True, print_output=False ) outputs = "" for output in stream: outputs += output yield outputs except Exception as e: print(e) raise gr.Error(f"Error: {e}") #yield "" finally: torch.cuda.empty_cache() gc.collect()