Yuan2-2B-demo / app.py
IEIT-Yuan's picture
Update app.py
38ce3d1 verified
raw
history blame
3.44 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import torch, transformers
import sys, os
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from transformers import AutoModelForCausalLM,AutoTokenizer,LlamaTokenizer
print("Creat tokenizer...")
tokenizer = LlamaTokenizer.from_pretrained('IEITYuan/Yuan2-2B-Janus-hf', add_eos_token=False, add_bos_token=False, eos_token='<eod>')
tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
print("Creat model...")
model = AutoModelForCausalLM.from_pretrained('IEITYuan/Yuan2-2B-Janus-hf', device_map='auto', torch_dtype=torch.bfloat16, trust_remote_code=True)
# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return True
return False
# Function to generate model predictions.
def predict(message, history):
# history_transformer_format = history + [[message, ""]]
# stop = StopOnTokens()
#
# # Formatting the input for the model.
# messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
# for item in history_transformer_format])
# model_inputs = tokenizer([messages], return_tensors="pt").to(device)
# streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# generate_kwargs = dict(
# model_inputs,
# streamer=streamer,
# max_new_tokens=1024,
# do_sample=True,
# top_p=0.95,
# top_k=50,
# temperature=0.7,
# num_beams=1,
# stopping_criteria=StoppingCriteriaList([stop])
# )
# t = Thread(target=model.generate, kwargs=generate_kwargs)
# t.start() # Starting the generation in a separate thread.
# partial_message = ""
# for new_token in streamer:
# partial_message += new_token
# if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
# break
# yield partial_message
inputs = tokenizer(message, return_tensors="pt")["input_ids"].to(device)
outputs = model.generate(inputs, do_sample=False, max_length=500)
print(tokenizer.decode(outputs[0]))
return(tokenizer.decode(outputs[0]))
# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
title="Yuan2_2b_chatBot",
description="่ฏทๆ้—ฎ",
examples=['่ฏท้—ฎ็›ฎๅ‰ๆœ€ๅ…ˆ่ฟ›็š„ๆœบๅ™จๅญฆไน ็ฎ—ๆณ•ๆœ‰ๅ“ชไบ›๏ผŸ','ไฝœไธ€้ฆ–ๅ…ณไบŽๆ–ฐๅนดๅฟซไน็š„่ฏ—','ๅŒ—ไบฌ็ƒค้ธญๆ€Žไนˆๅš๏ผŸ']
).launch() # Launching the web interface.