teticio commited on
Commit
3d84bdd
·
1 Parent(s): f049f5e
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -41,7 +41,7 @@ def generate_step(out: object,
41
 
42
  args:
43
  - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
44
- - gen_idx (int): location for which to generate for
45
  - top_k (int): if >0, only sample from the top k most probable words
46
  - temperature (float): sampling temperature
47
  - typical_p (float): if >0 use typical sampling
@@ -53,13 +53,13 @@ def generate_step(out: object,
53
  logits = out.logits[:, gen_idx]
54
  warpers = LogitsProcessorList()
55
  if temperature:
56
- warpers += [TemperatureLogitsWarper(temperature)]
57
  if top_k > 0:
58
- warpers += [TopKLogitsWarper(top_k)]
59
  if typical_p > 0:
60
  if typical_p >= 1:
61
  typical_p = 0.999
62
- warpers += [TypicalLogitsWarper(typical_p)]
63
  logits = warpers(None, logits)
64
 
65
  if sample:
 
41
 
42
  args:
43
  - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
44
+ - gen_idx (int): location for which to generate
45
  - top_k (int): if >0, only sample from the top k most probable words
46
  - temperature (float): sampling temperature
47
  - typical_p (float): if >0 use typical sampling
 
53
  logits = out.logits[:, gen_idx]
54
  warpers = LogitsProcessorList()
55
  if temperature:
56
+ warpers.append(TemperatureLogitsWarper(temperature))
57
  if top_k > 0:
58
+ warpers.append(TopKLogitsWarper(top_k))
59
  if typical_p > 0:
60
  if typical_p >= 1:
61
  typical_p = 0.999
62
+ warpers.append(TypicalLogitsWarper(typical_p))
63
  logits = warpers(None, logits)
64
 
65
  if sample: