# Provides terminal-based chat interface for RWKV model. # Usage: python chat_with_bot.py C:\rwkv.cpp-169M.bin # Prompts and code adapted from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py import os import argparse import pathlib import copy import json import time import sampling from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model from tokenizer_util import add_tokenizer_argument, get_tokenizer from typing import List, Dict, Optional # ======================================== Script settings ======================================== # English, Chinese, Japanese LANGUAGE: str = 'English' # QA: Question and Answer prompt to talk to an AI assistant. # Chat: chat prompt (need a large model for adequate quality, 7B+). PROMPT_TYPE: str = 'QA' MAX_GENERATION_LENGTH: int = 250 # Sampling temperature. It could be a good idea to increase temperature when top_p is low. TEMPERATURE: float = 0.8 # For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.) TOP_P: float = 0.5 # Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. PRESENCE_PENALTY: float = 0.2 # Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. FREQUENCY_PENALTY: float = 0.2 END_OF_LINE_TOKEN: int = 187 DOUBLE_END_OF_LINE_TOKEN: int = 535 END_OF_TEXT_TOKEN: int = 0 # ================================================================================================= parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') parser.add_argument('model_path', help='Path to RWKV model in ggml format') add_tokenizer_argument(parser) args = parser.parse_args() script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encoding='utf8') as json_file: prompt_data = json.load(json_file) user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] if init_prompt == '': raise ValueError('Prompt must not be empty') library = rwkv_cpp_shared_library.load_rwkv_shared_library() print(f'System info: {library.rwkv_get_system_info_string()}') print('Loading RWKV model') model = rwkv_cpp_model.RWKVModel(library, args.model_path) tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) # ================================================================================================= processed_tokens: List[int] = [] logits: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None state: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None: global processed_tokens, logits, state logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True) processed_tokens += _tokens logits[END_OF_LINE_TOKEN] += new_line_logit_bias state_by_thread: Dict[str, Dict] = {} def save_thread_state(_thread: str) -> None: state_by_thread[_thread] = { 'tokens': copy.deepcopy(processed_tokens), 'logits': copy.deepcopy(logits), 'state': copy.deepcopy(state) } def load_thread_state(_thread: str) -> None: global processed_tokens, logits, state thread_state = state_by_thread[_thread] processed_tokens = copy.deepcopy(thread_state['tokens']) logits = copy.deepcopy(thread_state['logits']) state = copy.deepcopy(thread_state['state']) # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end. # See https://github.com/BlinkDL/ChatRWKV/pull/110/files def split_last_end_of_line(tokens: List[int]) -> List[int]: if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN: tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN] return tokens # ================================================================================================= processing_start: float = time.time() prompt_tokens = tokenizer_encode(init_prompt) prompt_token_count = len(prompt_tokens) print(f'Processing {prompt_token_count} prompt tokens, may take a while') process_tokens(split_last_end_of_line(prompt_tokens)) processing_duration: float = time.time() - processing_start print(f'Processed in {int(processing_duration)} s, {int(processing_duration / prompt_token_count * 1000)} ms per token') save_thread_state('chat_init') save_thread_state('chat') print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.') while True: # Read user input user_input: str = input(f'> {user}{separator} ') msg: str = user_input.replace('\\n', '\n').strip() temperature: float = TEMPERATURE top_p: float = TOP_P if '-temp=' in msg: temperature = float(msg.split('-temp=')[1].split(' ')[0]) msg = msg.replace('-temp='+f'{temperature:g}', '') if temperature <= 0.2: temperature = 0.2 if temperature >= 5: temperature = 5 if '-top_p=' in msg: top_p = float(msg.split('-top_p=')[1].split(' ')[0]) msg = msg.replace('-top_p='+f'{top_p:g}', '') if top_p <= 0: top_p = 0 msg = msg.strip() # + reset --> reset chat if msg == '+reset': load_thread_state('chat_init') save_thread_state('chat') print(f'{bot}{separator} Chat reset.\n') continue elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': # +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model. if msg[:5].lower() == '+gen ': new = '\n' + msg[5:].strip() state = None processed_tokens = [] process_tokens(tokenizer_encode(new)) save_thread_state('gen_0') # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. elif msg[:3].lower() == '+i ': new = f''' Below is an instruction that describes a task. Write a response that appropriately completes the request. # Instruction: {msg[3:].strip()} # Response: ''' state = None processed_tokens = [] process_tokens(tokenizer_encode(new)) save_thread_state('gen_0') # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). elif msg[:4].lower() == '+qq ': new = '\nQ: ' + msg[4:].strip() + '\nA:' state = None processed_tokens = [] process_tokens(tokenizer_encode(new)) save_thread_state('gen_0') # +qa YOUR QUESTION --> answer an independent question (regardless of context). elif msg[:4].lower() == '+qa ': load_thread_state('chat_init') real_msg = msg[4:].strip() new = f'{user}{separator} {real_msg}\n\n{bot}{separator}' process_tokens(tokenizer_encode(new)) save_thread_state('gen_0') # +++ --> continue last free generation (only for +gen / +i) elif msg.lower() == '+++': try: load_thread_state('gen_1') save_thread_state('gen_0') except Exception as e: print(e) continue # ++ --> retry last free generation (only for +gen / +i) elif msg.lower() == '++': try: load_thread_state('gen_0') except Exception as e: print(e) continue thread = 'gen_1' else: # + --> alternate chat reply if msg.lower() == '+': try: load_thread_state('chat_pre') except Exception as e: print(e) continue # chat with bot else: load_thread_state('chat') new = f'{user}{separator} {msg}\n\n{bot}{separator}' process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999) save_thread_state('chat_pre') thread = 'chat' # Print bot response print(f'> {bot}{separator}', end='') start_index: int = len(processed_tokens) accumulated_tokens: List[int] = [] token_counts: Dict[int, int] = {} for i in range(MAX_GENERATION_LENGTH): for n in token_counts: logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY token: int = sampling.sample_logits(logits, temperature, top_p) if token == END_OF_TEXT_TOKEN: print() break if token not in token_counts: token_counts[token] = 1 else: token_counts[token] += 1 process_tokens([token]) # Avoid UTF-8 display issues accumulated_tokens += [token] decoded: str = tokenizer_decode(accumulated_tokens) if '\uFFFD' not in decoded: print(decoded, end='', flush=True) accumulated_tokens = [] if thread == 'chat': if '\n\n' in tokenizer_decode(processed_tokens[start_index:]): break if i == MAX_GENERATION_LENGTH - 1: print() save_thread_state(thread)