testegradio / llama.py
Pefu's picture
Upload folder using huggingface_hub
9c8a073
import os
import gradio as gr
import fire
from enum import Enum
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from transformers import TextIteratorStreamer
from llama_chat_format import format_to_llama_chat_style
# class syntax
class Model_Type(Enum):
gptq = 1
ggml = 2
full_precision = 3
def get_model_type(model_name):
if "gptq" in model_name.lower():
return Model_Type.gptq
elif "ggml" in model_name.lower():
return Model_Type.ggml
else:
return Model_Type.full_precision
def create_folder_if_not_exists(folder_path):
if not os.path.exists(folder_path):
os.makedirs(folder_path)
def initialize_gpu_model_and_tokenizer(model_name, model_type):
if model_type == Model_Type.gptq:
model = AutoGPTQForCausalLM.from_quantized(model_name, device_map="auto", use_safetensors=True, use_triton=False)
tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)
return model, tokenizer
def init_auto_model_and_tokenizer(model_name, model_type, file_name=None):
model_type = get_model_type(model_name)
if Model_Type.ggml == model_type:
models_folder = "./models"
create_folder_if_not_exists(models_folder)
file_path = hf_hub_download(repo_id=model_name, filename=file_name, local_dir=models_folder)
model = Llama(file_path, n_ctx=4096)
tokenizer = None
else:
model, tokenizer = initialize_gpu_model_and_tokenizer(model_name, model_type=model_type)
return model, tokenizer
def run_ui(model, tokenizer, is_chat_model, model_type):
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
if is_chat_model:
instruction = format_to_llama_chat_style(history)
else:
instruction = history[-1][0]
history[-1][1] = ""
kwargs = dict(temperature=0.6, top_p=0.9)
if model_type == Model_Type.ggml:
kwargs["max_tokens"] = 512
for chunk in model(prompt=instruction, stream=True, **kwargs):
token = chunk["choices"][0]["text"]
history[-1][1] += token
yield history
else:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
inputs = tokenizer(instruction, return_tensors="pt").to(model.device)
kwargs["max_new_tokens"] = 512
kwargs["input_ids"] = inputs["input_ids"]
kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=kwargs)
thread.start()
for token in streamer:
history[-1][1] += token
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(share=True, debug=True)
def main(model_name=None, file_name=None):
assert model_name is not None, "model_name argument is missing."
is_chat_model = 'chat' in model_name.lower()
model_type = get_model_type(model_name)
if model_type == Model_Type.ggml:
assert file_name is not None, "When model_name is provided for a GGML quantized model, file_name argument must also be provided."
model, tokenizer = init_auto_model_and_tokenizer(model_name, model_type, file_name)
run_ui(model, tokenizer, is_chat_model, model_type)
if __name__ == '__main__':
fire.Fire(main)