Spaces:
Running
Running
Support for NLLB and Sampling
Browse files- README.md +81 -9
- supported_languages.md +214 -2
- translate.py +54 -7
README.md
CHANGED
@@ -13,11 +13,7 @@
|
|
13 |
<br>
|
14 |
</p>
|
15 |
|
16 |
-
Easy-Translate is a script for translating large text files in your machine using the [M2M100 models](https://arxiv.org/pdf/2010.11125.pdf) from Facebook/Meta AI. We also privide a [script](#evaluate-translations) for Easy-Evaluation of your translations 🥳
|
17 |
-
|
18 |
-
**M2M100** is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
|
19 |
-
|
20 |
-
>M2M100 can directly translate between 9,900 directions of 100 languages.
|
21 |
|
22 |
Easy-Translate is built on top of 🤗HuggingFace's [Transformers](https://huggingface.co/docs/transformers/index) and 🤗HuggingFace's [Accelerate](https://huggingface.co/docs/accelerate/index) library.
|
23 |
|
@@ -27,26 +23,43 @@ We currently support:
|
|
27 |
- BF16 / FP16 / FP32 precision.
|
28 |
- Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
|
29 |
- Sharded Data Parallel to load huge models sharded on multiple GPUs (See: <https://huggingface.co/docs/accelerate/fsdp>).
|
|
|
30 |
|
31 |
>Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
|
32 |
|
33 |
|
|
|
34 |
## Supported languages
|
35 |
|
36 |
See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
|
37 |
|
38 |
-
**List of supported languages:**
|
39 |
-
Afrikaans, Amharic, Arabic, Asturian, Azerbaijani, Bashkir, Belarusian, Bulgarian, Bengali, Breton, Bosnian, Catalan, Cebuano, Czech, Welsh, Danish, German, Greeek, English, Spanish, Estonian, Persian, Fulah, Finnish, French, WesternFrisian, Irish, Gaelic, Galician, Gujarati, Hausa, Hebrew, Hindi, Croatian, Haitian, Hungarian, Armenian, Indonesian, Igbo, Iloko, Icelandic, Italian, Japanese, Javanese, Georgian, Kazakh, CentralKhmer, Kannada, Korean, Luxembourgish, Ganda, Lingala, Lao, Lithuanian, Latvian, Malagasy, Macedonian, Malayalam, Mongolian, Marathi, Malay, Burmese, Nepali, Dutch, Norwegian, NorthernSotho, Occitan, Oriya, Panjabi, Polish, Pushto, Portuguese, Romanian, Russian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Albanian, Serbian, Swati, Sundanese, Swedish, Swahili, Tamil, Thai, Tagalog, Tswana, Turkish, Ukrainian, Urdu, Uzbek, Vietnamese, Wolof, Xhosa, Yiddish, Yoruba, Chinese, Zulu
|
40 |
-
|
41 |
## Supported Models
|
42 |
|
|
|
|
|
|
|
|
|
43 |
- **Facebook/m2m100_418M**: <https://huggingface.co/facebook/m2m100_418M>
|
44 |
|
45 |
- **Facebook/m2m100_1.2B**: <https://huggingface.co/facebook/m2m100_1.2B>
|
46 |
|
47 |
- **Facebook/m2m100_12B**: <https://huggingface.co/facebook/m2m100-12B-avg-5-ckpt>
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
## Requirements
|
52 |
|
@@ -59,6 +72,9 @@ pip install --upgrade accelerate
|
|
59 |
|
60 |
HuggingFace Transformers
|
61 |
pip install --upgrade transformers
|
|
|
|
|
|
|
62 |
```
|
63 |
|
64 |
## Translate a file
|
@@ -109,6 +125,62 @@ accelerate launch translate.py \
|
|
109 |
--precision fp16
|
110 |
```
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
## Evaluate translations
|
113 |
|
114 |
To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score` and 🤗HuggingFace's [Datasets](https://huggingface.co/docs/datasets/index) model: `pip install datasets`.
|
|
|
13 |
<br>
|
14 |
</p>
|
15 |
|
16 |
+
Easy-Translate is a script for translating large text files in your machine using the [M2M100 models](https://arxiv.org/pdf/2010.11125.pdf) and [NLLB200 models](https://research.facebook.com/publications/no-language-left-behind/) from Facebook/Meta AI. We also privide a [script](#evaluate-translations) for Easy-Evaluation of your translations 🥳
|
|
|
|
|
|
|
|
|
17 |
|
18 |
Easy-Translate is built on top of 🤗HuggingFace's [Transformers](https://huggingface.co/docs/transformers/index) and 🤗HuggingFace's [Accelerate](https://huggingface.co/docs/accelerate/index) library.
|
19 |
|
|
|
23 |
- BF16 / FP16 / FP32 precision.
|
24 |
- Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
|
25 |
- Sharded Data Parallel to load huge models sharded on multiple GPUs (See: <https://huggingface.co/docs/accelerate/fsdp>).
|
26 |
+
- Greedy decoding / Beam Search decoding / Multinomial Sampling / Beam-Search Multinomial Sampling
|
27 |
|
28 |
>Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
|
29 |
|
30 |
|
31 |
+
|
32 |
## Supported languages
|
33 |
|
34 |
See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
|
35 |
|
|
|
|
|
|
|
36 |
## Supported Models
|
37 |
|
38 |
+
### M2M100
|
39 |
+
**M2M100** is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
|
40 |
+
>M2M100 can directly translate between 9,900 directions of 100 languages.
|
41 |
+
|
42 |
- **Facebook/m2m100_418M**: <https://huggingface.co/facebook/m2m100_418M>
|
43 |
|
44 |
- **Facebook/m2m100_1.2B**: <https://huggingface.co/facebook/m2m100_1.2B>
|
45 |
|
46 |
- **Facebook/m2m100_12B**: <https://huggingface.co/facebook/m2m100-12B-avg-5-ckpt>
|
47 |
|
48 |
+
### NLLB200
|
49 |
+
|
50 |
+
**No Language Left Behind (NLLB)** open-sources models capable of delivering high-quality translations directly between any pair of 200+ languages — including low-resource languages like Asturian, Luganda, Urdu and more. It aims to help people communicate with anyone, anywhere, regardless of their language preferences. It was introduced in this [paper](https://research.facebook.com/publications/no-language-left-behind/) and first released in [this](https://github.com/facebookresearch/fairseq/tree/nllb) repository.
|
51 |
+
>NLLB can directly translate between +40,000 of +200 languages.
|
52 |
+
|
53 |
+
- **facebook/nllb-200-3.3B**: <https://huggingface.co/facebook/nllb-200-3.3B>
|
54 |
+
|
55 |
+
- **facebook/nllb-200-1.3B**: <https://huggingface.co/facebook/nllb-200-1.3B>
|
56 |
+
|
57 |
+
- **facebook/nllb-200-distilled-1.3B**: <https://huggingface.co/facebook/nllb-200-distilled-1.3B>
|
58 |
+
|
59 |
+
- **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>
|
60 |
+
|
61 |
+
|
62 |
+
Any other ModelForSeq2SeqLM from HuggingFace's Hub should work with this library: <https://huggingface.co/models?pipeline_tag=text2text-generation>
|
63 |
|
64 |
## Requirements
|
65 |
|
|
|
72 |
|
73 |
HuggingFace Transformers
|
74 |
pip install --upgrade transformers
|
75 |
+
|
76 |
+
If you find errors using NLLB200, try installing transformers from source:
|
77 |
+
pip install git+https://github.com/huggingface/transformers.git
|
78 |
```
|
79 |
|
80 |
## Translate a file
|
|
|
125 |
--precision fp16
|
126 |
```
|
127 |
|
128 |
+
### Decoding/Sampling strategies
|
129 |
+
|
130 |
+
You can choose the decoding/sampling strategy to use and the number of candidate translation to output for each input sentence. By default we will use beam-search with 'num_beams' set to 5, and we will output the most likely candidate translation. But you can change this behavior:
|
131 |
+
##### Greedy decoding
|
132 |
+
```bash
|
133 |
+
accelerate launch translate.py \
|
134 |
+
--sentences_path sample_text/en.txt \
|
135 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
136 |
+
--source_lang en \
|
137 |
+
--target_lang es \
|
138 |
+
--model_name facebook/m2m100_1.2B \
|
139 |
+
--num_beams 1
|
140 |
+
```
|
141 |
+
|
142 |
+
##### Multinomial Sampling
|
143 |
+
```bash
|
144 |
+
accelerate launch translate.py \
|
145 |
+
--sentences_path sample_text/en.txt \
|
146 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
147 |
+
--source_lang en \
|
148 |
+
--target_lang es \
|
149 |
+
--model_name facebook/m2m100_1.2B \
|
150 |
+
--num_beams 1 \
|
151 |
+
--do_sample \
|
152 |
+
--temperature 0.5 \
|
153 |
+
--top_k 100 \
|
154 |
+
--top_p 0.8 \
|
155 |
+
--num_return_sequences 1
|
156 |
+
```
|
157 |
+
##### Beam-Search decoding **(DEFAULT)**
|
158 |
+
```bash
|
159 |
+
accelerate launch translate.py \
|
160 |
+
--sentences_path sample_text/en.txt \
|
161 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
162 |
+
--source_lang en \
|
163 |
+
--target_lang es \
|
164 |
+
--model_name facebook/m2m100_1.2B \
|
165 |
+
--num_beams 5 \
|
166 |
+
--num_return_sequences 1 \
|
167 |
+
```
|
168 |
+
##### Beam-Search Multinomial Sampling
|
169 |
+
```bash
|
170 |
+
accelerate launch translate.py \
|
171 |
+
--sentences_path sample_text/en.txt \
|
172 |
+
--output_path sample_text/en2es.translation.m2m100_1.2B.txt \
|
173 |
+
--source_lang en \
|
174 |
+
--target_lang es \
|
175 |
+
--model_name facebook/m2m100_1.2B \
|
176 |
+
--num_beams 5 \
|
177 |
+
--num_return_sequences 1 \
|
178 |
+
--do_sample \
|
179 |
+
--temperature 0.5 \
|
180 |
+
--top_k 100 \
|
181 |
+
--top_p 0.8
|
182 |
+
```
|
183 |
+
|
184 |
## Evaluate translations
|
185 |
|
186 |
To run the evaluation script you need to install [bert_score](https://github.com/Tiiiger/bert_score): `pip install bert_score` and 🤗HuggingFace's [Datasets](https://huggingface.co/docs/datasets/index) model: `pip install datasets`.
|
supported_languages.md
CHANGED
@@ -1,4 +1,10 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
| Language | Id |
|
4 |
|---|---|
|
@@ -101,4 +107,210 @@
|
|
101 |
| Yiddish | yi |
|
102 |
| Yoruba | yo |
|
103 |
| Chinese | zh |
|
104 |
-
| Zulu | zu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# List of supported languages
|
2 |
+
|
3 |
+
## Index
|
4 |
+
* [M2M100 supported languages](#supported-languages-m2m100)
|
5 |
+
* [NLLB200 supported languages](#supported-languages-nllb200)
|
6 |
+
|
7 |
+
## Supported languages M2M100
|
8 |
|
9 |
| Language | Id |
|
10 |
|---|---|
|
|
|
107 |
| Yiddish | yi |
|
108 |
| Yoruba | yo |
|
109 |
| Chinese | zh |
|
110 |
+
| Zulu | zu |
|
111 |
+
|
112 |
+
## Supported languages NLLB200
|
113 |
+
| Language id |
|
114 |
+
|-------------|
|
115 |
+
| ace_Arab |
|
116 |
+
| ace_Latn |
|
117 |
+
| acm_Arab |
|
118 |
+
| acq_Arab |
|
119 |
+
| aeb_Arab |
|
120 |
+
| afr_Latn |
|
121 |
+
| ajp_Arab |
|
122 |
+
| aka_Latn |
|
123 |
+
| amh_Ethi |
|
124 |
+
| apc_Arab |
|
125 |
+
| arb_Arab |
|
126 |
+
| ars_Arab |
|
127 |
+
| ary_Arab |
|
128 |
+
| arz_Arab |
|
129 |
+
| asm_Beng |
|
130 |
+
| ast_Latn |
|
131 |
+
| awa_Deva |
|
132 |
+
| ayr_Latn |
|
133 |
+
| azb_Arab |
|
134 |
+
| azj_Latn |
|
135 |
+
| bak_Cyrl |
|
136 |
+
| bam_Latn |
|
137 |
+
| ban_Latn |
|
138 |
+
| bel_Cyrl |
|
139 |
+
| bem_Latn |
|
140 |
+
| ben_Beng |
|
141 |
+
| bho_Deva |
|
142 |
+
| bjn_Arab |
|
143 |
+
| bjn_Latn |
|
144 |
+
| bod_Tibt |
|
145 |
+
| bos_Latn |
|
146 |
+
| bug_Latn |
|
147 |
+
| bul_Cyrl |
|
148 |
+
| cat_Latn |
|
149 |
+
| ceb_Latn |
|
150 |
+
| ces_Latn |
|
151 |
+
| cjk_Latn |
|
152 |
+
| ckb_Arab |
|
153 |
+
| crh_Latn |
|
154 |
+
| cym_Latn |
|
155 |
+
| dan_Latn |
|
156 |
+
| deu_Latn |
|
157 |
+
| dik_Latn |
|
158 |
+
| dyu_Latn |
|
159 |
+
| dzo_Tibt |
|
160 |
+
| ell_Grek |
|
161 |
+
| eng_Latn |
|
162 |
+
| epo_Latn |
|
163 |
+
| est_Latn |
|
164 |
+
| eus_Latn |
|
165 |
+
| ewe_Latn |
|
166 |
+
| fao_Latn |
|
167 |
+
| pes_Arab |
|
168 |
+
| fij_Latn |
|
169 |
+
| fin_Latn |
|
170 |
+
| fon_Latn |
|
171 |
+
| fra_Latn |
|
172 |
+
| fur_Latn |
|
173 |
+
| fuv_Latn |
|
174 |
+
| gla_Latn |
|
175 |
+
| gle_Latn |
|
176 |
+
| glg_Latn |
|
177 |
+
| grn_Latn |
|
178 |
+
| guj_Gujr |
|
179 |
+
| hat_Latn |
|
180 |
+
| hau_Latn |
|
181 |
+
| heb_Hebr |
|
182 |
+
| hin_Deva |
|
183 |
+
| hne_Deva |
|
184 |
+
| hrv_Latn |
|
185 |
+
| hun_Latn |
|
186 |
+
| hye_Armn |
|
187 |
+
| ibo_Latn |
|
188 |
+
| ilo_Latn |
|
189 |
+
| ind_Latn |
|
190 |
+
| isl_Latn |
|
191 |
+
| ita_Latn |
|
192 |
+
| jav_Latn |
|
193 |
+
| jpn_Jpan |
|
194 |
+
| kab_Latn |
|
195 |
+
| kac_Latn |
|
196 |
+
| kam_Latn |
|
197 |
+
| kan_Knda |
|
198 |
+
| kas_Arab |
|
199 |
+
| kas_Deva |
|
200 |
+
| kat_Geor |
|
201 |
+
| knc_Arab |
|
202 |
+
| knc_Latn |
|
203 |
+
| kaz_Cyrl |
|
204 |
+
| kbp_Latn |
|
205 |
+
| kea_Latn |
|
206 |
+
| khm_Khmr |
|
207 |
+
| kik_Latn |
|
208 |
+
| kin_Latn |
|
209 |
+
| kir_Cyrl |
|
210 |
+
| kmb_Latn |
|
211 |
+
| kon_Latn |
|
212 |
+
| kor_Hang |
|
213 |
+
| kmr_Latn |
|
214 |
+
| lao_Laoo |
|
215 |
+
| lvs_Latn |
|
216 |
+
| lij_Latn |
|
217 |
+
| lim_Latn |
|
218 |
+
| lin_Latn |
|
219 |
+
| lit_Latn |
|
220 |
+
| lmo_Latn |
|
221 |
+
| ltg_Latn |
|
222 |
+
| ltz_Latn |
|
223 |
+
| lua_Latn |
|
224 |
+
| lug_Latn |
|
225 |
+
| luo_Latn |
|
226 |
+
| lus_Latn |
|
227 |
+
| mag_Deva |
|
228 |
+
| mai_Deva |
|
229 |
+
| mal_Mlym |
|
230 |
+
| mar_Deva |
|
231 |
+
| min_Latn |
|
232 |
+
| mkd_Cyrl |
|
233 |
+
| plt_Latn |
|
234 |
+
| mlt_Latn |
|
235 |
+
| mni_Beng |
|
236 |
+
| khk_Cyrl |
|
237 |
+
| mos_Latn |
|
238 |
+
| mri_Latn |
|
239 |
+
| zsm_Latn |
|
240 |
+
| mya_Mymr |
|
241 |
+
| nld_Latn |
|
242 |
+
| nno_Latn |
|
243 |
+
| nob_Latn |
|
244 |
+
| npi_Deva |
|
245 |
+
| nso_Latn |
|
246 |
+
| nus_Latn |
|
247 |
+
| nya_Latn |
|
248 |
+
| oci_Latn |
|
249 |
+
| gaz_Latn |
|
250 |
+
| ory_Orya |
|
251 |
+
| pag_Latn |
|
252 |
+
| pan_Guru |
|
253 |
+
| pap_Latn |
|
254 |
+
| pol_Latn |
|
255 |
+
| por_Latn |
|
256 |
+
| prs_Arab |
|
257 |
+
| pbt_Arab |
|
258 |
+
| quy_Latn |
|
259 |
+
| ron_Latn |
|
260 |
+
| run_Latn |
|
261 |
+
| rus_Cyrl |
|
262 |
+
| sag_Latn |
|
263 |
+
| san_Deva |
|
264 |
+
| sat_Beng |
|
265 |
+
| scn_Latn |
|
266 |
+
| shn_Mymr |
|
267 |
+
| sin_Sinh |
|
268 |
+
| slk_Latn |
|
269 |
+
| slv_Latn |
|
270 |
+
| smo_Latn |
|
271 |
+
| sna_Latn |
|
272 |
+
| snd_Arab |
|
273 |
+
| som_Latn |
|
274 |
+
| sot_Latn |
|
275 |
+
| spa_Latn |
|
276 |
+
| als_Latn |
|
277 |
+
| srd_Latn |
|
278 |
+
| srp_Cyrl |
|
279 |
+
| ssw_Latn |
|
280 |
+
| sun_Latn |
|
281 |
+
| swe_Latn |
|
282 |
+
| swh_Latn |
|
283 |
+
| szl_Latn |
|
284 |
+
| tam_Taml |
|
285 |
+
| tat_Cyrl |
|
286 |
+
| tel_Telu |
|
287 |
+
| tgk_Cyrl |
|
288 |
+
| tgl_Latn |
|
289 |
+
| tha_Thai |
|
290 |
+
| tir_Ethi |
|
291 |
+
| taq_Latn |
|
292 |
+
| taq_Tfng |
|
293 |
+
| tpi_Latn |
|
294 |
+
| tsn_Latn |
|
295 |
+
| tso_Latn |
|
296 |
+
| tuk_Latn |
|
297 |
+
| tum_Latn |
|
298 |
+
| tur_Latn |
|
299 |
+
| twi_Latn |
|
300 |
+
| tzm_Tfng |
|
301 |
+
| uig_Arab |
|
302 |
+
| ukr_Cyrl |
|
303 |
+
| umb_Latn |
|
304 |
+
| urd_Arab |
|
305 |
+
| uzn_Latn |
|
306 |
+
| vec_Latn |
|
307 |
+
| vie_Latn |
|
308 |
+
| war_Latn |
|
309 |
+
| wol_Latn |
|
310 |
+
| xho_Latn |
|
311 |
+
| ydd_Hebr |
|
312 |
+
| yor_Latn |
|
313 |
+
| yue_Hant |
|
314 |
+
| zho_Hans |
|
315 |
+
| zho_Hant |
|
316 |
+
| zul_Latn |
|
translate.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from transformers import (
|
2 |
-
|
3 |
-
|
4 |
PreTrainedTokenizerBase,
|
5 |
DataCollatorForSeq2Seq,
|
6 |
)
|
@@ -60,6 +60,10 @@ def main(
|
|
60 |
max_length: int = 128,
|
61 |
num_beams: int = 4,
|
62 |
num_return_sequences: int = 1,
|
|
|
|
|
|
|
|
|
63 |
):
|
64 |
|
65 |
if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
|
@@ -70,11 +74,11 @@ def main(
|
|
70 |
)
|
71 |
|
72 |
print(f"Loading tokenizer {model_name}...")
|
73 |
-
tokenizer =
|
74 |
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
|
75 |
)
|
76 |
print(f"Loading model {model_name}...")
|
77 |
-
model =
|
78 |
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
|
79 |
)
|
80 |
|
@@ -92,12 +96,21 @@ def main(
|
|
92 |
raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
|
93 |
|
94 |
tokenizer.src_lang = source_lang
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
gen_kwargs = {
|
98 |
"max_length": max_length,
|
99 |
"num_beams": num_beams,
|
100 |
"num_return_sequences": num_return_sequences,
|
|
|
|
|
|
|
|
|
101 |
}
|
102 |
|
103 |
# total_lines: int = count_lines(sentences_path)
|
@@ -114,10 +127,12 @@ def main(
|
|
114 |
f"Num. Devices: {accelerator.num_processes}\n"
|
115 |
f"Distributed_type: {accelerator.distributed_type}\n"
|
116 |
f"Max length: {max_length}\n"
|
117 |
-
f"Num beams: {num_beams}\n"
|
118 |
f"Precision: {model.dtype}\n"
|
119 |
f"Model: {model_name}\n"
|
120 |
)
|
|
|
|
|
|
|
121 |
|
122 |
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
123 |
def inference(batch_size):
|
@@ -167,7 +182,8 @@ def main(
|
|
167 |
if accelerator.is_main_process:
|
168 |
if step == len(data_loader) - 1:
|
169 |
tgt_text = tgt_text[
|
170 |
-
: len(data_loader.dataset)
|
|
|
171 |
]
|
172 |
else:
|
173 |
samples_seen += len(tgt_text)
|
@@ -262,6 +278,33 @@ if __name__ == "__main__":
|
|
262 |
help="Precision of the model. bf16, fp16 or 32.",
|
263 |
)
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
args = parser.parse_args()
|
266 |
|
267 |
main(
|
@@ -276,4 +319,8 @@ if __name__ == "__main__":
|
|
276 |
num_beams=args.num_beams,
|
277 |
num_return_sequences=args.num_return_sequences,
|
278 |
precision=args.precision,
|
|
|
|
|
|
|
|
|
279 |
)
|
|
|
1 |
from transformers import (
|
2 |
+
AutoModelForSeq2SeqLM,
|
3 |
+
AutoTokenizer,
|
4 |
PreTrainedTokenizerBase,
|
5 |
DataCollatorForSeq2Seq,
|
6 |
)
|
|
|
60 |
max_length: int = 128,
|
61 |
num_beams: int = 4,
|
62 |
num_return_sequences: int = 1,
|
63 |
+
do_sample: bool = False,
|
64 |
+
temperature: float = 1.0,
|
65 |
+
top_k: int = 50,
|
66 |
+
top_p: float = 1.0,
|
67 |
):
|
68 |
|
69 |
if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
|
|
|
74 |
)
|
75 |
|
76 |
print(f"Loading tokenizer {model_name}...")
|
77 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
78 |
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
|
79 |
)
|
80 |
print(f"Loading model {model_name}...")
|
81 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
82 |
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
|
83 |
)
|
84 |
|
|
|
96 |
raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
|
97 |
|
98 |
tokenizer.src_lang = source_lang
|
99 |
+
try:
|
100 |
+
lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
|
101 |
+
except KeyError:
|
102 |
+
raise KeyError(
|
103 |
+
f"Language {target_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}"
|
104 |
+
)
|
105 |
|
106 |
gen_kwargs = {
|
107 |
"max_length": max_length,
|
108 |
"num_beams": num_beams,
|
109 |
"num_return_sequences": num_return_sequences,
|
110 |
+
"do_sample": do_sample,
|
111 |
+
"temperature": temperature,
|
112 |
+
"top_k": top_k,
|
113 |
+
"top_p": top_p,
|
114 |
}
|
115 |
|
116 |
# total_lines: int = count_lines(sentences_path)
|
|
|
127 |
f"Num. Devices: {accelerator.num_processes}\n"
|
128 |
f"Distributed_type: {accelerator.distributed_type}\n"
|
129 |
f"Max length: {max_length}\n"
|
|
|
130 |
f"Precision: {model.dtype}\n"
|
131 |
f"Model: {model_name}\n"
|
132 |
)
|
133 |
+
print("** Generation parameters **")
|
134 |
+
print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
|
135 |
+
print("\n")
|
136 |
|
137 |
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
138 |
def inference(batch_size):
|
|
|
182 |
if accelerator.is_main_process:
|
183 |
if step == len(data_loader) - 1:
|
184 |
tgt_text = tgt_text[
|
185 |
+
: len(data_loader.dataset) * num_return_sequences
|
186 |
+
- samples_seen
|
187 |
]
|
188 |
else:
|
189 |
samples_seen += len(tgt_text)
|
|
|
278 |
help="Precision of the model. bf16, fp16 or 32.",
|
279 |
)
|
280 |
|
281 |
+
parser.add_argument(
|
282 |
+
"--do_sample",
|
283 |
+
action="store_true",
|
284 |
+
help="Use sampling instead of beam search.",
|
285 |
+
)
|
286 |
+
|
287 |
+
parser.add_argument(
|
288 |
+
"--temperature",
|
289 |
+
type=float,
|
290 |
+
default=1.0,
|
291 |
+
help="Temperature for sampling, value used only if do_sample is True.",
|
292 |
+
)
|
293 |
+
|
294 |
+
parser.add_argument(
|
295 |
+
"--top_k",
|
296 |
+
type=int,
|
297 |
+
default=50,
|
298 |
+
help="If do_sample is True, will sample from the top k most likely tokens.",
|
299 |
+
)
|
300 |
+
|
301 |
+
parser.add_argument(
|
302 |
+
"--top_p",
|
303 |
+
type=float,
|
304 |
+
default=1.0,
|
305 |
+
help="If do_sample is True, will sample from the top k most likely tokens.",
|
306 |
+
)
|
307 |
+
|
308 |
args = parser.parse_args()
|
309 |
|
310 |
main(
|
|
|
319 |
num_beams=args.num_beams,
|
320 |
num_return_sequences=args.num_return_sequences,
|
321 |
precision=args.precision,
|
322 |
+
do_sample=args.do_sample,
|
323 |
+
temperature=args.temperature,
|
324 |
+
top_k=args.top_k,
|
325 |
+
top_p=args.top_p,
|
326 |
)
|