IEIT-Yuan commited on
Commit
b1e52c0
Β·
verified Β·
1 Parent(s): 87f8f2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -10,11 +10,11 @@ sys.path.append(
10
  from transformers import AutoModelForCausalLM,AutoTokenizer,LlamaTokenizer
11
 
12
  print("Creat tokenizer...")
13
- tokenizer = LlamaTokenizer.from_pretrained('IEITYuan/Yuan2-2B-hf', add_eos_token=False, add_bos_token=False, eos_token='<eod>')
14
  tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
15
 
16
  print("Creat model...")
17
- model = AutoModelForCausalLM.from_pretrained('IEITYuan/Yuan2-2B-hf', device_map='auto', torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  # using CUDA for an optimal experience
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  model = model.to(device)
@@ -31,33 +31,36 @@ class StopOnTokens(StoppingCriteria):
31
 
32
  # Function to generate model predictions.
33
  def predict(message, history):
34
- history_transformer_format = history + [[message, ""]]
35
- stop = StopOnTokens()
36
-
37
- # Formatting the input for the model.
38
- messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
39
- for item in history_transformer_format])
40
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
41
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
42
- generate_kwargs = dict(
43
- model_inputs,
44
- streamer=streamer,
45
- max_new_tokens=1024,
46
- do_sample=True,
47
- top_p=0.95,
48
- top_k=50,
49
- temperature=0.7,
50
- num_beams=1,
51
- stopping_criteria=StoppingCriteriaList([stop])
52
- )
53
- t = Thread(target=model.generate, kwargs=generate_kwargs)
54
- t.start() # Starting the generation in a separate thread.
55
- partial_message = ""
56
- for new_token in streamer:
57
- partial_message += new_token
58
- if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
59
- break
60
- yield partial_message
 
 
 
61
 
62
 
63
  # Setting up the Gradio chat interface.
 
10
  from transformers import AutoModelForCausalLM,AutoTokenizer,LlamaTokenizer
11
 
12
  print("Creat tokenizer...")
13
+ tokenizer = LlamaTokenizer.from_pretrained('IEITYuan/Yuan2-2B-Janus-hf', add_eos_token=False, add_bos_token=False, eos_token='<eod>')
14
  tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
15
 
16
  print("Creat model...")
17
+ model = AutoModelForCausalLM.from_pretrained('IEITYuan/Yuan2-2B-Janus-hf', device_map='auto', torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  # using CUDA for an optimal experience
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  model = model.to(device)
 
31
 
32
  # Function to generate model predictions.
33
  def predict(message, history):
34
+ # history_transformer_format = history + [[message, ""]]
35
+ # stop = StopOnTokens()
36
+ #
37
+ # # Formatting the input for the model.
38
+ # messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
39
+ # for item in history_transformer_format])
40
+ # model_inputs = tokenizer([messages], return_tensors="pt").to(device)
41
+ # streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
42
+ # generate_kwargs = dict(
43
+ # model_inputs,
44
+ # streamer=streamer,
45
+ # max_new_tokens=1024,
46
+ # do_sample=True,
47
+ # top_p=0.95,
48
+ # top_k=50,
49
+ # temperature=0.7,
50
+ # num_beams=1,
51
+ # stopping_criteria=StoppingCriteriaList([stop])
52
+ # )
53
+ # t = Thread(target=model.generate, kwargs=generate_kwargs)
54
+ # t.start() # Starting the generation in a separate thread.
55
+ # partial_message = ""
56
+ # for new_token in streamer:
57
+ # partial_message += new_token
58
+ # if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
59
+ # break
60
+ # yield partial_message
61
+ inputs = tokenizer(message, return_tensors="pt")["input_ids"].to("cuda:0")
62
+ outputs = model.generate(inputs, do_sample=False, max_length=100)
63
+ return(tokenizer.decode(outputs[0]))
64
 
65
 
66
  # Setting up the Gradio chat interface.