Update README.md
Browse files
README.md
CHANGED
@@ -11,7 +11,9 @@ tags:
|
|
11 |
Метрики обучения bleu:100.0 sari:28.699 fkgl:31.931 (из файла "train.logs")
|
12 |
|
13 |
```
|
14 |
-
input_text='Война Советского Союза против фашистской Германии и её союзников
|
|
|
|
|
15 |
|
16 |
def example(source, model, tokenizer):
|
17 |
"""
|
@@ -24,7 +26,9 @@ def example(source, model, tokenizer):
|
|
24 |
print(f'SOURCE: {source}')
|
25 |
input_ids, attention_mask = tokenizer(source, return_tensors = 'pt').values()
|
26 |
with torch.no_grad():
|
27 |
-
output = model.generate(input_ids = input_ids.to(model.device),
|
|
|
|
|
28 |
return tokenizer.decode(output.squeeze(0), skip_special_tokens = True)
|
29 |
|
30 |
example(input_text, model, tokenizer)
|
|
|
11 |
Метрики обучения bleu:100.0 sari:28.699 fkgl:31.931 (из файла "train.logs")
|
12 |
|
13 |
```
|
14 |
+
input_text='''Война Советского Союза против фашистской Германии и её союзников
|
15 |
+
(Венгрии, Италии, Румынии, Словакии, Хорватии, Финляндии, Японии);
|
16 |
+
составная часть Второй мировой войны 1939-1945 гг.'''
|
17 |
|
18 |
def example(source, model, tokenizer):
|
19 |
"""
|
|
|
26 |
print(f'SOURCE: {source}')
|
27 |
input_ids, attention_mask = tokenizer(source, return_tensors = 'pt').values()
|
28 |
with torch.no_grad():
|
29 |
+
output = model.generate(input_ids = input_ids.to(model.device),
|
30 |
+
attention_mask = attention_mask.to(model.device),
|
31 |
+
max_new_tokens=input_ids.size(1)*2, min_length=0)
|
32 |
return tokenizer.decode(output.squeeze(0), skip_special_tokens = True)
|
33 |
|
34 |
example(input_text, model, tokenizer)
|