mxmax commited on
Commit
0dec4f2
·
1 Parent(s): 480bc06

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -53,13 +53,13 @@ model_trained.to(device)
53
  def postprocess(text):
54
  return text.replace(".", "").replace('</>','')
55
 
56
- def answer_fn(text, sample=False, top_p=0.6):
57
  encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
58
  out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_length=512,temperature=0.5,do_sample=True,repetition_penalty=6.0 ,top_p=top_p)
59
  result = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
60
  return postprocess(result[0])
61
  text="宫颈癌的早期会有哪些危险信号"
62
- result=answer_fn(text, sample=True, top_p=0.6)
63
  print('prompt:',text)
64
  print("result:",result)
65
  ```
 
53
  def postprocess(text):
54
  return text.replace(".", "").replace('</>','')
55
 
56
+ def answer_fn(text, top_p=0.6):
57
  encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
58
  out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_length=512,temperature=0.5,do_sample=True,repetition_penalty=6.0 ,top_p=top_p)
59
  result = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
60
  return postprocess(result[0])
61
  text="宫颈癌的早期会有哪些危险信号"
62
+ result=answer_fn(text, top_p=0.6)
63
  print('prompt:',text)
64
  print("result:",result)
65
  ```