from threading import Thread import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer model_id = "fireballoon/baichuan-vicuna-chinese-7b" torch_device = "cuda" if torch.cuda.is_available() else "cpu" print("Running on device:", torch_device) print("CPU threads:", torch.get_num_threads()) if torch_device == "cuda": model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).cuda() else: model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) def run_generation(history, *args, **kwargs): # Get the model and tokenizer, and tokenize the user text. instruction = "A chat between a curious user and an artificial intelligence assistant. " \ "The assistant gives helpful, detailed, and polite answers to the user's questions." context = ''.join([f" USER: {turn[0].strip()} ASSISTANT: {turn[1].strip()} " for turn in history[:-1]]) prompt = instruction + context + f" USER: {history[-1][0].strip()} ASSISTANT:" input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() print() print(prompt) print('##', input_ids.size()) # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread. streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=2048, do_sample=True, temperature=0.7, repetition_penalty=1.1, top_p=0.85 ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Pull the generated text from the streamer, and update the model output. history[-1][1] = "" print("") for new_text in streamer: history[-1][1] += new_text print(new_text, end="", flush=True) yield history print('') return history def reset_textbox(): return gr.update(value='') with gr.Blocks() as demo: gr.Markdown( "# Baichuan Vicuna Chinese\n" f"[{model_id}](https://huggingface.co/{model_id}):使用中英双语sharegpt数据全参数微调的对话模型,基于baichuan-7b" ) chatbot = gr.Chatbot().style(height=600) msg = gr.Textbox() clear = gr.ClearButton([msg, chatbot]) def user(user_message, history): return gr.update(value="", interactive=False), history + [[user_message, None]] response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( run_generation, chatbot, chatbot ) response.then(lambda: gr.update(interactive=True), None, [msg], queue=False) demo.queue() demo.launch(server_name='0.0.0.0')