Spaces:
Sleeping
Sleeping
reorder LogitsWarpers and tidy
Browse files
app.py
CHANGED
@@ -51,16 +51,16 @@ def generate_step(out: object,
|
|
51 |
- list: batch_size tokens
|
52 |
"""
|
53 |
logits = out.logits[:, gen_idx]
|
54 |
-
|
55 |
-
if top_k > 0:
|
56 |
-
logit_warpers += [TopKLogitsWarper(top_k)]
|
57 |
if temperature:
|
58 |
-
|
|
|
|
|
59 |
if typical_p > 0:
|
60 |
if typical_p >= 1:
|
61 |
typical_p = 0.999
|
62 |
-
|
63 |
-
logits =
|
64 |
|
65 |
if sample:
|
66 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
51 |
- list: batch_size tokens
|
52 |
"""
|
53 |
logits = out.logits[:, gen_idx]
|
54 |
+
warpers = LogitsProcessorList()
|
|
|
|
|
55 |
if temperature:
|
56 |
+
warpers += [TemperatureLogitsWarper(temperature)]
|
57 |
+
if top_k > 0:
|
58 |
+
warpers += [TopKLogitsWarper(top_k)]
|
59 |
if typical_p > 0:
|
60 |
if typical_p >= 1:
|
61 |
typical_p = 0.999
|
62 |
+
warpers += [TypicalLogitsWarper(typical_p)]
|
63 |
+
logits = warpers(None, logits)
|
64 |
|
65 |
if sample:
|
66 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|