CognitiveScience commited on
Commit
34417e8
·
1 Parent(s): 3b31d45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -1
app.py CHANGED
@@ -29,6 +29,51 @@ from huggingface_hub import hf_hub_download
29
  from huggingface_hub import login
30
  from datasets import load_dataset
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  #dataset = load_dataset("csv", data_files="./data.csv")
33
 
34
 
@@ -209,7 +254,10 @@ with gr.Blocks() as demo:
209
  #@rate.change(inputs=rate, outputs=name,_js="window.location.reload()")
210
  #@celsci.change(inputs=celsci, outputs=rate,_js="window.location.reload()")
211
  #def secwork(name):
212
- # load_data()
 
 
 
213
  def backup_db():
214
  shutil.copyfile(DB_FILE, "./reviews.db")
215
  db = sqlite3.connect(DB_FILE)
 
29
  from huggingface_hub import login
30
  from datasets import load_dataset
31
 
32
+ import torch
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
34
+ from threading import Thread
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
37
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
38
+ model = model.to('cuda:0')
39
+
40
+ class StopOnTokens(StoppingCriteria):
41
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
42
+ stop_ids = [29, 0]
43
+ for stop_id in stop_ids:
44
+ if input_ids[0][-1] == stop_id:
45
+ return True
46
+ return False
47
+
48
+ def predict(message, history):
49
+
50
+ history_transformer_format = history + [[message, ""]]
51
+ stop = StopOnTokens()
52
+
53
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
54
+ for item in history_transformer_format])
55
+
56
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
57
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
58
+ generate_kwargs = dict(
59
+ model_inputs,
60
+ streamer=streamer,
61
+ max_new_tokens=1024,
62
+ do_sample=True,
63
+ top_p=0.95,
64
+ top_k=1000,
65
+ temperature=1.0,
66
+ num_beams=1,
67
+ stopping_criteria=StoppingCriteriaList([stop])
68
+ )
69
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
70
+ t.start()
71
+
72
+ partial_message = ""
73
+ for new_token in streamer:
74
+ if new_token != '<':
75
+ partial_message += new_token
76
+ yield partial_message
77
  #dataset = load_dataset("csv", data_files="./data.csv")
78
 
79
 
 
254
  #@rate.change(inputs=rate, outputs=name,_js="window.location.reload()")
255
  #@celsci.change(inputs=celsci, outputs=rate,_js="window.location.reload()")
256
  #def secwork(name):
257
+ # load_data()
258
+ with gr.Row()
259
+ with gr.Column():
260
+ dem=gr.ChatInterface(predict).queue()
261
  def backup_db():
262
  shutil.copyfile(DB_FILE, "./reviews.db")
263
  db = sqlite3.connect(DB_FILE)