import torch from transformers import AutoTokenizer, AutoModelForCausalLM model_name = "facebook/opt-350m" # model_name = "NousResearch/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf") def predict(message, chatbot, temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,): system_message = "\n당신은 도움이 되고 정중하며 정직한 Assistant입니다. 안전을 유지하면서 항상 가능한 한 도움이 되도록 답변하십시오. 귀하의 답변에는 유해하거나, 비윤리적이거나, 인종차별적이거나, 성차별적이거나, 독성이 있거나, 위험하거나 불법적인 콘텐츠가 포함되어서는 안 됩니다. 귀하의 답변은 사회적으로 편견이 없고 긍정적입니다.\n\n질문이 의미가 없거나 사실적으로 일관성이 없는 경우, 옳지 않은 것에 답변하는 대신 이유를 설명하십시오. 질문에 대한 답변을 모르는 경우, 허위정보 공유하지 마세요" input_system = f"[INST] <>\n{system_message}\n<>\n\n " input_history = "" for interaction in chatbot: input_history = input_system + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] " input_prompt = input_history + str(message) + " [/INST] " inputs = tokenizer.encode(input_prompt, return_tensors="pt").to('cuda') temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( input_ids=inputs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, ) outputs = model.generate(**generate_kwargs) generated_indcluded_full_text = tokenizer.decode(outputs[0]) print("generated_indcluded_full_text:", generated_indcluded_full_text) generated_text = generated_indcluded_full_text.split('[/INST] ')[-1] if '' in generated_text : generated_text = generated_text.split('')[0] else : pass import json tokens = generated_text.split('\n') token_list = [] for idx, token in enumerate(tokens): token_dict = {"id": idx + 1, "text": token} token_list.append(token_dict) response = {"data": {"token": token_list}} response = json.dumps(response, indent=4) response = json.loads(response) data_dict = response.get('data', {}) token_list = data_dict.get('token', []) import time partial_message = "" for token_entry in token_list: if token_entry: try: token_id = token_entry.get('id', None) token_text = token_entry.get('text', None) if token_text: for char in token_text: partial_message += char yield partial_message time.sleep(0.01) else: gr.Warning(f"The key 'text' does not exist or is None in this token entry: {token_entry}") except KeyError as e: gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}") continue title = "TheBloke/Llama-2-7b-Chat-GPTQ닝 모델 chatbot" description = """ TheBloke/Llama-2-7b-Chat-GPTQ 모델입니다. """ css = """.toast-wrap { display: none !important } """ examples=[ ['Hello there! How are you doing?'], ['Can you explain to me briefly what is Python programming language?'], ['Explain the plot of Cinderella in a sentence.'], ['How many hours does it take a man to eat a Helicopter?'], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ] import gradio as gr def vote(data: gr.LikeData): if data.liked: print("You upvoted this response: " + data.value) else: print("You downvoted this response: " + data.value) additional_inputs=[ gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=4096, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'), bubble_full_width = False) chat_interface_stream = gr.ChatInterface(predict, title=title, description=description, chatbot=chatbot_stream, css=css, examples=examples, cache_examples=False, additional_inputs=additional_inputs,) with gr.Blocks() as demo: with gr.Tab("Streaming"): chatbot_stream.like(vote, None, None) chat_interface_stream.render() demo.queue(concurrency_count=75, max_size=100).launch(debug=True)