thliang01's picture
docs: change lincense
71c9089 unverified
raw
history blame
4.43 kB
import gradio as gr
import spaces
import os
import spaces
import torch
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔
DESCRIPTION = '''
'''
LICENSE = """
<p/>
---
Built with TAIDE-LX-8B-Chat
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1")
model = AutoModelForCausalLM.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_taide_8b(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str:
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
#print(outputs)
yield "".join(outputs)
# Gradio block
chatbot=gr.Chatbot(height=450, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
gr.ChatInterface(
fn=chat_taide_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False ),
],
examples=[
['請以以下內容為基礎,寫一篇文章:撰寫一篇作文,題目為《一張舊照片》,內容要求為:選擇一張令你印象深刻的照片,說明令你印象深刻的原因,並描述照片中的影像及背後的故事。記錄成長的過程、與他人的情景、環境變遷和美麗的景色。'],
['請以品牌經理的身份,給廣告公司的創意總監寫一封信,提出對於新產品廣告宣傳活動的創意建議。'],
['以下提供英文內容,請幫我翻譯成中文。Dongshan coffee is famous for its unique position, and the constant refinement of production methods. The flavor is admired by many caffeine afficionados.'],
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()