Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import re | |
import transformers | |
import peft | |
import traceback | |
from queue import Queue | |
from threading import Thread | |
import gc | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
device = torch.device("cuda" if CUDA_AVAILABLE else "cpu") | |
tokenizer = transformers.AutoTokenizer.from_pretrained("cerebras/Cerebras-GPT-2.7B") | |
tokenizer.pad_token_id = 0 | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
"cerebras/Cerebras-GPT-2.7B", | |
load_in_8bit=True, | |
torch_dtype=torch.float16, | |
device_map={'':0} if CUDA_AVAILABLE else 'auto', | |
) | |
model = peft.PeftModel.from_pretrained( | |
model, | |
'lxe/lora-cerebras-gpt2.7b-alpaca-shortprompt', | |
torch_dtype=torch.float16 | |
) | |
model.half() | |
# Streaming functionality taken from https://github.com/oobabooga/text-generation-webui/blob/master/modules/text_generation.py#L105 | |
class Stream(transformers.StoppingCriteria): | |
def __init__(self, callback_func=None): | |
self.callback_func = callback_func | |
def __call__(self, input_ids, scores) -> bool: | |
if self.callback_func is not None: | |
self.callback_func(input_ids[0]) | |
return False | |
class Iteratorize: | |
""" | |
Transforms a function that takes a callback | |
into a lazy iterator (generator). | |
""" | |
def __init__(self, func, kwargs={}, callback=None): | |
self.mfunc=func | |
self.c_callback=callback | |
self.q = Queue() | |
self.sentinel = object() | |
self.kwargs = kwargs | |
self.stop_now = False | |
def _callback(val): | |
if self.stop_now: | |
raise ValueError | |
self.q.put(val) | |
def gentask(): | |
try: | |
ret = self.mfunc(callback=_callback, **self.kwargs) | |
except ValueError: | |
traceback.print_exc() | |
pass | |
except: | |
traceback.print_exc() | |
pass | |
clear_torch_cache() | |
self.q.put(self.sentinel) | |
if self.c_callback: | |
self.c_callback(ret) | |
self.thread = Thread(target=gentask) | |
self.thread.start() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
obj = self.q.get(True,None) | |
if obj is self.sentinel: | |
raise StopIteration | |
else: | |
return obj | |
def __del__(self): | |
clear_torch_cache() | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.stop_now = True | |
clear_torch_cache() | |
def clear_torch_cache(): | |
gc.collect() | |
if CUDA_AVAILABLE: | |
torch.cuda.empty_cache() | |
def generate_text( | |
history, | |
max_new_tokens, | |
do_sample, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
typical_p, | |
num_beams | |
): | |
# Create a conversation context of the last 4 entries in the history | |
inp = ''.join([ | |
f"Human: {h[0]}\n\nAssistant: {'' if h[1] is None else h[1]}\n\n" for h in history[-4:] | |
]).strip() | |
input_ids = tokenizer.encode( | |
inp, | |
return_tensors='pt', | |
truncation=True, | |
add_special_tokens=False | |
).to(device) # type: ignore | |
generate_params = { | |
"input_ids": input_ids, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": do_sample, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
"typical_p": typical_p, | |
"num_beams": num_beams, | |
"stopping_criteria": transformers.StoppingCriteriaList(), | |
"pad_token_id": tokenizer.pad_token_id, | |
} | |
def generate_with_callback(callback=None, **kwargs): | |
kwargs['stopping_criteria'].append(Stream(callback_func=callback)) | |
clear_torch_cache() | |
with torch.no_grad(): | |
model.generate(**kwargs) # type: ignore | |
def generate_with_streaming(**kwargs): | |
return Iteratorize(generate_with_callback, kwargs, callback=None) | |
with generate_with_streaming(**generate_params) as generator: | |
for output in generator: | |
new_tokens = len(output) - len(input_ids[0]) | |
reply = tokenizer.decode(output[-new_tokens:], skip_special_tokens=True) | |
# If reply contains '^Human:' or '^Assistant:' | |
# then we have reached the end of the assistant's response | |
stop_re = re.compile(r'^(Human|Assistant):', re.MULTILINE) | |
if re.search(stop_re, reply): | |
reply = ''.join(reply.split('\n')[:-1]) | |
history[-1][1] = reply.strip() | |
yield history | |
break | |
# if reply contains 'EOS' then we have reached the end of the conversation | |
if output[-1] in [tokenizer.eos_token_id]: | |
yield history | |
break | |
history[-1][1] = reply.strip() | |
yield history | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
## 🐺🦙 Cerebras GPT-2.7B Alpcaca-Shortprompt LoRA Chatbot | |
This is a very fast and relatively coherent (but hallucinating) chatbot. | |
It uses the [Cerebras-GPT-2.7B](https://huggingface.co/cerebras/Cerebras-GPT-2.7B), with a LoRA finetuned on the [Alpcaca Dataset](https://github.com/tloen/alpaca-lora/blob/main/alpaca_data_cleaned.json) dataset using a shorter prompt. | |
The chatbot keeps a very short conversation context of 4 entries. It's the fastest chatbot in the west! | |
More info [here](https://github.com/lxe/cerebras-lora-alpaca) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(value="How old is the Earth?", placeholder="Type a message...") | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
with gr.Column(): | |
max_new_tokens = gr.Slider(0, 2048, 200, step=1, label="max_new_tokens") | |
do_sample = gr.Checkbox(True, label="do_sample") | |
with gr.Row(): | |
with gr.Column(): | |
temperature = gr.Slider(0, 2, 0.1, step=0.01, label="temperature") | |
top_p = gr.Slider(0, 1, 0.8, step=0.01, label="top_p") | |
top_k = gr.Slider(0, 100, 35, step=1, label="top_k") | |
with gr.Column(): | |
repetition_penalty = gr.Slider(0, 10, 1.1, step=0.01, label="repetition_penalty") | |
typical_p = gr.Slider(0, 1, 1, step=0.01, label="typical_p") | |
num_beams = gr.Slider(0, 10, 1, step=1, label="num_beams") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def fix_history(history): | |
update_history = False | |
for i, (user, bot) in enumerate(history): | |
if bot is None: | |
update_history = True | |
history[i][1] = "_silence_" | |
if update_history: | |
chatbot.update(history) | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
generate_text, inputs=[ | |
chatbot, | |
max_new_tokens, | |
do_sample, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
typical_p, | |
num_beams | |
], outputs=[chatbot], | |
).then(fix_history, chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue().launch() | |