import torch import gradio as gr import re import transformers import peft import traceback from queue import Queue from threading import Thread import gc CUDA_AVAILABLE = torch.cuda.is_available() device = torch.device("cuda" if CUDA_AVAILABLE else "cpu") tokenizer = transformers.AutoTokenizer.from_pretrained("cerebras/Cerebras-GPT-2.7B") tokenizer.pad_token_id = 0 model = transformers.AutoModelForCausalLM.from_pretrained( "cerebras/Cerebras-GPT-2.7B", load_in_8bit=True, torch_dtype=torch.float16, device_map={'':0} if CUDA_AVAILABLE else 'auto', ) model = peft.PeftModel.from_pretrained( model, 'lxe/lora-cerebras-gpt2.7b-alpaca-shortprompt', torch_dtype=torch.float16 ) model.half() # Streaming functionality taken from https://github.com/oobabooga/text-generation-webui/blob/master/modules/text_generation.py#L105 class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func def __call__(self, input_ids, scores) -> bool: if self.callback_func is not None: self.callback_func(input_ids[0]) return False class Iteratorize: """ Transforms a function that takes a callback into a lazy iterator (generator). """ def __init__(self, func, kwargs={}, callback=None): self.mfunc=func self.c_callback=callback self.q = Queue() self.sentinel = object() self.kwargs = kwargs self.stop_now = False def _callback(val): if self.stop_now: raise ValueError self.q.put(val) def gentask(): try: ret = self.mfunc(callback=_callback, **self.kwargs) except ValueError: traceback.print_exc() pass except: traceback.print_exc() pass clear_torch_cache() self.q.put(self.sentinel) if self.c_callback: self.c_callback(ret) self.thread = Thread(target=gentask) self.thread.start() def __iter__(self): return self def __next__(self): obj = self.q.get(True,None) if obj is self.sentinel: raise StopIteration else: return obj def __del__(self): clear_torch_cache() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop_now = True clear_torch_cache() def clear_torch_cache(): gc.collect() if CUDA_AVAILABLE: torch.cuda.empty_cache() def generate_text( history, max_new_tokens, do_sample, temperature, top_p, top_k, repetition_penalty, typical_p, num_beams ): # Create a conversation context of the last 4 entries in the history inp = ''.join([ f"Human: {h[0]}\n\nAssistant: {'' if h[1] is None else h[1]}\n\n" for h in history[-4:] ]).strip() input_ids = tokenizer.encode( inp, return_tensors='pt', truncation=True, add_special_tokens=False ).to(device) # type: ignore generate_params = { "input_ids": input_ids, "max_new_tokens": max_new_tokens, "do_sample": do_sample, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "typical_p": typical_p, "num_beams": num_beams, "stopping_criteria": transformers.StoppingCriteriaList(), "pad_token_id": tokenizer.pad_token_id, } def generate_with_callback(callback=None, **kwargs): kwargs['stopping_criteria'].append(Stream(callback_func=callback)) clear_torch_cache() with torch.no_grad(): model.generate(**kwargs) # type: ignore def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) with generate_with_streaming(**generate_params) as generator: for output in generator: new_tokens = len(output) - len(input_ids[0]) reply = tokenizer.decode(output[-new_tokens:], skip_special_tokens=True) # If reply contains '^Human:' or '^Assistant:' # then we have reached the end of the assistant's response stop_re = re.compile(r'^(Human|Assistant):', re.MULTILINE) if re.search(stop_re, reply): reply = ''.join(reply.split('\n')[:-1]) history[-1][1] = reply.strip() yield history break # if reply contains 'EOS' then we have reached the end of the conversation if output[-1] in [tokenizer.eos_token_id]: yield history break history[-1][1] = reply.strip() yield history with gr.Blocks() as demo: gr.Markdown(""" ## πŸΊπŸ¦™ Cerebras GPT-2.7B Alpcaca-Shortprompt LoRA Chatbot This is a very fast and relatively coherent (but hallucinating) chatbot. It uses the [Cerebras-GPT-2.7B](https://huggingface.co/cerebras/Cerebras-GPT-2.7B), with a LoRA finetuned on the [Alpcaca Dataset](https://github.com/tloen/alpaca-lora/blob/main/alpaca_data_cleaned.json) dataset using a shorter prompt. The chatbot keeps a very short conversation context of 4 entries. It's the fastest chatbot in the west! More info [here](https://github.com/lxe/cerebras-lora-alpaca) """) with gr.Row(): with gr.Column(): chatbot = gr.Chatbot() msg = gr.Textbox(value="How old is the Earth?", placeholder="Type a message...") with gr.Row(): clear = gr.Button("Clear") with gr.Column(): max_new_tokens = gr.Slider(0, 2048, 200, step=1, label="max_new_tokens") do_sample = gr.Checkbox(True, label="do_sample") with gr.Row(): with gr.Column(): temperature = gr.Slider(0, 2, 0.1, step=0.01, label="temperature") top_p = gr.Slider(0, 1, 0.8, step=0.01, label="top_p") top_k = gr.Slider(0, 100, 35, step=1, label="top_k") with gr.Column(): repetition_penalty = gr.Slider(0, 10, 1.1, step=0.01, label="repetition_penalty") typical_p = gr.Slider(0, 1, 1, step=0.01, label="typical_p") num_beams = gr.Slider(0, 10, 1, step=1, label="num_beams") def user(user_message, history): return "", history + [[user_message, None]] def fix_history(history): update_history = False for i, (user, bot) in enumerate(history): if bot is None: update_history = True history[i][1] = "_silence_" if update_history: chatbot.update(history) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( generate_text, inputs=[ chatbot, max_new_tokens, do_sample, temperature, top_p, top_k, repetition_penalty, typical_p, num_beams ], outputs=[chatbot], ).then(fix_history, chatbot) clear.click(lambda: None, None, chatbot, queue=False) demo.queue().launch()