import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "facebook/opt-350m"
# model_name = "NousResearch/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
def predict(message, chatbot, temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
system_message = "\n당신은 도움이 되고 정중하며 정직한 Assistant입니다. 안전을 유지하면서 항상 가능한 한 도움이 되도록 답변하십시오. 귀하의 답변에는 유해하거나, 비윤리적이거나, 인종차별적이거나, 성차별적이거나, 독성이 있거나, 위험하거나 불법적인 콘텐츠가 포함되어서는 안 됩니다. 귀하의 답변은 사회적으로 편견이 없고 긍정적입니다.\n\n질문이 의미가 없거나 사실적으로 일관성이 없는 경우, 옳지 않은 것에 답변하는 대신 이유를 설명하십시오. 질문에 대한 답변을 모르는 경우, 허위정보 공유하지 마세요"
input_system = f"[INST] <>\n{system_message}\n<>\n\n "
input_history = ""
for interaction in chatbot:
input_history = input_system + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] "
input_prompt = input_history + str(message) + " [/INST] "
inputs = tokenizer.encode(input_prompt, return_tensors="pt").to('cuda')
temperature = float(temperature)
if temperature < 1e-2: temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
input_ids=inputs,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
outputs = model.generate(**generate_kwargs)
generated_indcluded_full_text = tokenizer.decode(outputs[0])
print("generated_indcluded_full_text:", generated_indcluded_full_text)
generated_text = generated_indcluded_full_text.split('[/INST] ')[-1]
if '' in generated_text :
generated_text = generated_text.split('')[0]
else : pass
import json
tokens = generated_text.split('\n')
token_list = []
for idx, token in enumerate(tokens):
token_dict = {"id": idx + 1, "text": token}
token_list.append(token_dict)
response = {"data": {"token": token_list}}
response = json.dumps(response, indent=4)
response = json.loads(response)
data_dict = response.get('data', {})
token_list = data_dict.get('token', [])
import time
partial_message = ""
for token_entry in token_list:
if token_entry:
try:
token_id = token_entry.get('id', None)
token_text = token_entry.get('text', None)
if token_text:
for char in token_text:
partial_message += char
yield partial_message
time.sleep(0.01)
else:
gr.Warning(f"The key 'text' does not exist or is None in this token entry: {token_entry}")
except KeyError as e:
gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}")
continue
title = "TheBloke/Llama-2-7b-Chat-GPTQ닝 모델 chatbot"
description = """
TheBloke/Llama-2-7b-Chat-GPTQ 모델입니다.
"""
css = """.toast-wrap { display: none !important } """
examples=[
['Hello there! How are you doing?'],
['Can you explain to me briefly what is Python programming language?'],
['Explain the plot of Cinderella in a sentence.'],
['How many hours does it take a man to eat a Helicopter?'],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
]
import gradio as gr
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'), bubble_full_width = False)
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
chatbot=chatbot_stream,
css=css,
examples=examples,
cache_examples=False,
additional_inputs=additional_inputs,)
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
chatbot_stream.like(vote, None, None)
chat_interface_stream.render()
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)