Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -48,7 +48,11 @@ Assistant: girlfriend;mother;father;friend
|
|
48 |
torch.set_grad_enabled(False)
|
49 |
model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
|
50 |
pipe = pipeline("text-generation", model=model_name, device='cuda')
|
51 |
-
generate_kwargs = {
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def past_kv_to_device(past_kv, device, dtype):
|
54 |
return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
|
@@ -84,14 +88,15 @@ def generate(text, past_key_values):
|
|
84 |
prompt_format.format(system_message=system_prompt, prompt=text), **cur_generate_kwargs
|
85 |
)[0]['generated_text']
|
86 |
print(response)
|
87 |
-
return response[
|
|
|
88 |
|
89 |
|
90 |
if __name__ == "__main__":
|
91 |
with torch.no_grad():
|
92 |
-
past_key_values = set_past_key_values()
|
93 |
demo = gr.Interface(
|
94 |
-
partial(generate, past_key_values=
|
95 |
inputs="textbox", outputs="textbox"
|
96 |
)
|
97 |
demo.launch()
|
|
|
48 |
torch.set_grad_enabled(False)
|
49 |
model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
|
50 |
pipe = pipeline("text-generation", model=model_name, device='cuda')
|
51 |
+
generate_kwargs = {
|
52 |
+
'max_new_tokens': 20,
|
53 |
+
'temperature': 0.8,
|
54 |
+
'repetition_penalty': 1.1
|
55 |
+
}
|
56 |
|
57 |
def past_kv_to_device(past_kv, device, dtype):
|
58 |
return tuple((torch.tensor(k).to(device).to(dtype), torch.tensor(v).to(device).to(dtype)) for k, v in past_kv)
|
|
|
88 |
prompt_format.format(system_message=system_prompt, prompt=text), **cur_generate_kwargs
|
89 |
)[0]['generated_text']
|
90 |
print(response)
|
91 |
+
return response.split('<|im_start|>assistant')[1]
|
92 |
+
# return response[-1]['content']
|
93 |
|
94 |
|
95 |
if __name__ == "__main__":
|
96 |
with torch.no_grad():
|
97 |
+
# past_key_values = set_past_key_values()
|
98 |
demo = gr.Interface(
|
99 |
+
partial(generate, past_key_values=None),
|
100 |
inputs="textbox", outputs="textbox"
|
101 |
)
|
102 |
demo.launch()
|