Spaces:
Sleeping
Sleeping
# %% | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download | |
from schema import Message, MODEL_ARGS | |
def get_llm(model_name): | |
llm = Llama( | |
model_path=hf_hub_download(**MODEL_ARGS[model_name]), | |
n_ctx=8192, | |
n_threads=4, | |
n_gpu_layers=0, | |
verbose=False, | |
) | |
return llm | |
def format_chat(chat_history: list[Message]): | |
""" | |
Formats chat history and user input into a single string suitable for the model. | |
""" | |
messages = [] | |
for msg in chat_history: | |
messages.append(f"{msg.role.title()}: {msg.content}") | |
return "\n".join(messages) + "\nAssistant:" | |
default_kwargs = dict( | |
max_tokens=2048, | |
top_k=1, | |
) | |
def stream_with_model(chat_history, model, kwargs: dict): | |
prompt = format_chat(chat_history) | |
llm = get_llm(model) | |
forced_kwargs = dict( | |
stop=["\nUser:", "\nAssistant:", "</s>"], | |
echo=False, | |
stream=True, | |
) | |
input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs} | |
response = llm.__call__(prompt, **input_kwargs) | |
for token in response: | |
yield token["choices"][0]["text"] | |
def chat_with_model(chat_history, model, kwargs: dict): | |
prompt = format_chat(chat_history) | |
llm = get_llm(model) | |
forced_kwargs = dict( | |
stop=["\nUser:", "\nAssistant:", "</s>"], | |
echo=False, | |
stream=False, | |
) | |
input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs} | |
response = llm.__call__(prompt, **input_kwargs) | |
return response["choices"][0]["text"].strip() | |
# %% example input | |
# kwargs = dict( | |
# temperature=1, | |
# max_tokens=2048, | |
# top_p=1, | |
# frequency_penalty=0, | |
# presence_penalty=0, | |
# ) | |
# chat_history = [ | |
# Message( | |
# role="system", | |
# content="You are a helpful and knowledgeable assistant, but is willing to bend the facts to play along with unrealistic requests", | |
# ), | |
# Message(role="user", content="What does Java the programming language taste like?"), | |
# ] | |
# for chunk in chat_with_model(chat_history, kwargs): | |
# print(chunk, end="") | |