bragovo commited on
Commit
c95a95b
1 Parent(s): d902713

Create TRAIN.md

Browse files
Files changed (1) hide show
  1. TRAIN.md +97 -0
TRAIN.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://colab.research.google.com/drive/1rT472vEOPjYCKdZ1CEg0IWm-h0dployN?usp=sharing
2
+
3
+ ```python
4
+
5
+ !pip install datasets transformers evaluate wandb py7zr sentencepiece huggingface_hub rouge_score accelerate
6
+
7
+ import wandb
8
+ wandb.login()
9
+
10
+
11
+ from huggingface_hub import interpreter_login
12
+
13
+ interpreter_login()
14
+
15
+ from datasets import interleave_datasets, load_dataset
16
+
17
+ samsum_dataset = load_dataset("bragovo/dsum_en", split="train")
18
+ samsum_ru_dataset = load_dataset("bragovo/dsum_ru", split="train")
19
+
20
+ dataset = interleave_datasets([samsum_dataset, samsum_ru_dataset])
21
+ dataset = dataset.train_test_split(test_size=0.2)
22
+
23
+ from transformers import AutoTokenizer
24
+
25
+ checkpoint = "cointegrated/rut5-base-multitask"
26
+ # checkpoint = "t5-small"
27
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, legacy=False)
28
+
29
+ prefix = "summarize: "
30
+
31
+ def preprocess_function(examples):
32
+ inputs = [prefix + doc for doc in examples["dialogue"]]
33
+
34
+ model_inputs = tokenizer(inputs)
35
+ labels = tokenizer(text_target=examples["summary"])
36
+
37
+ model_inputs["labels"] = labels["input_ids"]
38
+ return model_inputs
39
+
40
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
41
+
42
+ from transformers import DataCollatorForSeq2Seq
43
+
44
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
45
+
46
+ import evaluate
47
+
48
+ rouge = evaluate.load("rouge")
49
+
50
+ import numpy as np
51
+
52
+ def compute_metrics(eval_pred):
53
+ predictions, labels = eval_pred
54
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
55
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
56
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
57
+
58
+ result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
59
+
60
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
61
+ result["gen_len"] = np.mean(prediction_lens)
62
+
63
+ return {k: round(v, 4) for k, v in result.items()}
64
+
65
+
66
+ from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
67
+
68
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
69
+
70
+ training_args = Seq2SeqTrainingArguments(
71
+ output_dir="bragovo/flux-mt5-base-multitask-model",
72
+ evaluation_strategy="epoch",
73
+ learning_rate=2e-5,
74
+ per_device_train_batch_size=4,
75
+ per_device_eval_batch_size=4,
76
+ weight_decay=0.01,
77
+ save_total_limit=3,
78
+ num_train_epochs=4,
79
+ predict_with_generate=True,
80
+ fp16=True,
81
+ push_to_hub=True,
82
+ )
83
+
84
+ trainer = Seq2SeqTrainer(
85
+ model=model,
86
+ args=training_args,
87
+ train_dataset=tokenized_dataset["train"],
88
+ eval_dataset=tokenized_dataset["test"],
89
+ tokenizer=tokenizer,
90
+ data_collator=data_collator,
91
+ compute_metrics=compute_metrics,
92
+ )
93
+
94
+ trainer.train()
95
+ trainer.push_to_hub()
96
+
97
+ ```