teticio commited on
Commit
ec1ff5d
·
1 Parent(s): 6425ca8

reorder LogitsWarpers and tidy

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -51,16 +51,16 @@ def generate_step(out: object,
51
  - list: batch_size tokens
52
  """
53
  logits = out.logits[:, gen_idx]
54
- logit_warpers = []
55
- if top_k > 0:
56
- logit_warpers += [TopKLogitsWarper(top_k)]
57
  if temperature:
58
- logit_warpers += [TemperatureLogitsWarper(temperature)]
 
 
59
  if typical_p > 0:
60
  if typical_p >= 1:
61
  typical_p = 0.999
62
- logit_warpers += [TypicalLogitsWarper(typical_p)]
63
- logits = LogitsProcessorList(logit_warpers)(None, logits)
64
 
65
  if sample:
66
  probs = torch.nn.functional.softmax(logits, dim=-1)
 
51
  - list: batch_size tokens
52
  """
53
  logits = out.logits[:, gen_idx]
54
+ warpers = LogitsProcessorList()
 
 
55
  if temperature:
56
+ warpers += [TemperatureLogitsWarper(temperature)]
57
+ if top_k > 0:
58
+ warpers += [TopKLogitsWarper(top_k)]
59
  if typical_p > 0:
60
  if typical_p >= 1:
61
  typical_p = 0.999
62
+ warpers += [TypicalLogitsWarper(typical_p)]
63
+ logits = warpers(None, logits)
64
 
65
  if sample:
66
  probs = torch.nn.functional.softmax(logits, dim=-1)