dar-tau commited on
Commit
906564b
·
verified ·
1 Parent(s): 14c86d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -48,7 +48,11 @@ Assistant: girlfriend;mother;father;friend
48
  torch.set_grad_enabled(False)
49
  model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
50
  pipe = pipeline("text-generation", model=model_name, device='cuda')
51
- generate_kwargs = {'max_new_tokens': 20}
 
 
 
 
52
 
53
  def past_kv_to_device(past_kv, device, dtype):
54
  return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
@@ -84,14 +88,15 @@ def generate(text, past_key_values):
84
  prompt_format.format(system_message=system_prompt, prompt=text), **cur_generate_kwargs
85
  )[0]['generated_text']
86
  print(response)
87
- return response[-1]['content']
 
88
 
89
 
90
  if __name__ == "__main__":
91
  with torch.no_grad():
92
- past_key_values = set_past_key_values()
93
  demo = gr.Interface(
94
- partial(generate, past_key_values=past_key_values),
95
  inputs="textbox", outputs="textbox"
96
  )
97
  demo.launch()
 
48
  torch.set_grad_enabled(False)
49
  model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
50
  pipe = pipeline("text-generation", model=model_name, device='cuda')
51
+ generate_kwargs = {
52
+ 'max_new_tokens': 20,
53
+ 'temperature': 0.8,
54
+ 'repetition_penalty': 1.1
55
+ }
56
 
57
  def past_kv_to_device(past_kv, device, dtype):
58
  return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
 
88
  prompt_format.format(system_message=system_prompt, prompt=text), **cur_generate_kwargs
89
  )[0]['generated_text']
90
  print(response)
91
+ return response.split('<|im_start|>assistant')[1]
92
+ # return response[-1]['content']
93
 
94
 
95
  if __name__ == "__main__":
96
  with torch.no_grad():
97
+ # past_key_values = set_past_key_values()
98
  demo = gr.Interface(
99
+ partial(generate, past_key_values=None),
100
  inputs="textbox", outputs="textbox"
101
  )
102
  demo.launch()