CognitiveScience commited on
Commit
3516070
·
1 Parent(s): 98a3583

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -48
app.py CHANGED
@@ -29,51 +29,6 @@ from huggingface_hub import hf_hub_download
29
  from huggingface_hub import login
30
  from datasets import load_dataset
31
 
32
- import torch
33
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
34
- from threading import Thread
35
-
36
- tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
37
- model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
38
- model = model.to('cuda:0')
39
-
40
- class StopOnTokens(StoppingCriteria):
41
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
42
- stop_ids = [29, 0]
43
- for stop_id in stop_ids:
44
- if input_ids[0][-1] == stop_id:
45
- return True
46
- return False
47
-
48
- def predict(message, history):
49
-
50
- history_transformer_format = history + [[message, ""]]
51
- stop = StopOnTokens()
52
-
53
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
54
- for item in history_transformer_format])
55
-
56
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
57
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
58
- generate_kwargs = dict(
59
- model_inputs,
60
- streamer=streamer,
61
- max_new_tokens=1024,
62
- do_sample=True,
63
- top_p=0.95,
64
- top_k=1000,
65
- temperature=1.0,
66
- num_beams=1,
67
- stopping_criteria=StoppingCriteriaList([stop])
68
- )
69
- t = Thread(target=model.generate, kwargs=generate_kwargs)
70
- t.start()
71
-
72
- partial_message = ""
73
- for new_token in streamer:
74
- if new_token != '<':
75
- partial_message += new_token
76
- yield partial_message
77
  #dataset = load_dataset("csv", data_files="./data.csv")
78
 
79
 
@@ -255,9 +210,6 @@ with gr.Blocks() as demo:
255
  #@celsci.change(inputs=celsci, outputs=rate,_js="window.location.reload()")
256
  #def secwork(name):
257
  # load_data()
258
- with gr.Row():
259
- with gr.Column():
260
- gr.ChatInterface(predict)
261
  def backup_db():
262
  shutil.copyfile(DB_FILE, "./reviews.db")
263
  db = sqlite3.connect(DB_FILE)
 
29
  from huggingface_hub import login
30
  from datasets import load_dataset
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  #dataset = load_dataset("csv", data_files="./data.csv")
33
 
34
 
 
210
  #@celsci.change(inputs=celsci, outputs=rate,_js="window.location.reload()")
211
  #def secwork(name):
212
  # load_data()
 
 
 
213
  def backup_db():
214
  shutil.copyfile(DB_FILE, "./reviews.db")
215
  db = sqlite3.connect(DB_FILE)