dqnguyen commited on
Commit
543657e
1 Parent(s): d605930

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -60,11 +60,15 @@ def translate_vi2en(vi_text: str) -> str:
60
  input_ids = tokenizer_vi2en(vi_text, return_tensors="pt").input_ids
61
  output_ids = model_vi2en.generate(
62
  input_ids,
63
- do_sample=True,
64
- top_k=100,
65
- top_p=0.8,
66
  decoder_start_token_id=tokenizer_vi2en.lang_code_to_id["en_XX"],
67
  num_return_sequences=1,
 
 
 
 
 
 
 
68
  )
69
  en_text = tokenizer_vi2en.batch_decode(output_ids, skip_special_tokens=True)
70
  en_text = " ".join(en_text)
@@ -77,11 +81,15 @@ def translate_en2vi(en_text: str) -> str:
77
  input_ids = tokenizer_en2vi(en_text, return_tensors="pt").input_ids
78
  output_ids = model_en2vi.generate(
79
  input_ids,
80
- do_sample=True,
81
- top_k=100,
82
- top_p=0.8,
83
  decoder_start_token_id=tokenizer_en2vi.lang_code_to_id["vi_VN"],
84
  num_return_sequences=1,
 
 
 
 
 
 
 
85
  )
86
  vi_text = tokenizer_en2vi.batch_decode(output_ids, skip_special_tokens=True)
87
  vi_text = " ".join(vi_text)
 
60
  input_ids = tokenizer_vi2en(vi_text, return_tensors="pt").input_ids
61
  output_ids = model_vi2en.generate(
62
  input_ids,
 
 
 
63
  decoder_start_token_id=tokenizer_vi2en.lang_code_to_id["en_XX"],
64
  num_return_sequences=1,
65
+ # # With sampling
66
+ # do_sample=True,
67
+ # top_k=100,
68
+ # top_p=0.8,
69
+ # With beam search
70
+ num_beams=5,
71
+ early_stopping=True
72
  )
73
  en_text = tokenizer_vi2en.batch_decode(output_ids, skip_special_tokens=True)
74
  en_text = " ".join(en_text)
 
81
  input_ids = tokenizer_en2vi(en_text, return_tensors="pt").input_ids
82
  output_ids = model_en2vi.generate(
83
  input_ids,
 
 
 
84
  decoder_start_token_id=tokenizer_en2vi.lang_code_to_id["vi_VN"],
85
  num_return_sequences=1,
86
+ # # With sampling
87
+ # do_sample=True,
88
+ # top_k=100,
89
+ # top_p=0.8,
90
+ # With beam search
91
+ num_beams=5,
92
+ early_stopping=True
93
  )
94
  vi_text = tokenizer_en2vi.batch_decode(output_ids, skip_special_tokens=True)
95
  vi_text = " ".join(vi_text)