cyqm
commited on
Commit
·
396d643
1
Parent(s):
15cdada
Update handler: add random seed
Browse files- handler.py +5 -0
handler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import time
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
|
4 |
class EndpointHandler:
|
@@ -45,11 +46,15 @@ class EndpointHandler:
|
|
45 |
|
46 |
model_inputs = self.tokenizer([text], return_tensors="pt").to("cuda")
|
47 |
|
|
|
|
|
48 |
time_start = time.time()
|
49 |
generated_ids = self.model.generate(
|
50 |
**model_inputs,
|
51 |
max_new_tokens=max_new_tokens,
|
52 |
temperature=1.0,
|
|
|
|
|
53 |
**parameters
|
54 |
)
|
55 |
time_end = time.time()
|
|
|
1 |
import time
|
2 |
+
import random
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
|
5 |
class EndpointHandler:
|
|
|
46 |
|
47 |
model_inputs = self.tokenizer([text], return_tensors="pt").to("cuda")
|
48 |
|
49 |
+
torch.manual_seed(random.randint(0, 2 ** 32 - 1))
|
50 |
+
|
51 |
time_start = time.time()
|
52 |
generated_ids = self.model.generate(
|
53 |
**model_inputs,
|
54 |
max_new_tokens=max_new_tokens,
|
55 |
temperature=1.0,
|
56 |
+
do_sample=True,
|
57 |
+
top_p=0.9,
|
58 |
**parameters
|
59 |
)
|
60 |
time_end = time.time()
|