|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, time, json, re, gc, subprocess |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import argparse |
|
import time |
|
import sampling |
|
import copy |
|
from datetime import datetime |
|
from huggingface_hub import hf_hub_download |
|
from pynvml import * |
|
from tokenizer_util import add_tokenizer_argument, get_tokenizer |
|
import rwkv_world_tokenizer |
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
hf_hub_download(repo_id="JoPmt/RWKV-5-3B-V2-Quant", filename="rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin", local_dir='~/app/Downloads') |
|
model_path='~/app/Downloads/rwkv-5-world-3b-v2-20231118-ctx16k.Q4_0.bin' |
|
from copy import deepcopy |
|
from enum import Enum |
|
from typing import Dict, List |
|
from huggingface_hub import InferenceClient |
|
from transformers.agents import PythonInterpreterTool |
|
from transformers import AutoTokenizer |
|
tokenizer=AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B",revision="pr/13") |
|
tools=[PythonInterpreterTool()] |
|
os.system("apt-get update && apt-get install cmake gcc g++") |
|
os.system("git clone --recursive https://github.com/JoPmt/rwkv.cpp.git && cd rwkv.cpp && mkdir build && cd build && cmake .. -DRWKV_CUBLAS=ON -DRWKV_BUILD_SHARED_LIBRARY=ON -DGGML_CUDA=ON -DRWKV_BUILD_PYTHON_MODULE=ON -DRWKV_BUILD_TOOLS=ON -DRWKV_BUILD_EXTRAS=ON && cmake --build . --config Release && make RWKV_CUBLAS=1 GGML_CUDA=1") |
|
import rwkv_cpp_model |
|
import rwkv_cpp_shared_library |
|
|
|
def find_lib(): |
|
for root, dirs, files in os.walk("/"): |
|
for file in files: |
|
if file == "librwkv.so": |
|
return os.path.join(root, file) |
|
return None |
|
library_path = find_lib() |
|
rwkv_lib = rwkv_cpp_shared_library.RWKVSharedLibrary(library_path) |
|
modal = rwkv_cpp_model.RWKVModel(rwkv_lib,model_path,thread_count=2) |
|
print('Loading RWKV model') |
|
tokenizer_decode, tokenizer_encode = get_tokenizer('auto', modal.n_vocab) |
|
out_str = '' |
|
prompt = out_str |
|
token_count = 1200 |
|
temperature = 1.0 |
|
top_p = 0.7 |
|
presence_penalty = 0.1 |
|
count_penalty = 0.4 |
|
def generate_prompt(instruction, zput=""): |
|
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') |
|
zput = zput.strip().replace('\r\n','\n').replace('\n\n','\n') |
|
if zput: |
|
return f"""Instruction: {instruction} |
|
Input: {zput} |
|
Response:""" |
|
else: |
|
return f"""User: hi |
|
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. |
|
User: {instruction} |
|
Assistant:""" |
|
class MessageRole(str, Enum): |
|
USER = "user" |
|
ASSISTANT = "assistant" |
|
SYSTEM = "system" |
|
TOOL_CALL = "tool-call" |
|
TOOL_RESPONSE = "tool-response" |
|
@classmethod |
|
def roles(cls): |
|
return [r.value for r in cls] |
|
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}): |
|
""" |
|
Subsequent messages with the same role will be concatenated to a single message. |
|
|
|
Args: |
|
message_list (`List[Dict[str, str]]`): List of chat messages. |
|
""" |
|
final_message_list = [] |
|
message_list = deepcopy(message_list) |
|
for message in message_list: |
|
if not set(message.keys()) == {"role", "content"}: |
|
raise ValueError("Message should contain only 'role' and 'content' keys!") |
|
|
|
role = message["role"] |
|
if role not in MessageRole.roles(): |
|
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.") |
|
|
|
if role in role_conversions: |
|
message["role"] = role_conversions[role] |
|
|
|
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: |
|
final_message_list[-1]["content"] = "\n=======\n" + message["content"] |
|
else: |
|
final_message_list.append(message) |
|
return final_message_list |
|
llama_role_conversions = { |
|
MessageRole.TOOL_RESPONSE: MessageRole.USER, |
|
MessageRole.TOOL_CALL: MessageRole.USER, |
|
} |
|
class HfEngine: |
|
def __init__(self, model: str = "JoPmt/JoPmt"): |
|
self.model = model |
|
self.client = modal |
|
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str: |
|
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) |
|
print(messages) |
|
pret='' |
|
prut='' |
|
for message in messages: |
|
print(message['content']) |
|
if message['role'].lower() == 'system': |
|
pret+=''+message['content']+'' |
|
if message['role'].lower() == 'user': |
|
prut+=''+message['content']+'' |
|
|
|
prompt=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True,) |
|
print(prompt) |
|
token_count=1200 |
|
temperature=1.0 |
|
top_p=0.7 |
|
presencePenalty = 0.1 |
|
countPenalty = 0.4 |
|
token_ban=[] |
|
stop_token=[0] |
|
ctx=pret |
|
prompt=prut |
|
all_tokens = [] |
|
out_last = 0 |
|
out_str = '' |
|
occurrence = {} |
|
state = None |
|
ctx=generate_prompt(ctx,prompt) |
|
prompt_tokens = tokenizer_encode(ctx) |
|
prompt_token_count = len(prompt_tokens) |
|
init_logits, init_state = modal.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True) |
|
logits, state = init_logits.copy(), init_state.copy() |
|
out_str = '' |
|
occurrence = {} |
|
bof=[] |
|
for i in range(token_count): |
|
for n in occurrence: |
|
logits[n] -= (presencePenalty + occurrence[n] * countPenalty) |
|
token = sampling.sample_logits(logits, temperature, top_p) |
|
|
|
if token in stop_token: |
|
break |
|
all_tokens += [token] |
|
|
|
for xxx in occurrence: |
|
occurrence[xxx] *= 0.996 |
|
|
|
if token not in occurrence: |
|
occurrence[token] = 1 |
|
else: |
|
occurrence[token] += 1 |
|
|
|
tmp = tokenizer_decode(all_tokens[out_last:]) |
|
if '\ufffd' not in tmp: |
|
out_str += tmp |
|
out_last = i + 1 |
|
|
|
logits, state = modal.eval(token, state, state, logits, use_numpy=True) |
|
del state |
|
gc.collect() |
|
return out_str.strip() |