kisejin
commited on
Commit
·
671e27a
1
Parent(s):
808a032
change: update skipbert mechanism
Browse files- template_FL/src/fedllm/client_app.py +45 -18
- template_FL/src/fedllm/dataset.py +1 -1
- template_FL/src/fedllm/models.py +66 -10
- template_FL/src/fedllm/server_app.py +97 -8
- template_FL/src/fedllm/skipbert/modeling.py +1 -1
- template_FL/src/fedllm/skipbert/plot.py +2 -2
- template_FL/src/fedllm/skipbert/trainer.py +691 -0
- template_FL/src/fedllm/trainer.py +252 -54
- template_FL/src/pyproject.toml +47 -48
template_FL/src/fedllm/client_app.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
|
3 |
import os
|
4 |
import warnings
|
|
|
5 |
from typing import Dict, Tuple
|
6 |
|
7 |
import torch
|
@@ -13,8 +14,16 @@ from flwr.common.config import unflatten_dict
|
|
13 |
from flwr.common.typing import NDArrays, Scalar
|
14 |
from omegaconf import DictConfig
|
15 |
|
16 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
|
|
18 |
from trl import SFTTrainer, SFTConfig
|
19 |
from deepspeed.profiling.flops_profiler import get_model_profile
|
20 |
from deepspeed.accelerator import get_accelerator
|
@@ -40,6 +49,11 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
40 |
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
41 |
warnings.filterwarnings("ignore", category=UserWarning)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def input_constructor(batch_size, seq_len, tokenizer):
|
45 |
fake_seq = ""
|
@@ -80,13 +94,15 @@ class FlowerClient(NumPyClient):
|
|
80 |
self,
|
81 |
model_cfg: DictConfig,
|
82 |
train_cfg: DictConfig,
|
83 |
-
|
|
|
84 |
trainset,
|
85 |
valset,
|
86 |
num_rounds,
|
87 |
): # pylint: disable=too-many-arguments
|
88 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
89 |
self.train_cfg = train_cfg
|
|
|
90 |
|
91 |
self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
|
92 |
# self.training_arguments = SFTConfig(**train_cfg.training_arguments, max_seq_length=train_cfg.seq_length)
|
@@ -94,17 +110,18 @@ class FlowerClient(NumPyClient):
|
|
94 |
self.num_rounds = num_rounds
|
95 |
self.trainset = trainset
|
96 |
self.valset = valset
|
97 |
-
self.
|
98 |
self.holdoutset = None
|
99 |
self.refset = None
|
100 |
-
self.
|
|
|
101 |
self.data_influence_tokenizer = None
|
102 |
|
103 |
# instantiate model
|
104 |
self.model, self.tokenizer = get_model(model_cfg)
|
105 |
|
106 |
-
if self.
|
107 |
-
self.
|
108 |
|
109 |
# (
|
110 |
# self.data_collator,
|
@@ -129,8 +146,8 @@ class FlowerClient(NumPyClient):
|
|
129 |
# Replace -100 with pad token id in labels
|
130 |
labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
|
131 |
|
132 |
-
print(f"Shape of predictions: {np.shape(pred_ids)}")
|
133 |
-
print(f"Shape of labels: {np.shape(labels_ids)}")
|
134 |
|
135 |
# Decode predictions and labels
|
136 |
pred_str = self.tokenizer.batch_decode(
|
@@ -165,6 +182,7 @@ class FlowerClient(NumPyClient):
|
|
165 |
.map(
|
166 |
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
167 |
num_proc=8,
|
|
|
168 |
)
|
169 |
)
|
170 |
|
@@ -175,16 +193,17 @@ class FlowerClient(NumPyClient):
|
|
175 |
.map(
|
176 |
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
177 |
num_proc=8,
|
|
|
178 |
)
|
179 |
)
|
180 |
|
181 |
# Create holdoutset and refset if state is True
|
182 |
-
if self.
|
183 |
trainset_size = len(self.trainset)
|
184 |
|
185 |
# Calculate sizes for holdout and reference sets
|
186 |
-
holdout_size = int(trainset_size * self.
|
187 |
-
ref_size = int(trainset_size * self.
|
188 |
|
189 |
# Shuffle the trainset to ensure randomness
|
190 |
shuffled_indices = list(range(trainset_size))
|
@@ -199,16 +218,17 @@ class FlowerClient(NumPyClient):
|
|
199 |
self.refset = self.trainset.select(ref_indices)
|
200 |
|
201 |
print(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
|
|
|
202 |
|
203 |
|
204 |
def fit(
|
205 |
self, parameters: NDArrays, config: Dict[str, Scalar]
|
206 |
) -> Tuple[NDArrays, int, Dict]:
|
207 |
"""Implement distributed fit function for a given client."""
|
208 |
-
if self.
|
209 |
main_model_params, data_influence_model_params = split_models(parameters)
|
210 |
set_parameters(self.model, main_model_params)
|
211 |
-
set_parameters_bert(self.
|
212 |
else:
|
213 |
set_parameters(self.model, parameters)
|
214 |
|
@@ -259,18 +279,20 @@ class FlowerClient(NumPyClient):
|
|
259 |
args=self.training_arguments,
|
260 |
data_collator=self.data_collator,
|
261 |
compute_metrics=self.compute_metrics,
|
262 |
-
|
263 |
-
|
|
|
|
|
264 |
data_influence_tokenizer=self.data_influence_tokenizer,
|
265 |
)
|
266 |
|
267 |
# Train the model
|
268 |
results = trainer.train()
|
269 |
|
270 |
-
if self.
|
271 |
# After training
|
272 |
main_model_params = get_parameters(self.model)
|
273 |
-
data_influence_model_params = model_parameters_to_ndarrays(self.
|
274 |
final_model_params = concatenate_models_with_marker(main_model_params, data_influence_model_params)
|
275 |
else:
|
276 |
final_model_params = get_parameters(self.model)
|
@@ -286,7 +308,7 @@ class FlowerClient(NumPyClient):
|
|
286 |
detailed=False,
|
287 |
)
|
288 |
flops2, macs2, params2 = get_model_profile(
|
289 |
-
self.
|
290 |
kwargs=input_constructor(batch_size, seq_len, self.data_influence_tokenizer),
|
291 |
print_profile=True,
|
292 |
detailed=False,
|
@@ -315,11 +337,16 @@ def client_fn(context: Context) -> FlowerClient:
|
|
315 |
client_set = load_data_homo(partition_id, num_partitions, cfg.dataset.name)
|
316 |
else:
|
317 |
client_set = load_data_hete(partition_id)
|
|
|
|
|
|
|
|
|
318 |
|
319 |
return FlowerClient(
|
320 |
cfg.model,
|
321 |
cfg.train,
|
322 |
cfg.mates,
|
|
|
323 |
client_set['train'],
|
324 |
client_set['test'],
|
325 |
num_rounds,
|
|
|
2 |
|
3 |
import os
|
4 |
import warnings
|
5 |
+
import logging
|
6 |
from typing import Dict, Tuple
|
7 |
|
8 |
import torch
|
|
|
14 |
from flwr.common.typing import NDArrays, Scalar
|
15 |
from omegaconf import DictConfig
|
16 |
|
17 |
+
from transformers import (
|
18 |
+
TrainingArguments,
|
19 |
+
DataCollatorForSeq2Seq,
|
20 |
+
Trainer,
|
21 |
+
EarlyStoppingCallback,
|
22 |
+
# BertForSequenceClassification,
|
23 |
+
GenerationConfig
|
24 |
+
)
|
25 |
|
26 |
+
|
27 |
from trl import SFTTrainer, SFTConfig
|
28 |
from deepspeed.profiling.flops_profiler import get_model_profile
|
29 |
from deepspeed.accelerator import get_accelerator
|
|
|
49 |
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
50 |
warnings.filterwarnings("ignore", category=UserWarning)
|
51 |
|
52 |
+
logging.getLogger("flwr").setLevel(logging.INFO)
|
53 |
+
logging.getLogger("ClientAppActor").setLevel(logging.INFO)
|
54 |
+
|
55 |
+
logger = logging.getLogger(__name__)
|
56 |
+
|
57 |
|
58 |
def input_constructor(batch_size, seq_len, tokenizer):
|
59 |
fake_seq = ""
|
|
|
94 |
self,
|
95 |
model_cfg: DictConfig,
|
96 |
train_cfg: DictConfig,
|
97 |
+
mates_cfg: DictConfig,
|
98 |
+
skipbert_cfg: DictConfig,
|
99 |
trainset,
|
100 |
valset,
|
101 |
num_rounds,
|
102 |
): # pylint: disable=too-many-arguments
|
103 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
104 |
self.train_cfg = train_cfg
|
105 |
+
self.skipbert_cfg = skipbert_cfg
|
106 |
|
107 |
self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
|
108 |
# self.training_arguments = SFTConfig(**train_cfg.training_arguments, max_seq_length=train_cfg.seq_length)
|
|
|
110 |
self.num_rounds = num_rounds
|
111 |
self.trainset = trainset
|
112 |
self.valset = valset
|
113 |
+
self.mates_cfg = mates_cfg
|
114 |
self.holdoutset = None
|
115 |
self.refset = None
|
116 |
+
self.teacher_data_influence_model = None
|
117 |
+
self.student_data_influence_model = None
|
118 |
self.data_influence_tokenizer = None
|
119 |
|
120 |
# instantiate model
|
121 |
self.model, self.tokenizer = get_model(model_cfg)
|
122 |
|
123 |
+
if self.mates_cfg.state:
|
124 |
+
self.teacher_data_influence_model, self.student_data_influence_model ,self.data_influence_tokenizer = get_data_influence_model(model_cfg, skipbert_cfg)
|
125 |
|
126 |
# (
|
127 |
# self.data_collator,
|
|
|
146 |
# Replace -100 with pad token id in labels
|
147 |
labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
|
148 |
|
149 |
+
# print(f"Shape of predictions: {np.shape(pred_ids)}")
|
150 |
+
# print(f"Shape of labels: {np.shape(labels_ids)}")
|
151 |
|
152 |
# Decode predictions and labels
|
153 |
pred_str = self.tokenizer.batch_decode(
|
|
|
182 |
.map(
|
183 |
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
184 |
num_proc=8,
|
185 |
+
remove_columns=['instruction', 'input', 'output']
|
186 |
)
|
187 |
)
|
188 |
|
|
|
193 |
.map(
|
194 |
lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
|
195 |
num_proc=8,
|
196 |
+
remove_columns=['instruction', 'input', 'output']
|
197 |
)
|
198 |
)
|
199 |
|
200 |
# Create holdoutset and refset if state is True
|
201 |
+
if self.mates_cfg.state:
|
202 |
trainset_size = len(self.trainset)
|
203 |
|
204 |
# Calculate sizes for holdout and reference sets
|
205 |
+
holdout_size = int(trainset_size * self.mates_cfg.holdout_ratio)
|
206 |
+
ref_size = int(trainset_size * self.mates_cfg.reference_ratio)
|
207 |
|
208 |
# Shuffle the trainset to ensure randomness
|
209 |
shuffled_indices = list(range(trainset_size))
|
|
|
218 |
self.refset = self.trainset.select(ref_indices)
|
219 |
|
220 |
print(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
|
221 |
+
# logger.info(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
|
222 |
|
223 |
|
224 |
def fit(
|
225 |
self, parameters: NDArrays, config: Dict[str, Scalar]
|
226 |
) -> Tuple[NDArrays, int, Dict]:
|
227 |
"""Implement distributed fit function for a given client."""
|
228 |
+
if self.mates_cfg.state and int(config["current_round"]) != 1:
|
229 |
main_model_params, data_influence_model_params = split_models(parameters)
|
230 |
set_parameters(self.model, main_model_params)
|
231 |
+
set_parameters_bert(self.teacher_data_influence_model, data_influence_model_params)
|
232 |
else:
|
233 |
set_parameters(self.model, parameters)
|
234 |
|
|
|
279 |
args=self.training_arguments,
|
280 |
data_collator=self.data_collator,
|
281 |
compute_metrics=self.compute_metrics,
|
282 |
+
mates_cfg=self.mates_cfg,
|
283 |
+
skipbert_cfg=self.skipbert_cfg,
|
284 |
+
teacher_data_influence_model=self.teacher_data_influence_model,
|
285 |
+
student_data_influence_model=self.student_data_influence_model,
|
286 |
data_influence_tokenizer=self.data_influence_tokenizer,
|
287 |
)
|
288 |
|
289 |
# Train the model
|
290 |
results = trainer.train()
|
291 |
|
292 |
+
if self.mates_cfg.state:
|
293 |
# After training
|
294 |
main_model_params = get_parameters(self.model)
|
295 |
+
data_influence_model_params = model_parameters_to_ndarrays(self.teacher_data_influence_model)
|
296 |
final_model_params = concatenate_models_with_marker(main_model_params, data_influence_model_params)
|
297 |
else:
|
298 |
final_model_params = get_parameters(self.model)
|
|
|
308 |
detailed=False,
|
309 |
)
|
310 |
flops2, macs2, params2 = get_model_profile(
|
311 |
+
self.teacher_data_influence_model,
|
312 |
kwargs=input_constructor(batch_size, seq_len, self.data_influence_tokenizer),
|
313 |
print_profile=True,
|
314 |
detailed=False,
|
|
|
337 |
client_set = load_data_homo(partition_id, num_partitions, cfg.dataset.name)
|
338 |
else:
|
339 |
client_set = load_data_hete(partition_id)
|
340 |
+
|
341 |
+
|
342 |
+
cfg.skipbert.att_layer_maps = [int(s) for s in cfg.skipbert.att_layer_maps.split(', ')]
|
343 |
+
cfg.skipbert.hid_layer_maps = [int(k) for k in cfg.skipbert.hid_layer_maps.split(', ')]
|
344 |
|
345 |
return FlowerClient(
|
346 |
cfg.model,
|
347 |
cfg.train,
|
348 |
cfg.mates,
|
349 |
+
cfg.skipbert,
|
350 |
client_set['train'],
|
351 |
client_set['test'],
|
352 |
num_rounds,
|
template_FL/src/fedllm/dataset.py
CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
|
|
9 |
|
10 |
FDS = None # Cache FederatedDataset
|
11 |
client_id_ds = None
|
12 |
-
global_test_set_homo = None
|
13 |
|
14 |
def split_train_test(dataset, test_size):
|
15 |
# Split the dataset into train and test sets
|
|
|
9 |
|
10 |
FDS = None # Cache FederatedDataset
|
11 |
client_id_ds = None
|
12 |
+
# global_test_set_homo = None
|
13 |
|
14 |
def split_train_test(dataset, test_size):
|
15 |
# Split the dataset into train and test sets
|
template_FL/src/fedllm/models.py
CHANGED
@@ -11,7 +11,16 @@ from peft import (
|
|
11 |
set_peft_model_state_dict,
|
12 |
)
|
13 |
from peft.utils import prepare_model_for_kbit_training
|
14 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
from flwr.common.typing import NDArrays
|
17 |
from transformers.trainer_callback import TrainerControl, TrainerState
|
@@ -90,24 +99,71 @@ def get_model(model_cfg: DictConfig):
|
|
90 |
|
91 |
return get_peft_model(model, peft_config), tokenizer
|
92 |
|
93 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
use_cuda = torch.cuda.is_available()
|
95 |
device_map = torch.device("cuda:0" if use_cuda else "cpu")
|
96 |
-
|
|
|
|
|
|
|
97 |
# Load model with num_labels=1
|
98 |
-
|
99 |
-
|
100 |
-
num_labels=1, # Set number of labels to 1 for regression or single-class tasks
|
101 |
).to(device_map)
|
102 |
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
if use_cuda:
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
)
|
|
|
109 |
|
110 |
-
return
|
111 |
|
112 |
|
113 |
def set_parameters(model, parameters: NDArrays) -> None:
|
|
|
11 |
set_peft_model_state_dict,
|
12 |
)
|
13 |
from peft.utils import prepare_model_for_kbit_training
|
14 |
+
from transformers import (
|
15 |
+
AutoModelForCausalLM,
|
16 |
+
AutoTokenizer,
|
17 |
+
BitsAndBytesConfig,
|
18 |
+
TrainerCallback,
|
19 |
+
# BertForSequenceClassification,
|
20 |
+
BertConfig,
|
21 |
+
)
|
22 |
+
|
23 |
+
from .skipbert.modeling import BertForSequenceClassification, SkipBertForSequenceClassification
|
24 |
|
25 |
from flwr.common.typing import NDArrays
|
26 |
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
|
99 |
|
100 |
return get_peft_model(model, peft_config), tokenizer
|
101 |
|
102 |
+
def get_custom_config(teacher_name, skipbert_cfg: DictConfig):
|
103 |
+
num_labels = 1 # Set number of labels to 1 for regression or single-class tasks
|
104 |
+
|
105 |
+
|
106 |
+
teacher_config = BertConfig.from_pretrained(teacher_name)
|
107 |
+
teacher_config.num_labels = num_labels
|
108 |
+
teacher_config.fit_size = teacher_config.hidden_size
|
109 |
+
|
110 |
+
student_config = BertConfig.from_pretrained(skipbert_cfg.student_model)
|
111 |
+
student_config.num_labels = num_labels
|
112 |
+
student_config.fit_size = teacher_config.hidden_size
|
113 |
+
|
114 |
+
if skipbert_cfg.num_layers_student > 0:
|
115 |
+
student_config.num_hidden_layers = skipbert_cfg.num_layers_student
|
116 |
+
if skipbert_cfg.num_full_hidden_layers_student > 0:
|
117 |
+
student_config.num_full_hidden_layers = skipbert_cfg.num_full_hidden_layers_student
|
118 |
+
else:
|
119 |
+
student_config.num_full_hidden_layers = student_config.num_hidden_layers
|
120 |
+
|
121 |
+
student_config.task_type = skipbert_cfg.output_mode
|
122 |
+
student_config.n_gram_left = skipbert_cfg.n_gram_left
|
123 |
+
student_config.n_gram_right = skipbert_cfg.n_gram_right
|
124 |
+
# student_config.plot_mode = 'plot_passive'
|
125 |
+
student_config.plot_mode = 'force_compute'
|
126 |
+
student_config.ngram_masking = 0.
|
127 |
+
if not hasattr(student_config, 'enter_hidden_size'):
|
128 |
+
student_config.enter_hidden_size = student_config.hidden_size
|
129 |
+
if not hasattr(student_config, 'max_num_entries'):
|
130 |
+
student_config.max_num_entries = 100000
|
131 |
+
|
132 |
+
return teacher_config, student_config
|
133 |
+
|
134 |
+
|
135 |
+
def get_data_influence_model(model_cfg: DictConfig, skipbert_cfg: DictConfig):
|
136 |
use_cuda = torch.cuda.is_available()
|
137 |
device_map = torch.device("cuda:0" if use_cuda else "cpu")
|
138 |
+
teacher_name = "bert-base-uncased"
|
139 |
+
|
140 |
+
teacher_config, student_config = get_custom_config(teacher_name=teacher_name, skipbert_cfg=skipbert_cfg)
|
141 |
+
|
142 |
# Load model with num_labels=1
|
143 |
+
teacher_model = BertForSequenceClassification.from_pretrained(
|
144 |
+
teacher_name, config=teacher_config
|
|
|
145 |
).to(device_map)
|
146 |
|
147 |
+
student_model = SkipBertForSequenceClassification.from_pretrained(
|
148 |
+
skipbert_cfg.student_model, config=student_config,
|
149 |
+
do_fit=skipbert_cfg.do_fit, share_param=skipbert_cfg.share_param).to(device_map)
|
150 |
+
|
151 |
+
if skipbert_cfg.freeze_lower_layers:
|
152 |
+
student_model.freeze_shallow_layers()
|
153 |
+
|
154 |
+
tokenizer = AutoTokenizer.from_pretrained(teacher_name, do_lower_case=skipbert_cfg.do_lower_case, use_fast=True)
|
155 |
|
156 |
if use_cuda:
|
157 |
+
teacher_model = prepare_model_for_kbit_training(
|
158 |
+
teacher_model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
159 |
+
)
|
160 |
+
|
161 |
+
student_model = prepare_model_for_kbit_training(
|
162 |
+
student_model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
163 |
)
|
164 |
+
|
165 |
|
166 |
+
return teacher_model, student_model, tokenizer
|
167 |
|
168 |
|
169 |
def set_parameters(model, parameters: NDArrays) -> None:
|
template_FL/src/fedllm/server_app.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
"""flowertune-llm: A Flower / FlowerTune app."""
|
2 |
|
3 |
import os
|
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
import numpy as np
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
from datetime import datetime
|
9 |
from tqdm import tqdm
|
@@ -12,9 +14,11 @@ from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding, Traini
|
|
12 |
from .trainer import ManualTrainer
|
13 |
from transformers.integrations import WandbCallback
|
14 |
from torch.utils.data import DataLoader
|
|
|
15 |
from flwr.common import Context, ndarrays_to_parameters
|
16 |
from flwr.common.config import unflatten_dict
|
17 |
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
|
18 |
# from flwr.server.strategy import FedAvg
|
19 |
from omegaconf import DictConfig
|
20 |
|
@@ -27,6 +31,12 @@ from .metrics import exact_match, f1, get_rouge_score
|
|
27 |
|
28 |
from datasets import load_dataset, Dataset
|
29 |
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
load_dotenv(".env")
|
@@ -36,6 +46,79 @@ os.environ["WANDB_NAME"] = os.getenv("WANDB_NAME")
|
|
36 |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
|
37 |
# os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
class LLMSampleCB(WandbCallback):
|
40 |
def __init__(self, trainer, test_dataset, task, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
|
41 |
"A CallBack to log samples a wandb.Table during training"
|
@@ -84,7 +167,7 @@ class LLMSampleCB(WandbCallback):
|
|
84 |
|
85 |
|
86 |
|
87 |
-
def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_args, task):
|
88 |
|
89 |
wandb.init(
|
90 |
project='FL@CSS25',
|
@@ -157,8 +240,10 @@ def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_arg
|
|
157 |
args=training_arguments,
|
158 |
data_collator=data_collator,
|
159 |
compute_metrics=compute_metrics,
|
160 |
-
|
161 |
-
|
|
|
|
|
162 |
data_influence_tokenizer=None,
|
163 |
)
|
164 |
|
@@ -185,7 +270,7 @@ def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_arg
|
|
185 |
# Get function that will be executed by the strategy's evaluate() method
|
186 |
# Here we use it to save global model checkpoints
|
187 |
|
188 |
-
def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_round, total_nodes, save_path, mates_args):
|
189 |
"""Return an evaluation function for saving global model."""
|
190 |
|
191 |
def evaluate(server_round: int, parameters, config):
|
@@ -208,12 +293,14 @@ def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_r
|
|
208 |
}
|
209 |
if dataset_cfg.type == 'homo':
|
210 |
ds = load_dataset(dataset_cfg.name)
|
|
|
|
|
211 |
_, test = train_test_split(
|
212 |
-
|
213 |
)
|
214 |
global_test_set_homo = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
|
215 |
|
216 |
-
loss, metrics = test_model(global_test_set_homo, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, 'homo')
|
217 |
total_loss = loss
|
218 |
result_metric = {'homo_f1': metrics['homo_f1']}
|
219 |
else:
|
@@ -225,7 +312,7 @@ def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_r
|
|
225 |
|
226 |
for task in ['general', 'finance', 'math', 'medical', 'code']:
|
227 |
ds = global_test_set_hete[task]
|
228 |
-
loss, metrics = test_model(ds, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, task)
|
229 |
list_loss.append(loss)
|
230 |
|
231 |
list_f1[f'{task}_f1'] = metrics[f'{task}_f1']
|
@@ -273,6 +360,8 @@ def fit_weighted_average(metrics):
|
|
273 |
def server_fn(context: Context):
|
274 |
"""Construct components that set the ServerApp behaviour."""
|
275 |
# Create output directory given current timestamp
|
|
|
|
|
276 |
current_time = datetime.now()
|
277 |
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
278 |
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
@@ -296,7 +385,7 @@ def server_fn(context: Context):
|
|
296 |
fit_metrics_aggregation_fn=fit_weighted_average,
|
297 |
initial_parameters=init_model_parameters,
|
298 |
evaluate_fn=get_evaluate_fn(
|
299 |
-
cfg.train, cfg.model, cfg.dataset, cfg.train.save_every_round, num_rounds, num_nodes, save_path, cfg.mates
|
300 |
),
|
301 |
use_mates=cfg.mates.state
|
302 |
)
|
|
|
1 |
"""flowertune-llm: A Flower / FlowerTune app."""
|
2 |
|
3 |
import os
|
4 |
+
import sys
|
5 |
import torch
|
6 |
import wandb
|
7 |
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
from dotenv import load_dotenv
|
10 |
from datetime import datetime
|
11 |
from tqdm import tqdm
|
|
|
14 |
from .trainer import ManualTrainer
|
15 |
from transformers.integrations import WandbCallback
|
16 |
from torch.utils.data import DataLoader
|
17 |
+
import flwr
|
18 |
from flwr.common import Context, ndarrays_to_parameters
|
19 |
from flwr.common.config import unflatten_dict
|
20 |
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
21 |
+
from flwr.common.logger import FLOWER_LOGGER
|
22 |
# from flwr.server.strategy import FedAvg
|
23 |
from omegaconf import DictConfig
|
24 |
|
|
|
31 |
|
32 |
from datasets import load_dataset, Dataset
|
33 |
from sklearn.model_selection import train_test_split
|
34 |
+
import logging
|
35 |
+
import uuid
|
36 |
+
|
37 |
+
logging.getLogger("flwr").setLevel(logging.INFO)
|
38 |
+
logging.getLogger("ClientAppActor").setLevel(logging.INFO)
|
39 |
+
logging.getLogger("Trainer").setLevel(logging.INFO)
|
40 |
|
41 |
|
42 |
load_dotenv(".env")
|
|
|
46 |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
|
47 |
# os.environ["WANDB_LOG_MODEL"] = "checkpoint"
|
48 |
|
49 |
+
|
50 |
+
class SessionIDFilter(logging.Filter):
|
51 |
+
"""Adds a session_id to log records."""
|
52 |
+
def __init__(self, session_id):
|
53 |
+
super().__init__()
|
54 |
+
self.session_id = session_id
|
55 |
+
|
56 |
+
def filter(self, record):
|
57 |
+
record.session_id = self.session_id
|
58 |
+
return True
|
59 |
+
|
60 |
+
|
61 |
+
def configure_logging():
|
62 |
+
# Generate a unique session ID for this run
|
63 |
+
session_id = str(uuid.uuid4())
|
64 |
+
|
65 |
+
# Define log format with session ID and process ID
|
66 |
+
log_format = (
|
67 |
+
"%(asctime)s - %(session_id)s - %(process)d - %(name)s - "
|
68 |
+
"%(levelname)s - %(message)s"
|
69 |
+
)
|
70 |
+
|
71 |
+
# Create a FileHandler and attach the SessionIDFilter to it
|
72 |
+
file_handler = logging.FileHandler("main.log", mode='a') # Append mode
|
73 |
+
formatter = logging.Formatter(log_format)
|
74 |
+
file_handler.setFormatter(formatter)
|
75 |
+
file_handler.addFilter(SessionIDFilter(session_id)) # Add filter to the handler
|
76 |
+
|
77 |
+
# Console handler: logs to stdout (you can also log to stderr)
|
78 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
79 |
+
console_handler.setLevel(logging.DEBUG)
|
80 |
+
console_handler.setFormatter(formatter)
|
81 |
+
|
82 |
+
|
83 |
+
# Configure root logger to use this handler
|
84 |
+
logging.basicConfig(
|
85 |
+
level=logging.INFO,
|
86 |
+
handlers=[
|
87 |
+
file_handler,
|
88 |
+
console_handler,
|
89 |
+
], # Use the filtered handler
|
90 |
+
)
|
91 |
+
|
92 |
+
# if not any(
|
93 |
+
# isinstance(handler, logging.FileHandler) and handler.baseFilename == file_handler.baseFilename
|
94 |
+
# for handler in FLOWER_LOGGER.handlers
|
95 |
+
# ):
|
96 |
+
# FLOWER_LOGGER.addHandler(file_handler)
|
97 |
+
|
98 |
+
for handler in FLOWER_LOGGER.handlers:
|
99 |
+
FLOWER_LOGGER.addHandler(file_handler)
|
100 |
+
|
101 |
+
# Get the logger for the ClientAppActor module and attach the same file handler
|
102 |
+
client_actor_logger = logging.getLogger("flwr.simulation.ray_transport.ray_actor")
|
103 |
+
# if not any(
|
104 |
+
# isinstance(handler, logging.FileHandler) and handler.baseFilename == file_handler.baseFilename
|
105 |
+
# for handler in client_actor_logger.handlers
|
106 |
+
# ):
|
107 |
+
# client_actor_logger.addHandler(file_handler)
|
108 |
+
|
109 |
+
for handler in client_actor_logger.handlers:
|
110 |
+
client_actor_logger.addHandler(file_handler)
|
111 |
+
|
112 |
+
# Explicitly configure Ray's logger to propagate
|
113 |
+
ray_logger = logging.getLogger("ray") # Ray's parent logger
|
114 |
+
for handler in ray_logger.handlers:
|
115 |
+
ray_logger.addHandler(file_handler)
|
116 |
+
|
117 |
+
|
118 |
+
# Log the start of the session
|
119 |
+
logger = logging.getLogger(__name__)
|
120 |
+
logger.info("===== Application Started =====")
|
121 |
+
|
122 |
class LLMSampleCB(WandbCallback):
|
123 |
def __init__(self, trainer, test_dataset, task, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
|
124 |
"A CallBack to log samples a wandb.Table during training"
|
|
|
167 |
|
168 |
|
169 |
|
170 |
+
def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_args, skipbert_cfg, task):
|
171 |
|
172 |
wandb.init(
|
173 |
project='FL@CSS25',
|
|
|
240 |
args=training_arguments,
|
241 |
data_collator=data_collator,
|
242 |
compute_metrics=compute_metrics,
|
243 |
+
mates_cfg=mates_args,
|
244 |
+
skipbert_cfg=skipbert_cfg,
|
245 |
+
teacher_data_influence_model=None,
|
246 |
+
student_data_influence_model=None,
|
247 |
data_influence_tokenizer=None,
|
248 |
)
|
249 |
|
|
|
270 |
# Get function that will be executed by the strategy's evaluate() method
|
271 |
# Here we use it to save global model checkpoints
|
272 |
|
273 |
+
def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_round, total_nodes, save_path, mates_args, skipbert_cfg):
|
274 |
"""Return an evaluation function for saving global model."""
|
275 |
|
276 |
def evaluate(server_round: int, parameters, config):
|
|
|
293 |
}
|
294 |
if dataset_cfg.type == 'homo':
|
295 |
ds = load_dataset(dataset_cfg.name)
|
296 |
+
option = 'test' if 'test' in ds else 'train'
|
297 |
+
df = pd.DataFrame(ds[option])
|
298 |
_, test = train_test_split(
|
299 |
+
df, test_size=0.09, shuffle=True, random_state=42
|
300 |
)
|
301 |
global_test_set_homo = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
|
302 |
|
303 |
+
loss, metrics = test_model(global_test_set_homo, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, skipbert_cfg, 'homo')
|
304 |
total_loss = loss
|
305 |
result_metric = {'homo_f1': metrics['homo_f1']}
|
306 |
else:
|
|
|
312 |
|
313 |
for task in ['general', 'finance', 'math', 'medical', 'code']:
|
314 |
ds = global_test_set_hete[task]
|
315 |
+
loss, metrics = test_model(ds, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, skipbert_cfg, task)
|
316 |
list_loss.append(loss)
|
317 |
|
318 |
list_f1[f'{task}_f1'] = metrics[f'{task}_f1']
|
|
|
360 |
def server_fn(context: Context):
|
361 |
"""Construct components that set the ServerApp behaviour."""
|
362 |
# Create output directory given current timestamp
|
363 |
+
configure_logging()
|
364 |
+
logger = logging.getLogger(__name__)
|
365 |
current_time = datetime.now()
|
366 |
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
367 |
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
|
385 |
fit_metrics_aggregation_fn=fit_weighted_average,
|
386 |
initial_parameters=init_model_parameters,
|
387 |
evaluate_fn=get_evaluate_fn(
|
388 |
+
cfg.train, cfg.model, cfg.dataset, cfg.train.save_every_round, num_rounds, num_nodes, save_path, cfg.mates, cfg.skipbert
|
389 |
),
|
390 |
use_mates=cfg.mates.state
|
391 |
)
|
template_FL/src/fedllm/skipbert/modeling.py
CHANGED
@@ -474,6 +474,7 @@ class ShallowSkipping(nn.Module):
|
|
474 |
|
475 |
@torch.jit.script
|
476 |
def merge_ngrams(input_ids, ngram_hidden_states, aux_embeddings):
|
|
|
477 |
batch_size, seq_length = input_ids.shape
|
478 |
lens = (input_ids!=0).sum(1)
|
479 |
hidden_state = torch.zeros([batch_size, seq_length, ngram_hidden_states.size(-1)], dtype=ngram_hidden_states.dtype, device=ngram_hidden_states.device)
|
@@ -562,7 +563,6 @@ class ShallowSkipping(nn.Module):
|
|
562 |
):
|
563 |
|
564 |
device = model.device
|
565 |
-
|
566 |
batch_size, seq_length = input_ids.shape
|
567 |
aux_embeddings = model.embeddings.position_embeddings2.weight[:seq_length].unsqueeze(0)
|
568 |
aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
|
|
|
474 |
|
475 |
@torch.jit.script
|
476 |
def merge_ngrams(input_ids, ngram_hidden_states, aux_embeddings):
|
477 |
+
# batch_size, seq_length = input_ids.shape
|
478 |
batch_size, seq_length = input_ids.shape
|
479 |
lens = (input_ids!=0).sum(1)
|
480 |
hidden_state = torch.zeros([batch_size, seq_length, ngram_hidden_states.size(-1)], dtype=ngram_hidden_states.dtype, device=ngram_hidden_states.device)
|
|
|
563 |
):
|
564 |
|
565 |
device = model.device
|
|
|
566 |
batch_size, seq_length = input_ids.shape
|
567 |
aux_embeddings = model.embeddings.position_embeddings2.weight[:seq_length].unsqueeze(0)
|
568 |
aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
|
template_FL/src/fedllm/skipbert/plot.py
CHANGED
@@ -117,8 +117,8 @@ class Plot:
|
|
117 |
self.max_num_entries = max_num_entries
|
118 |
self.hidden_size = hidden_size
|
119 |
|
120 |
-
self.trigram_to_id, self.id_to_trigram = self.build_hash_table('input_ids_tri_gram.memmap', max_num_entries)
|
121 |
-
self.orig_trigram_hidden_states = _read_or_create_memmap("plot_hidden_states_tri_gram.memmap", dtype='float16', shape=(max_num_entries, 3, hidden_size))
|
122 |
|
123 |
def build_hash_table(self, path, max_num_entries):
|
124 |
n_gram = 3
|
|
|
117 |
self.max_num_entries = max_num_entries
|
118 |
self.hidden_size = hidden_size
|
119 |
|
120 |
+
self.trigram_to_id, self.id_to_trigram = self.build_hash_table('./input_ids_tri_gram.memmap', max_num_entries)
|
121 |
+
self.orig_trigram_hidden_states = _read_or_create_memmap("./plot_hidden_states_tri_gram.memmap", dtype='float16', shape=(max_num_entries, 3, hidden_size))
|
122 |
|
123 |
def build_hash_table(self, path, max_num_entries):
|
124 |
n_gram = 3
|
template_FL/src/fedllm/skipbert/trainer.py
ADDED
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from accelerate import Accelerator
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import MSELoss, CrossEntropyLoss
|
6 |
+
import copy
|
7 |
+
import numpy as np
|
8 |
+
from transformers import (
|
9 |
+
# BertForSequenceClassification,
|
10 |
+
GenerationConfig,
|
11 |
+
AutoTokenizer,
|
12 |
+
Trainer,
|
13 |
+
get_scheduler,
|
14 |
+
EarlyStoppingCallback,
|
15 |
+
TrainingArguments
|
16 |
+
)
|
17 |
+
from transformers.trainer_utils import (
|
18 |
+
EvaluationStrategy,
|
19 |
+
IntervalStrategy,
|
20 |
+
)
|
21 |
+
from transformers.trainer_pt_utils import nested_detach
|
22 |
+
from transformers.utils import is_sagemaker_mp_enabled
|
23 |
+
from transformers.training_args import OptimizerNames
|
24 |
+
from typing import Dict, List, Optional, Any, Union, Tuple, Callable
|
25 |
+
import numpy as np
|
26 |
+
from torch.utils.data import Dataset, DataLoader
|
27 |
+
from accelerate.utils import (
|
28 |
+
AutocastKwargs,
|
29 |
+
DistributedDataParallelKwargs,
|
30 |
+
DistributedType,
|
31 |
+
)
|
32 |
+
from datasets import Dataset
|
33 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
34 |
+
import wandb
|
35 |
+
import logging
|
36 |
+
|
37 |
+
logging.getLogger("Trainer").setLevel(logging.INFO)
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
|
41 |
+
def compute_metrics_skipbert(pred):
|
42 |
+
"""
|
43 |
+
Compute metrics for model evaluation
|
44 |
+
"""
|
45 |
+
labels = pred.label_ids
|
46 |
+
|
47 |
+
preds = pred.predictions
|
48 |
+
|
49 |
+
if len(preds[0]) >= 2:
|
50 |
+
preds = torch.tensor(preds.argmax(-1))
|
51 |
+
labels = torch.tensor(labels)
|
52 |
+
|
53 |
+
acc = accuracy_score(labels, preds)
|
54 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
|
55 |
+
return {
|
56 |
+
'accuracy': acc,
|
57 |
+
'f1': f1,
|
58 |
+
'precision': precision,
|
59 |
+
'recall': recall
|
60 |
+
}
|
61 |
+
else:
|
62 |
+
labels = torch.tensor(pred.label_ids[:, np.newaxis])
|
63 |
+
preds = torch.tensor(pred.predictions)
|
64 |
+
|
65 |
+
# MSE
|
66 |
+
mse = nn.MSELoss()
|
67 |
+
mse_loss = mse(labels, preds)
|
68 |
+
|
69 |
+
#RMSE
|
70 |
+
rmse = torch.sqrt(mse_loss)
|
71 |
+
|
72 |
+
# MAE
|
73 |
+
mae = nn.L1Loss()
|
74 |
+
mae_loss = mae(labels, preds)
|
75 |
+
|
76 |
+
|
77 |
+
return {
|
78 |
+
'mse': mse_loss,
|
79 |
+
'rmse': rmse,
|
80 |
+
'mae': mae_loss,
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
# Create custom Trainer for training SkipBERT
|
85 |
+
class SkipBertTrainer(Trainer):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
student_model: nn.Module,
|
89 |
+
teacher_model: Optional[nn.Module] = None,
|
90 |
+
train_dataset: Optional[Dataset] = None,
|
91 |
+
eval_dataset: Optional[Dataset] = None,
|
92 |
+
args: Optional[TrainingArguments] = None,
|
93 |
+
data_collator: Optional[Callable] = None,
|
94 |
+
compute_metrics: Optional[Callable] = None,
|
95 |
+
alpha: float = 0.5,
|
96 |
+
temperature: float = 2.0,
|
97 |
+
beta: float = 1.0,
|
98 |
+
use_logits: bool = True,
|
99 |
+
use_att: bool = True,
|
100 |
+
use_rep: bool = True,
|
101 |
+
use_embedding: bool = True,
|
102 |
+
att_layer_maps: Optional[List[int]] = None,
|
103 |
+
hid_layer_maps: Optional[List[int]] = None,
|
104 |
+
epochs_no_cls: int = 0,
|
105 |
+
reduce_T: int = 1,
|
106 |
+
output_mode: str = 'classification',
|
107 |
+
num_masked_layers_teacher: int = 0,
|
108 |
+
num_masked_last_layers_teacher: int = 0,
|
109 |
+
fp16: bool = False,
|
110 |
+
num_full_hidden_layers_student: int = 0,
|
111 |
+
**kwargs,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
Initialize SkipBERT Trainer with knowledge distillation capabilities.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
student_model: The student model to be trained
|
118 |
+
teacher_model: The teacher model for knowledge distillation
|
119 |
+
train_dataset: Training dataset
|
120 |
+
eval_dataset: Evaluation dataset
|
121 |
+
args: Training arguments
|
122 |
+
alpha: Balance between distillation loss and cross-entropy loss
|
123 |
+
temperature: Temperature for softening probability distributions
|
124 |
+
beta: Weighting factor for different loss components
|
125 |
+
use_logits: Whether to use logits-based distillation
|
126 |
+
use_att: Whether to use attention-based distillation
|
127 |
+
use_rep: Whether to use representation-based distillation
|
128 |
+
use_embedding: Whether to use embedding-based distillation
|
129 |
+
"""
|
130 |
+
# Set default training arguments if not provided
|
131 |
+
if args is None:
|
132 |
+
args = TrainingArguments(
|
133 |
+
output_dir="./results",
|
134 |
+
num_train_epochs=3,
|
135 |
+
per_device_train_batch_size=2,
|
136 |
+
per_device_eval_batch_size=2,
|
137 |
+
logging_dir='./logs',
|
138 |
+
evaluation_strategy=EvaluationStrategy.EPOCH,
|
139 |
+
save_strategy=IntervalStrategy.EPOCH,
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
# Call parent constructor
|
144 |
+
super().__init__(
|
145 |
+
model=student_model,
|
146 |
+
args=args,
|
147 |
+
train_dataset=train_dataset,
|
148 |
+
eval_dataset=eval_dataset,
|
149 |
+
data_collator=data_collator,
|
150 |
+
compute_metrics=compute_metrics,
|
151 |
+
**kwargs
|
152 |
+
)
|
153 |
+
|
154 |
+
# Store additional knowledge distillation parameters
|
155 |
+
self.teacher_model = teacher_model
|
156 |
+
self.alpha = alpha
|
157 |
+
self.temperature = temperature
|
158 |
+
self.beta = beta
|
159 |
+
self.use_logits = use_logits
|
160 |
+
self.use_att = use_att
|
161 |
+
self.use_rep = use_rep
|
162 |
+
self.use_embedding = use_embedding
|
163 |
+
self.att_layer_maps = att_layer_maps or []
|
164 |
+
self.hid_layer_maps = hid_layer_maps or []
|
165 |
+
self.epochs_no_cls = epochs_no_cls
|
166 |
+
self.reduce_T = reduce_T
|
167 |
+
self.output_mode = output_mode
|
168 |
+
self.num_masked_layers_teacher = num_masked_layers_teacher
|
169 |
+
self.num_masked_last_layers_teacher = num_masked_last_layers_teacher
|
170 |
+
self.num_full_hidden_layers_student = num_full_hidden_layers_student
|
171 |
+
self.tr_att_loss = 0
|
172 |
+
self.tr_rep_loss = 0
|
173 |
+
self.tr_cls_loss = 0
|
174 |
+
self.list_att_loss = []
|
175 |
+
self.list_rep_loss = []
|
176 |
+
self.list_embed_loss = []
|
177 |
+
|
178 |
+
# Prepare FP16 if enabled
|
179 |
+
self.fp16 = fp16
|
180 |
+
if fp16:
|
181 |
+
try:
|
182 |
+
from apex import amp
|
183 |
+
except ImportError:
|
184 |
+
raise ImportError(
|
185 |
+
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
|
186 |
+
)
|
187 |
+
|
188 |
+
# Initialize amp
|
189 |
+
self.model, self.optimizer = amp.initialize(
|
190 |
+
self.model,
|
191 |
+
self.optimizer,
|
192 |
+
opt_level='01'
|
193 |
+
)
|
194 |
+
|
195 |
+
# Half precision for teacher model if exists
|
196 |
+
if self.teacher_model is not None:
|
197 |
+
self.teacher_model = self.teacher_model.half()
|
198 |
+
|
199 |
+
# Loss functions
|
200 |
+
self.loss_mse = MSELoss()
|
201 |
+
|
202 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
203 |
+
"""
|
204 |
+
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
205 |
+
|
206 |
+
Subclass and override for custom behavior.
|
207 |
+
"""
|
208 |
+
|
209 |
+
# Separate labels from inputs
|
210 |
+
labels = inputs.pop("labels")
|
211 |
+
|
212 |
+
if self.model_accepts_loss_kwargs:
|
213 |
+
loss_kwargs = {}
|
214 |
+
if num_items_in_batch is not None:
|
215 |
+
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
216 |
+
inputs = {**inputs, **loss_kwargs}
|
217 |
+
|
218 |
+
# Forward pass through student model
|
219 |
+
student_logits, student_atts, student_reps = model(**inputs)
|
220 |
+
student_reps = student_reps[-self.num_full_hidden_layers_student-1:]
|
221 |
+
|
222 |
+
# Forward pass through teacher model
|
223 |
+
self.teacher_model.eval()
|
224 |
+
with torch.no_grad():
|
225 |
+
# teacher_logits, teacher_atts, teacher_reps = self.teacher_model(**inputs)
|
226 |
+
teacher_outputs = self.teacher_model(**inputs, output_hidden_states=True, output_attentions=True)
|
227 |
+
teacher_logits, teacher_atts, teacher_reps = teacher_outputs.logits, teacher_outputs.attentions, teacher_outputs.hidden_states
|
228 |
+
start, end = self.num_masked_layers_teacher, -1 * self.num_masked_layers_teacher if self.num_masked_layers_teacher != 0 else None
|
229 |
+
teacher_reps = teacher_reps[start:end]
|
230 |
+
|
231 |
+
# Save past state if it exists
|
232 |
+
# TODO: this needs to be fixed and made cleaner later.
|
233 |
+
if self.args.past_index >= 0:
|
234 |
+
self._past = student_outputs[self.args.past_index]
|
235 |
+
|
236 |
+
# Compute losses
|
237 |
+
att_loss, rep_loss = 0., 0.
|
238 |
+
|
239 |
+
# ---------------------------
|
240 |
+
if labels is not None:
|
241 |
+
|
242 |
+
|
243 |
+
# ---------------------------
|
244 |
+
if self.att_layer_maps is None:
|
245 |
+
teacher_layer_num = len(teacher_atts)
|
246 |
+
student_layer_num = len(student_atts)
|
247 |
+
assert teacher_layer_num % student_layer_num == 0
|
248 |
+
layers_per_block = int(teacher_layer_num / student_layer_num)
|
249 |
+
new_teacher_atts = [
|
250 |
+
teacher_atts[(i * 1) * layers_per_block - 1]
|
251 |
+
for i in range(student_layer_num)
|
252 |
+
]
|
253 |
+
assert len(student_atts) == len(new_teacher_atts)
|
254 |
+
|
255 |
+
else:
|
256 |
+
new_teacher_atts = []
|
257 |
+
for t2s in self.att_layer_maps:
|
258 |
+
if t2s >= 0:
|
259 |
+
new_teacher_atts.append(teacher_atts[t2s])
|
260 |
+
else:
|
261 |
+
new_teacher_atts.append(None)
|
262 |
+
|
263 |
+
# ----------------------------
|
264 |
+
|
265 |
+
for student_att, teacher_att in zip(student_atts, new_teacher_atts):
|
266 |
+
if teacher_att is None:
|
267 |
+
continue
|
268 |
+
student_att = torch.where(
|
269 |
+
student_att <= 1e-2,
|
270 |
+
torch.zeros_like(student_att),
|
271 |
+
student_att
|
272 |
+
)
|
273 |
+
|
274 |
+
teacher_att = torch.where(
|
275 |
+
teacher_att <= 1e-2,
|
276 |
+
torch.zeros_like(teacher_att),
|
277 |
+
teacher_att
|
278 |
+
)
|
279 |
+
|
280 |
+
att_loss += self.loss_mse(student_att, teacher_att)
|
281 |
+
|
282 |
+
# ---------------------------
|
283 |
+
|
284 |
+
if self.hid_layer_maps is None:
|
285 |
+
teacher_layer_num = len(teacher_atts) - 1
|
286 |
+
student_layer_num = len(student_atts) - 1
|
287 |
+
assert teacher_layer_num % student_layer_num == 0
|
288 |
+
layers_per_block = int(teacher_layer_num / student_layer_num)
|
289 |
+
new_teacher_reps = [
|
290 |
+
teacher_reps[i * layers_per_block]
|
291 |
+
for i in range(student_layer_num + 1)
|
292 |
+
]
|
293 |
+
assert len(new_student_reps) == len(new_teacher_reps)
|
294 |
+
else:
|
295 |
+
new_student_reps = student_reps
|
296 |
+
new_teacher_reps = []
|
297 |
+
for t2s in self.hid_layer_maps:
|
298 |
+
if t2s >= 0:
|
299 |
+
new_teacher_reps.append(teacher_reps[t2s])
|
300 |
+
else:
|
301 |
+
new_teacher_reps.append(None)
|
302 |
+
|
303 |
+
# ---------------------------
|
304 |
+
|
305 |
+
for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
|
306 |
+
if teacher_rep is None:
|
307 |
+
continue
|
308 |
+
tmp_loss = self.loss_mse(student_rep, teacher_rep)
|
309 |
+
rep_loss += tmp_loss
|
310 |
+
|
311 |
+
self.tr_att_loss += att_loss.item()
|
312 |
+
self.tr_rep_loss += rep_loss.item()
|
313 |
+
|
314 |
+
# ---------------------------
|
315 |
+
embedding_loss = 0
|
316 |
+
if self.use_embedding:
|
317 |
+
embedding_loss = self.loss_mse(
|
318 |
+
student_reps[0], teacher_reps[0]
|
319 |
+
)
|
320 |
+
|
321 |
+
# ---------------------------
|
322 |
+
|
323 |
+
# ---------------------------
|
324 |
+
|
325 |
+
if self.use_logits and self.state.epoch >= self.epochs_no_cls:
|
326 |
+
if isinstance(student_logits, tuple) or \
|
327 |
+
isinstance(student_logits, list):
|
328 |
+
cls_loss = None
|
329 |
+
_scale = 0.
|
330 |
+
for il, logits in enumerate(student_logits):
|
331 |
+
_loss, _, _ = self._compute_distillation_loss(
|
332 |
+
student_logits, student_atts, student_reps,
|
333 |
+
teacher_logits, teacher_atts, teacher_reps,
|
334 |
+
labels
|
335 |
+
)
|
336 |
+
if cls_loss is None:
|
337 |
+
cls_loss = _loss
|
338 |
+
else:
|
339 |
+
cls_loss = _loss * (il + 1.) + cls_loss
|
340 |
+
_scale += il + 1
|
341 |
+
|
342 |
+
cls_loss = cls_loss * (1. / _scale)
|
343 |
+
|
344 |
+
else:
|
345 |
+
cls_loss, kd_loss, ce_loss = self._compute_distillation_loss(
|
346 |
+
student_logits, student_atts, student_reps,
|
347 |
+
teacher_logits, teacher_atts, teacher_reps,
|
348 |
+
labels
|
349 |
+
)
|
350 |
+
self.tr_cls_loss += cls_loss.item()
|
351 |
+
|
352 |
+
else:
|
353 |
+
cls_loss = 0
|
354 |
+
|
355 |
+
# ---------------------------
|
356 |
+
|
357 |
+
|
358 |
+
check = self.state.epoch >= self.epochs_no_cls
|
359 |
+
self.beta = self.beta * check + (1 - check) * 1.
|
360 |
+
|
361 |
+
# ---------------------------
|
362 |
+
|
363 |
+
if self.use_embedding and \
|
364 |
+
self.use_att and \
|
365 |
+
self.use_rep:
|
366 |
+
loss = self.beta * (rep_loss + att_loss + embedding_loss) + cls_loss
|
367 |
+
|
368 |
+
elif self.use_att and self.use_rep:
|
369 |
+
loss = self.beta * (rep_loss + att_loss) + cls_loss
|
370 |
+
|
371 |
+
elif self.use_embedding and self.use_att:
|
372 |
+
loss = self.beta * (att_loss + embedding_loss) + cls_loss
|
373 |
+
|
374 |
+
elif self.use_embedding and self.use_rep:
|
375 |
+
loss = self.beta * (rep_loss + embedding_loss) + cls_loss
|
376 |
+
|
377 |
+
elif self.use_att and \
|
378 |
+
not self.use_embedding and \
|
379 |
+
not self.use_rep:
|
380 |
+
loss = self.beta * att_loss + cls_loss
|
381 |
+
|
382 |
+
elif self.use_rep and \
|
383 |
+
not self.use_embedding and \
|
384 |
+
not self.use_att:
|
385 |
+
loss = self.beta * rep_loss + cls_loss
|
386 |
+
|
387 |
+
else:
|
388 |
+
loss = cls_loss
|
389 |
+
|
390 |
+
|
391 |
+
# ---------------------------
|
392 |
+
|
393 |
+
else:
|
394 |
+
if isinstance(outputs, dict) and "loss" not in outputs:
|
395 |
+
raise ValueError(
|
396 |
+
"The model did not return a loss from the inputs, only the following keys: "
|
397 |
+
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
398 |
+
)
|
399 |
+
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
400 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
401 |
+
|
402 |
+
# ---------------------------
|
403 |
+
|
404 |
+
# ---------------------------
|
405 |
+
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
406 |
+
loss *= self.accelerator.num_processes
|
407 |
+
rep_loss *= self.accelerator.num_processes
|
408 |
+
att_loss *= self.accelerator.num_processes
|
409 |
+
embedding_loss *= self.accelerator.num_processes
|
410 |
+
self.list_att_loss.append(att_loss.item())
|
411 |
+
self.list_rep_loss.append(rep_loss.item())
|
412 |
+
self.list_embed_loss.append(embedding_loss.item())
|
413 |
+
# ---------------------------
|
414 |
+
|
415 |
+
|
416 |
+
# ---------------------------
|
417 |
+
|
418 |
+
# Ensure logits are properly formatted for evaluation metrics
|
419 |
+
logits = student_logits
|
420 |
+
if return_outputs:
|
421 |
+
# Ensure student_logits has the correct shape [batch_size, num_classes]
|
422 |
+
if isinstance(student_logits, (tuple, list)):
|
423 |
+
logits = student_logits[-1]
|
424 |
+
else:
|
425 |
+
logits = student_logits
|
426 |
+
|
427 |
+
# If logits is 1D, reshape it to 2D
|
428 |
+
if len(logits.shape) == 1:
|
429 |
+
logits = logits.unsqueeze(0)
|
430 |
+
|
431 |
+
# Ensure we have [batch_size, num_classes] shape
|
432 |
+
if len(logits.shape) != 2:
|
433 |
+
raise ValueError(f"Unexpected logits shape: {logits.shape}. Expected [batch_size, num_classes]")
|
434 |
+
|
435 |
+
if self.output_mode == "classification": # Classification
|
436 |
+
loss = nn.functional.cross_entropy(labels.view(-1), logits.view(-1, len(logits[0])), reduction="mean")
|
437 |
+
|
438 |
+
elif self.output_mode == "regression": # Regression
|
439 |
+
# print(f"Return output - student: {nn.functional.softmax(student_logits, dim=0).view(-1)}, labels: {labels.view(-1)}")
|
440 |
+
loss = self.loss_mse(labels.view(-1), logits.view(-1))
|
441 |
+
|
442 |
+
|
443 |
+
# ---------------------------
|
444 |
+
# print(f"loss: {loss}, att_loss: {att_loss}, rep_loss: {rep_loss}, embed_loss: {embedding_loss}, Train {return_outputs}")
|
445 |
+
return (loss, logits) if return_outputs else loss
|
446 |
+
|
447 |
+
def _compute_distillation_loss(
|
448 |
+
self,
|
449 |
+
student_logits, student_atts, student_reps,
|
450 |
+
teacher_logits, teacher_atts, teacher_reps,
|
451 |
+
labels
|
452 |
+
):
|
453 |
+
"""
|
454 |
+
Compute comprehensive knowledge distillation loss.
|
455 |
+
|
456 |
+
Args:
|
457 |
+
student_*: Student model's outputs
|
458 |
+
teacher_*: Teacher model's outputs
|
459 |
+
labels: Ground truth labels
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
Computed loss
|
463 |
+
"""
|
464 |
+
|
465 |
+
# Classification/distillation loss
|
466 |
+
if self.output_mode == "classification": # Classification
|
467 |
+
# Similar to previous implementation's distillation loss
|
468 |
+
if teacher_logits is not None:
|
469 |
+
student_likelihood = nn.functional.log_softmax(student_logits / self.temperature, dim=-1)
|
470 |
+
targets_prob = nn.functional.softmax(teacher_logits / self.temperature, dim=-1)
|
471 |
+
d_loss = (-targets_prob * student_likelihood).mean() * (self.temperature ** 2) / self.reduce_T
|
472 |
+
else:
|
473 |
+
d_loss = 0
|
474 |
+
# Standard cross-entropy/MSE loss
|
475 |
+
nll_loss = nn.functional.cross_entropy(student_logits, labels, reduction="mean")
|
476 |
+
|
477 |
+
elif self.output_mode == "regression": # Regression
|
478 |
+
# student_likelihood = nn.functional.softmax(student_logits, dim=0)
|
479 |
+
# teacher_likelihood = nn.functional.softmax(teacher_logits, dim=0)
|
480 |
+
student_likelihood = torch.tensor(student_logits)
|
481 |
+
teacher_likelihood = torch.tensor(teacher_logits)
|
482 |
+
d_loss = self.loss_mse(student_likelihood.view(-1), teacher_likelihood.view(-1))
|
483 |
+
nll_loss = self.loss_mse(teacher_likelihood.view(-1), labels.view(-1))
|
484 |
+
else:
|
485 |
+
assert output_mode in ["classification", "regression"]
|
486 |
+
d_loss = 0.
|
487 |
+
nll_loss = 0.
|
488 |
+
tol_loss = self.alpha * d_loss + (1 - self.alpha) * nll_loss
|
489 |
+
return tol_loss, d_loss, nll_loss
|
490 |
+
|
491 |
+
def train(
|
492 |
+
self,
|
493 |
+
resume_from_checkpoint: Optional[str] = None,
|
494 |
+
trial: Optional[Dict[str, Any]] = None,
|
495 |
+
ignore_keys_for_eval: Optional[List[str]] = None,
|
496 |
+
**kwargs
|
497 |
+
):
|
498 |
+
"""
|
499 |
+
Train method with explicit configuration for knowledge distillation training.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
resume_from_checkpoint: Optional checkpoint to resume training
|
503 |
+
trial: Optional hyperparameter trial configuration
|
504 |
+
ignore_keys_for_eval: Keys to ignore during evaluation
|
505 |
+
"""
|
506 |
+
# Prepare teacher model if exists
|
507 |
+
if self.teacher_model is not None:
|
508 |
+
self.teacher_model.to(self.args.device)
|
509 |
+
self.teacher_model.eval() # Ensure teacher is in eval mode
|
510 |
+
|
511 |
+
# Call parent train method
|
512 |
+
return super().train(
|
513 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
514 |
+
trial=trial,
|
515 |
+
ignore_keys_for_eval=ignore_keys_for_eval,
|
516 |
+
**kwargs
|
517 |
+
)
|
518 |
+
def training_step(
|
519 |
+
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
|
520 |
+
) -> torch.Tensor:
|
521 |
+
"""
|
522 |
+
Perform a training step on a batch of inputs.
|
523 |
+
|
524 |
+
Subclass and override to inject custom behavior.
|
525 |
+
|
526 |
+
Args:
|
527 |
+
model (`nn.Module`):
|
528 |
+
The model to train.
|
529 |
+
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
530 |
+
The inputs and targets of the model.
|
531 |
+
|
532 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
533 |
+
argument `labels`. Check your model's documentation for all accepted arguments.
|
534 |
+
|
535 |
+
Return:
|
536 |
+
`torch.Tensor`: The tensor with training loss on this batch.
|
537 |
+
"""
|
538 |
+
model.train()
|
539 |
+
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
540 |
+
self.optimizer.train()
|
541 |
+
|
542 |
+
inputs = self._prepare_inputs(inputs)
|
543 |
+
if is_sagemaker_mp_enabled():
|
544 |
+
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
545 |
+
return loss_mb.reduce_mean().detach().to(self.args.device)
|
546 |
+
|
547 |
+
with self.compute_loss_context_manager():
|
548 |
+
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
549 |
+
|
550 |
+
del inputs
|
551 |
+
if (
|
552 |
+
self.args.torch_empty_cache_steps is not None
|
553 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
554 |
+
):
|
555 |
+
if is_torch_xpu_available():
|
556 |
+
torch.xpu.empty_cache()
|
557 |
+
elif is_torch_mlu_available():
|
558 |
+
torch.mlu.empty_cache()
|
559 |
+
elif is_torch_musa_available():
|
560 |
+
torch.musa.empty_cache()
|
561 |
+
elif is_torch_npu_available():
|
562 |
+
torch.npu.empty_cache()
|
563 |
+
elif is_torch_mps_available(min_version="2.0"):
|
564 |
+
torch.mps.empty_cache()
|
565 |
+
else:
|
566 |
+
torch.cuda.empty_cache()
|
567 |
+
|
568 |
+
kwargs = {}
|
569 |
+
|
570 |
+
# For LOMO optimizers you need to explicitly use the learnign rate
|
571 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
572 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
573 |
+
|
574 |
+
if self.args.n_gpu > 1:
|
575 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
576 |
+
|
577 |
+
if self.use_apex:
|
578 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
579 |
+
scaled_loss.requires_grad = True
|
580 |
+
scaled_loss.backward()
|
581 |
+
|
582 |
+
if (self.state.global_step + 1) % self.args.gradient_accumulation_steps == 0:
|
583 |
+
nn.utils.clip_grad_norm_(amp.master_params(self.optimizer[0]), 1.0)
|
584 |
+
|
585 |
+
|
586 |
+
else:
|
587 |
+
# Finally we need to normalize the loss for reporting
|
588 |
+
loss.requires_grad = True
|
589 |
+
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
|
590 |
+
loss = loss / self.args.gradient_accumulation_steps
|
591 |
+
|
592 |
+
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
|
593 |
+
# https://github.com/huggingface/transformers/pull/35808
|
594 |
+
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
|
595 |
+
kwargs["scale_wrt_gas"] = False
|
596 |
+
|
597 |
+
self.accelerator.backward(loss, **kwargs)
|
598 |
+
|
599 |
+
if (self.state.global_step + 1) % self.args.gradient_accumulation_steps == 0:
|
600 |
+
# nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
|
601 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
602 |
+
|
603 |
+
return loss.detach()
|
604 |
+
|
605 |
+
def evaluate(
|
606 |
+
self,
|
607 |
+
eval_dataset: Optional[Dataset] = None,
|
608 |
+
ignore_keys: Optional[List[str]] = None,
|
609 |
+
metric_key_prefix: str = "eval",
|
610 |
+
**kwargs
|
611 |
+
) -> Dict[str, float]:
|
612 |
+
"""
|
613 |
+
Evaluation method with custom metrics computation.
|
614 |
+
|
615 |
+
Args:
|
616 |
+
eval_dataset: Optional evaluation dataset
|
617 |
+
ignore_keys: Keys to ignore during evaluation
|
618 |
+
metric_key_prefix: Prefix for metrics
|
619 |
+
|
620 |
+
Returns:
|
621 |
+
Dictionary of evaluation metrics
|
622 |
+
"""
|
623 |
+
# Use parent's evaluate method with optional customizations
|
624 |
+
return super().evaluate(
|
625 |
+
eval_dataset=eval_dataset,
|
626 |
+
ignore_keys=ignore_keys,
|
627 |
+
metric_key_prefix=metric_key_prefix,
|
628 |
+
**kwargs
|
629 |
+
)
|
630 |
+
|
631 |
+
def prediction_step(
|
632 |
+
self,
|
633 |
+
model: nn.Module,
|
634 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
635 |
+
prediction_loss_only: bool,
|
636 |
+
ignore_keys: Optional[List[str]] = None,
|
637 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
638 |
+
"""
|
639 |
+
Override prediction step to handle the model's output format correctly.
|
640 |
+
"""
|
641 |
+
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
|
642 |
+
|
643 |
+
return_loss = inputs.get("return_loss", None)
|
644 |
+
if return_loss is None:
|
645 |
+
return_loss = self.can_return_loss
|
646 |
+
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
|
647 |
+
|
648 |
+
|
649 |
+
inputs = self._prepare_inputs(inputs)
|
650 |
+
if ignore_keys is None:
|
651 |
+
if hasattr(self.model, "config"):
|
652 |
+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
653 |
+
else:
|
654 |
+
ignore_keys = []
|
655 |
+
|
656 |
+
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
657 |
+
if has_labels or loss_without_labels:
|
658 |
+
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
659 |
+
if len(labels) == 1:
|
660 |
+
labels = labels[0]
|
661 |
+
else:
|
662 |
+
labels = None
|
663 |
+
|
664 |
+
with torch.no_grad():
|
665 |
+
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
666 |
+
loss = loss.mean().detach()
|
667 |
+
|
668 |
+
# Get logits from outputs
|
669 |
+
if isinstance(outputs, dict):
|
670 |
+
logits = outputs["logits"]
|
671 |
+
else:
|
672 |
+
# logits = outputs[0]
|
673 |
+
logits = outputs
|
674 |
+
|
675 |
+
|
676 |
+
# Ensure logits has correct shape [batch_size, num_classes]
|
677 |
+
if len(logits.shape) == 1:
|
678 |
+
logits = logits.unsqueeze(0)
|
679 |
+
|
680 |
+
if prediction_loss_only:
|
681 |
+
return (loss, None, None)
|
682 |
+
|
683 |
+
if labels is not None:
|
684 |
+
labels = labels.detach()
|
685 |
+
|
686 |
+
logits = nested_detach(logits)
|
687 |
+
if len(logits.shape) == 1:
|
688 |
+
logits = logits[0]
|
689 |
+
|
690 |
+
# print(f"Validation loss: {loss}")
|
691 |
+
return (loss, logits, labels)
|
template_FL/src/fedllm/trainer.py
CHANGED
@@ -3,14 +3,48 @@ from torch.utils.data import DataLoader
|
|
3 |
import torch
|
4 |
import copy
|
5 |
import numpy as np
|
6 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import inspect
|
8 |
import logging
|
9 |
import wandb
|
10 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class ManualLLMSampleCB:
|
15 |
def __init__(self, model, tokenizer, task, num_samples=10, max_new_tokens=256):
|
16 |
self.model = model
|
@@ -61,12 +95,14 @@ class ManualLLMSampleCB:
|
|
61 |
def log_samples_to_wandb(self, dataset):
|
62 |
samples_table = self.create_samples_table(dataset)
|
63 |
wandb.log({"sample_predictions": samples_table})
|
64 |
-
|
|
|
|
|
65 |
|
66 |
class ManualTrainer:
|
67 |
def __init__(
|
68 |
self, model, tokenizer, train_dataset, val_dataset, holdout_dataset, reference_dataset,
|
69 |
-
args, data_collator, compute_metrics,
|
70 |
):
|
71 |
self.accelerator = Accelerator()
|
72 |
self.model = model
|
@@ -74,10 +110,13 @@ class ManualTrainer:
|
|
74 |
self.args = args
|
75 |
self.data_collator = data_collator
|
76 |
self.compute_metrics = compute_metrics
|
77 |
-
self.
|
78 |
-
self.
|
|
|
|
|
79 |
self.data_influence_tokenizer = data_influence_tokenizer
|
80 |
-
|
|
|
81 |
# Remove unused columns from datasets
|
82 |
if train_dataset:
|
83 |
self.train_dataset = self._remove_unused_columns(train_dataset, "training")
|
@@ -105,13 +144,13 @@ class ManualTrainer:
|
|
105 |
else:
|
106 |
self.val_loader = None
|
107 |
|
108 |
-
if self.
|
109 |
self.holdout_dataset = self._remove_unused_columns(holdout_dataset, "holdout")
|
110 |
self.reference_dataset = self._remove_unused_columns(reference_dataset, "reference")
|
111 |
|
112 |
self.holdout_loader = DataLoader(
|
113 |
self.holdout_dataset,
|
114 |
-
batch_size=self.
|
115 |
shuffle=True,
|
116 |
collate_fn=self.data_collator,
|
117 |
drop_last=self.args.dataloader_drop_last
|
@@ -119,7 +158,7 @@ class ManualTrainer:
|
|
119 |
|
120 |
self.reference_loader = DataLoader(
|
121 |
self.reference_dataset,
|
122 |
-
batch_size=self.
|
123 |
shuffle=False,
|
124 |
collate_fn=self.data_collator,
|
125 |
drop_last=self.args.dataloader_drop_last
|
@@ -136,11 +175,56 @@ class ManualTrainer:
|
|
136 |
self.model, self.optimizer, self.full_train_loader, self.val_loader
|
137 |
)
|
138 |
|
139 |
-
|
|
|
140 |
# Prepare holdout and reference loaders for Accelerator
|
141 |
-
self.
|
142 |
-
self.
|
143 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
def _remove_unused_columns(self, dataset, description=None):
|
146 |
"""
|
@@ -188,15 +272,15 @@ class ManualTrainer:
|
|
188 |
|
189 |
for epoch in range(self.args.num_train_epochs):
|
190 |
# Check if it's time to update the data influence model and state is True
|
191 |
-
if self.
|
192 |
print("Updating the data influence model and selecting high-quality data...")
|
193 |
-
logger.info("Updating the data influence model and selecting high-quality data...")
|
194 |
self.update_data_influence_model()
|
195 |
|
196 |
# Filter high-quality data using the data influence model
|
197 |
high_quality_indices = self.select_high_quality_data(
|
198 |
dataset_size=len(self.train_dataset),
|
199 |
-
selection_fraction=self.
|
200 |
)
|
201 |
self.train_loader = self.accelerator.prepare(
|
202 |
self.create_filtered_dataloader(high_quality_indices)
|
@@ -224,16 +308,16 @@ class ManualTrainer:
|
|
224 |
epoch_loss += loss.item()
|
225 |
|
226 |
if (step + 1) % self.args.logging_steps == 0:
|
227 |
-
|
228 |
-
logger.info(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
|
229 |
|
230 |
avg_epoch_loss = epoch_loss / len(self.train_loader)
|
231 |
training_loss.append(avg_epoch_loss)
|
232 |
|
233 |
val_results = self.evaluate()
|
234 |
|
235 |
-
|
236 |
-
logger
|
237 |
|
238 |
# Early stopping logic
|
239 |
if val_results["eval_loss"] < best_val_loss:
|
@@ -243,6 +327,7 @@ class ManualTrainer:
|
|
243 |
early_stopping_counter += 1
|
244 |
if early_stopping_counter >= early_stopping_patience:
|
245 |
print("Early stopping triggered")
|
|
|
246 |
break
|
247 |
|
248 |
return {"training_loss": sum(training_loss) / len(training_loss), "best_val_loss": best_val_loss}
|
@@ -252,14 +337,20 @@ class ManualTrainer:
|
|
252 |
Use the data influence model to predict quality scores and select high-quality data indices.
|
253 |
"""
|
254 |
print("Selecting high-quality data using the data influence model...")
|
|
|
255 |
|
256 |
# Predict influence scores for the entire dataset
|
257 |
influence_scores = []
|
258 |
-
self.
|
259 |
influence_optimizer = self.accelerator.prepare(
|
260 |
-
torch.optim.AdamW(
|
|
|
|
|
|
|
261 |
)
|
262 |
i = 0
|
|
|
|
|
263 |
with torch.no_grad():
|
264 |
for batch in self.full_train_loader: # Full dataset loader
|
265 |
text = self.tokenizer.batch_decode(
|
@@ -278,12 +369,13 @@ class ManualTrainer:
|
|
278 |
|
279 |
# Train the data influence model
|
280 |
influence_optimizer.zero_grad()
|
281 |
-
|
282 |
input_ids=bert_inputs['input_ids'],
|
283 |
attention_mask=bert_inputs['attention_mask'],
|
284 |
)
|
285 |
|
286 |
-
|
|
|
287 |
|
288 |
i += 1
|
289 |
|
@@ -293,6 +385,7 @@ class ManualTrainer:
|
|
293 |
# Normalize influence scores and apply Gumbel-Top-$k$ selection
|
294 |
influence_scores = np.array(influence_scores)
|
295 |
print(">> Influence scores shape:", influence_scores.shape)
|
|
|
296 |
|
297 |
# Add Gumbel noise for diversity
|
298 |
rng = np.random.default_rng()
|
@@ -303,7 +396,12 @@ class ManualTrainer:
|
|
303 |
selection_size = int(len(influence_scores)*selection_fraction)
|
304 |
high_quality_indices = np.argpartition(-influence_scores, selection_size)[:selection_size]
|
305 |
print(f"Selected {len(high_quality_indices)} high-quality samples.")
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
307 |
return high_quality_indices
|
308 |
|
309 |
def create_filtered_dataloader(self, indices):
|
@@ -311,6 +409,7 @@ class ManualTrainer:
|
|
311 |
Create a new dataloader with only the selected high-quality data.
|
312 |
"""
|
313 |
print("Creating a filtered dataloader with selected high-quality data...")
|
|
|
314 |
subset_dataset = torch.utils.data.Subset(self.train_dataset, indices)
|
315 |
return torch.utils.data.DataLoader(
|
316 |
subset_dataset,
|
@@ -325,16 +424,18 @@ class ManualTrainer:
|
|
325 |
# Train a copy of the model on holdout data and validate on reference data
|
326 |
copied_model = copy.deepcopy(self.model)
|
327 |
copied_model.train()
|
|
|
|
|
328 |
optimizer = self.accelerator.prepare(
|
329 |
torch.optim.Adam(copied_model.parameters(), lr=self.args.learning_rate)
|
330 |
)
|
331 |
holdout_reference_pairs = []
|
332 |
|
333 |
-
|
334 |
-
logger.info("Starting to collect holdout-reference pairs...")
|
335 |
for step, holdout_batch in enumerate(self.holdout_loader):
|
336 |
-
|
337 |
-
logger.info(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
|
338 |
|
339 |
optimizer.zero_grad()
|
340 |
outputs = copied_model(
|
@@ -352,7 +453,7 @@ class ManualTrainer:
|
|
352 |
optimizer.step()
|
353 |
|
354 |
print(f"Evaluating reference losses at step {step}...")
|
355 |
-
logger.info(f"Evaluating reference losses at step {step}...")
|
356 |
|
357 |
copied_model.eval()
|
358 |
reference_losses = []
|
@@ -373,42 +474,138 @@ class ManualTrainer:
|
|
373 |
|
374 |
# Train the data influence model using the generated pairs
|
375 |
print("Starting to train the data influence model...")
|
376 |
-
logger.info("Starting to train the data influence model...")
|
377 |
|
378 |
-
self.
|
379 |
-
influence_optimizer = torch.optim.AdamW(self.
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
# Tokenize the text using the BERT tokenizer
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
393 |
|
394 |
# Train the data influence model
|
395 |
influence_optimizer.zero_grad()
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
400 |
)
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
403 |
influence_loss.backward()
|
404 |
influence_optimizer.step()
|
405 |
-
|
406 |
if step % 50 == 0:
|
407 |
print(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
|
408 |
-
logger.info(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
|
|
|
409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
|
414 |
|
@@ -459,7 +656,8 @@ class ManualTrainer:
|
|
459 |
metrics = self.compute_metrics({"predictions": padded_preds, "label_ids": padded_labels})
|
460 |
|
461 |
metrics.update({"eval_loss": val_loss / len(self.val_loader)})
|
462 |
-
print("Validation Metrics:
|
|
|
463 |
|
464 |
if wandb_sample:
|
465 |
# Sample Logging
|
|
|
3 |
import torch
|
4 |
import copy
|
5 |
import numpy as np
|
6 |
+
from transformers import (
|
7 |
+
# BertForSequenceClassification,
|
8 |
+
GenerationConfig,
|
9 |
+
AutoTokenizer,
|
10 |
+
Trainer,
|
11 |
+
get_scheduler,
|
12 |
+
EarlyStoppingCallback,
|
13 |
+
TrainingArguments,
|
14 |
+
DataCollatorWithPadding
|
15 |
+
)
|
16 |
+
from datasets import Dataset
|
17 |
+
from .skipbert.trainer import compute_metrics_skipbert, SkipBertTrainer
|
18 |
+
|
19 |
import inspect
|
20 |
import logging
|
21 |
import wandb
|
22 |
from tqdm import tqdm
|
23 |
+
import time
|
24 |
+
from functools import partial
|
25 |
+
|
26 |
+
logging.getLogger("Trainer").setLevel(logging.INFO)
|
27 |
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
30 |
+
|
31 |
+
def time_format(runtime, logger):
|
32 |
+
if runtime < 60:
|
33 |
+
logger.info(f'Runtime: {runtime:.2f} seconds')
|
34 |
+
elif runtime < 3600: # Less than one hour
|
35 |
+
minutes = runtime / 60
|
36 |
+
logger.info(f'Runtime: {minutes:.2f} minutes')
|
37 |
+
else:
|
38 |
+
hours = runtime / 3600
|
39 |
+
logger.info(f'Runtime: {hours:.2f} hours')
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def convert_to_tokens_reg(data, tokenizer, max_seq_length, device):
|
44 |
+
input_tokenzied = tokenizer(data['text'], truncation=True, padding=True, max_length=max_seq_length, return_tensors="pt")
|
45 |
+
input_tokenzied['labels'] = torch.tensor(data['label'], dtype=torch.float32).reshape(-1, 1)
|
46 |
+
return input_tokenzied
|
47 |
+
|
48 |
class ManualLLMSampleCB:
|
49 |
def __init__(self, model, tokenizer, task, num_samples=10, max_new_tokens=256):
|
50 |
self.model = model
|
|
|
95 |
def log_samples_to_wandb(self, dataset):
|
96 |
samples_table = self.create_samples_table(dataset)
|
97 |
wandb.log({"sample_predictions": samples_table})
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
|
102 |
class ManualTrainer:
|
103 |
def __init__(
|
104 |
self, model, tokenizer, train_dataset, val_dataset, holdout_dataset, reference_dataset,
|
105 |
+
args, data_collator, compute_metrics, mates_cfg, skipbert_cfg, teacher_data_influence_model, student_data_influence_model, data_influence_tokenizer
|
106 |
):
|
107 |
self.accelerator = Accelerator()
|
108 |
self.model = model
|
|
|
110 |
self.args = args
|
111 |
self.data_collator = data_collator
|
112 |
self.compute_metrics = compute_metrics
|
113 |
+
self.mates_cfg = mates_cfg
|
114 |
+
self.skipbert_cfg = skipbert_cfg
|
115 |
+
self.teacher_data_influence_model = teacher_data_influence_model
|
116 |
+
self.student_data_influence_model = student_data_influence_model
|
117 |
self.data_influence_tokenizer = data_influence_tokenizer
|
118 |
+
|
119 |
+
|
120 |
# Remove unused columns from datasets
|
121 |
if train_dataset:
|
122 |
self.train_dataset = self._remove_unused_columns(train_dataset, "training")
|
|
|
144 |
else:
|
145 |
self.val_loader = None
|
146 |
|
147 |
+
if self.mates_cfg.state:
|
148 |
self.holdout_dataset = self._remove_unused_columns(holdout_dataset, "holdout")
|
149 |
self.reference_dataset = self._remove_unused_columns(reference_dataset, "reference")
|
150 |
|
151 |
self.holdout_loader = DataLoader(
|
152 |
self.holdout_dataset,
|
153 |
+
batch_size=self.mates_cfg.holdout_batch_size,
|
154 |
shuffle=True,
|
155 |
collate_fn=self.data_collator,
|
156 |
drop_last=self.args.dataloader_drop_last
|
|
|
158 |
|
159 |
self.reference_loader = DataLoader(
|
160 |
self.reference_dataset,
|
161 |
+
batch_size=self.mates_cfg.reference_batch_size,
|
162 |
shuffle=False,
|
163 |
collate_fn=self.data_collator,
|
164 |
drop_last=self.args.dataloader_drop_last
|
|
|
175 |
self.model, self.optimizer, self.full_train_loader, self.val_loader
|
176 |
)
|
177 |
|
178 |
+
### Define for MATEs ###
|
179 |
+
if self.mates_cfg.state:
|
180 |
# Prepare holdout and reference loaders for Accelerator
|
181 |
+
self.teacher_data_influence_model, self.holdout_loader, self.reference_loader = self.accelerator.prepare(
|
182 |
+
self.teacher_data_influence_model, self.holdout_loader, self.reference_loader
|
183 |
)
|
184 |
+
self.student_data_influence_model = self.accelerator.prepare(self.student_data_influence_model)
|
185 |
+
######
|
186 |
+
|
187 |
+
### Define for SkipBERT ###
|
188 |
+
self.skipbert_train_args = TrainingArguments(
|
189 |
+
output_dir=self.skipbert_cfg.output_dir,
|
190 |
+
learning_rate=self.skipbert_cfg.learning_rate,
|
191 |
+
num_train_epochs=self.skipbert_cfg.num_train_epochs,
|
192 |
+
per_device_train_batch_size=self.skipbert_cfg.train_batch_size,
|
193 |
+
gradient_accumulation_steps=self.skipbert_cfg.gradient_accumulation_steps,
|
194 |
+
per_device_eval_batch_size=self.skipbert_cfg.eval_batch_size,
|
195 |
+
eval_accumulation_steps=self.skipbert_cfg.eval_accumulation_steps,
|
196 |
+
max_steps=self.skipbert_cfg.max_steps,
|
197 |
+
logging_steps = 10,
|
198 |
+
evaluation_strategy=self.skipbert_cfg.evaluation_strategy,
|
199 |
+
save_strategy=self.skipbert_cfg.save_strategy,
|
200 |
+
lr_scheduler_type=self.skipbert_cfg.lr_scheduler_type,
|
201 |
+
warmup_steps=self.skipbert_cfg.warmup_steps,
|
202 |
+
weight_decay=self.skipbert_cfg.weight_decay,
|
203 |
+
logging_dir=self.skipbert_cfg.logging_dir,
|
204 |
+
report_to='wandb',
|
205 |
+
run_name='skipbert',
|
206 |
+
do_train=self.skipbert_cfg.do_train,
|
207 |
+
do_eval=self.skipbert_cfg.do_eval,
|
208 |
+
dataloader_drop_last=False,
|
209 |
+
ddp_find_unused_parameters=False,
|
210 |
+
group_by_length=True,
|
211 |
+
load_best_model_at_end = True
|
212 |
+
)
|
213 |
+
|
214 |
+
# Prepare custom optimizer student model's parameters
|
215 |
+
if self.student_data_influence_model is not None:
|
216 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
217 |
+
self.student_optimizer_grouped_parameters = [
|
218 |
+
{
|
219 |
+
'params': [p for n, p in self.student_data_influence_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
220 |
+
'weight_decay': 0.01
|
221 |
+
},
|
222 |
+
{
|
223 |
+
'params': [p for n, p in self.student_data_influence_model.named_parameters() if any(nd in n for nd in no_decay)],
|
224 |
+
'weight_decay': 0.0
|
225 |
+
}
|
226 |
+
]
|
227 |
+
######
|
228 |
|
229 |
def _remove_unused_columns(self, dataset, description=None):
|
230 |
"""
|
|
|
272 |
|
273 |
for epoch in range(self.args.num_train_epochs):
|
274 |
# Check if it's time to update the data influence model and state is True
|
275 |
+
if self.mates_cfg.state and epoch % self.mates_cfg.update_data_influence_model_step == 0:
|
276 |
print("Updating the data influence model and selecting high-quality data...")
|
277 |
+
# logger.info("Updating the data influence model and selecting high-quality data...")
|
278 |
self.update_data_influence_model()
|
279 |
|
280 |
# Filter high-quality data using the data influence model
|
281 |
high_quality_indices = self.select_high_quality_data(
|
282 |
dataset_size=len(self.train_dataset),
|
283 |
+
selection_fraction=self.mates_cfg.selection_fraction,
|
284 |
)
|
285 |
self.train_loader = self.accelerator.prepare(
|
286 |
self.create_filtered_dataloader(high_quality_indices)
|
|
|
308 |
epoch_loss += loss.item()
|
309 |
|
310 |
if (step + 1) % self.args.logging_steps == 0:
|
311 |
+
print(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
|
312 |
+
# logger.info(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
|
313 |
|
314 |
avg_epoch_loss = epoch_loss / len(self.train_loader)
|
315 |
training_loss.append(avg_epoch_loss)
|
316 |
|
317 |
val_results = self.evaluate()
|
318 |
|
319 |
+
print(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
|
320 |
+
# logger.info(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
|
321 |
|
322 |
# Early stopping logic
|
323 |
if val_results["eval_loss"] < best_val_loss:
|
|
|
327 |
early_stopping_counter += 1
|
328 |
if early_stopping_counter >= early_stopping_patience:
|
329 |
print("Early stopping triggered")
|
330 |
+
# logger.info("Early stopping triggered")
|
331 |
break
|
332 |
|
333 |
return {"training_loss": sum(training_loss) / len(training_loss), "best_val_loss": best_val_loss}
|
|
|
337 |
Use the data influence model to predict quality scores and select high-quality data indices.
|
338 |
"""
|
339 |
print("Selecting high-quality data using the data influence model...")
|
340 |
+
# logger.info("Selecting high-quality data using the data influence model...")
|
341 |
|
342 |
# Predict influence scores for the entire dataset
|
343 |
influence_scores = []
|
344 |
+
self.student_data_influence_model.eval()
|
345 |
influence_optimizer = self.accelerator.prepare(
|
346 |
+
torch.optim.AdamW(
|
347 |
+
self.student_optimizer_grouped_parameters,
|
348 |
+
lr=self.skipbert_train_args.learning_rate,
|
349 |
+
)
|
350 |
)
|
351 |
i = 0
|
352 |
+
|
353 |
+
start_time = time.perf_counter()
|
354 |
with torch.no_grad():
|
355 |
for batch in self.full_train_loader: # Full dataset loader
|
356 |
text = self.tokenizer.batch_decode(
|
|
|
369 |
|
370 |
# Train the data influence model
|
371 |
influence_optimizer.zero_grad()
|
372 |
+
logits, attn_outputs, hidn_output = self.student_data_influence_model(
|
373 |
input_ids=bert_inputs['input_ids'],
|
374 |
attention_mask=bert_inputs['attention_mask'],
|
375 |
)
|
376 |
|
377 |
+
|
378 |
+
influence_scores.extend(logits.squeeze(-1).cpu().numpy())
|
379 |
|
380 |
i += 1
|
381 |
|
|
|
385 |
# Normalize influence scores and apply Gumbel-Top-$k$ selection
|
386 |
influence_scores = np.array(influence_scores)
|
387 |
print(">> Influence scores shape:", influence_scores.shape)
|
388 |
+
# logger.info(">> Influence scores shape:", influence_scores.shape)
|
389 |
|
390 |
# Add Gumbel noise for diversity
|
391 |
rng = np.random.default_rng()
|
|
|
396 |
selection_size = int(len(influence_scores)*selection_fraction)
|
397 |
high_quality_indices = np.argpartition(-influence_scores, selection_size)[:selection_size]
|
398 |
print(f"Selected {len(high_quality_indices)} high-quality samples.")
|
399 |
+
# logger.info(f"Selected {len(high_quality_indices)} high-quality samples.")
|
400 |
+
|
401 |
+
end_time = time.perf_counter()
|
402 |
+
runtime = round((end_time - start_time), 2)
|
403 |
+
time_format(runtime, logger)
|
404 |
+
|
405 |
return high_quality_indices
|
406 |
|
407 |
def create_filtered_dataloader(self, indices):
|
|
|
409 |
Create a new dataloader with only the selected high-quality data.
|
410 |
"""
|
411 |
print("Creating a filtered dataloader with selected high-quality data...")
|
412 |
+
# logger.info("Creating a filtered dataloader with selected high-quality data...")
|
413 |
subset_dataset = torch.utils.data.Subset(self.train_dataset, indices)
|
414 |
return torch.utils.data.DataLoader(
|
415 |
subset_dataset,
|
|
|
424 |
# Train a copy of the model on holdout data and validate on reference data
|
425 |
copied_model = copy.deepcopy(self.model)
|
426 |
copied_model.train()
|
427 |
+
self.accelerator.state._reset_state()
|
428 |
+
self.accelerator = Accelerator()
|
429 |
optimizer = self.accelerator.prepare(
|
430 |
torch.optim.Adam(copied_model.parameters(), lr=self.args.learning_rate)
|
431 |
)
|
432 |
holdout_reference_pairs = []
|
433 |
|
434 |
+
print("Starting to collect holdout-reference pairs...")
|
435 |
+
# logger.info("Starting to collect holdout-reference pairs...")
|
436 |
for step, holdout_batch in enumerate(self.holdout_loader):
|
437 |
+
print(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
|
438 |
+
# logger.info(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
|
439 |
|
440 |
optimizer.zero_grad()
|
441 |
outputs = copied_model(
|
|
|
453 |
optimizer.step()
|
454 |
|
455 |
print(f"Evaluating reference losses at step {step}...")
|
456 |
+
# logger.info(f"Evaluating reference losses at step {step}...")
|
457 |
|
458 |
copied_model.eval()
|
459 |
reference_losses = []
|
|
|
474 |
|
475 |
# Train the data influence model using the generated pairs
|
476 |
print("Starting to train the data influence model...")
|
477 |
+
# logger.info("Starting to train the data influence model...")
|
478 |
|
479 |
+
self.teacher_data_influence_model.train()
|
480 |
+
influence_optimizer = torch.optim.AdamW(self.teacher_data_influence_model.parameters(), lr=self.args.learning_rate)
|
481 |
+
|
482 |
+
list_texts, list_score = [], []
|
483 |
+
batch_size = 0
|
484 |
+
# Convert to Dataset objective
|
485 |
+
for texts, score in holdout_reference_pairs:
|
486 |
+
if batch_size == 0:
|
487 |
+
batch_size = len(texts)
|
488 |
+
list_texts.extend(texts)
|
489 |
+
list_score.extend([score] * len(texts))
|
490 |
+
|
491 |
+
|
492 |
+
holdout_reference_pairs = {'text': list_texts, 'label': list_score}
|
493 |
+
holdout_reference_pairs = Dataset.from_dict(holdout_reference_pairs)
|
494 |
+
|
495 |
+
# Wrap the function with partial
|
496 |
+
convert_func = partial(
|
497 |
+
convert_to_tokens_reg,
|
498 |
+
tokenizer=self.data_influence_tokenizer,
|
499 |
+
max_seq_length=self.skipbert_cfg.max_seq_length,
|
500 |
+
device=self.accelerator.device
|
501 |
+
)
|
502 |
+
holdout_reference_pairs_loader = DataLoader(
|
503 |
+
holdout_reference_pairs.map(
|
504 |
+
convert_func,
|
505 |
+
batched=True,
|
506 |
+
num_proc=8,
|
507 |
+
remove_columns=holdout_reference_pairs.column_names
|
508 |
+
),
|
509 |
+
batch_size=batch_size,
|
510 |
+
collate_fn=DataCollatorWithPadding(tokenizer=self.data_influence_tokenizer, padding=True, max_length=self.skipbert_cfg.max_seq_length), # Use the same collate function
|
511 |
+
drop_last=self.args.dataloader_drop_last
|
512 |
+
)
|
513 |
+
loss_mse = torch.nn.MSELoss()
|
514 |
+
|
515 |
+
for step, batch_input in enumerate(holdout_reference_pairs_loader):
|
516 |
# Tokenize the text using the BERT tokenizer
|
517 |
+
batch_input = {k: v.to('cuda:0') for k, v in batch_input.items()} # cuda:0
|
518 |
+
# text, score = row['text'], row['label']
|
519 |
+
# bert_inputs = self.data_influence_tokenizer(
|
520 |
+
# text,
|
521 |
+
# truncation=True,
|
522 |
+
# padding='max_length',
|
523 |
+
# max_length=256,
|
524 |
+
# return_tensors='pt'
|
525 |
+
# ).to(self.accelerator.device)
|
526 |
+
|
527 |
+
# # Convert score to tensor and enable gradients
|
528 |
+
# score_tensor = torch.tensor([score] * len(text), device=self.accelerator.device, dtype=torch.float32, requires_grad=True)
|
529 |
|
530 |
# Train the data influence model
|
531 |
influence_optimizer.zero_grad()
|
532 |
+
# outputs = self.teacher_data_influence_model(
|
533 |
+
# input_ids=bert_inputs['input_ids'],
|
534 |
+
# attention_mask=bert_inputs['attention_mask'],
|
535 |
+
# labels=score_tensor
|
536 |
+
# )
|
537 |
+
|
538 |
+
outputs = self.teacher_data_influence_model(
|
539 |
+
**batch_input
|
540 |
)
|
541 |
+
|
542 |
+
influence_loss = loss_mse(batch_input['labels'].view(-1), outputs.logits.view(-1))
|
543 |
+
# print(f"Loss: {influence_loss} - require_grad: {influence_loss.grad_fn}")
|
544 |
+
|
545 |
+
influence_loss.requires_grad = True
|
546 |
influence_loss.backward()
|
547 |
influence_optimizer.step()
|
548 |
+
|
549 |
if step % 50 == 0:
|
550 |
print(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
|
551 |
+
# logger.info(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
|
552 |
+
|
553 |
|
554 |
+
### Distillation for SkipBERT ###
|
555 |
+
train_converted = holdout_reference_pairs.map(
|
556 |
+
convert_func,
|
557 |
+
batched=True,
|
558 |
+
num_proc=8,
|
559 |
+
remove_columns=holdout_reference_pairs.column_names
|
560 |
+
)
|
561 |
|
562 |
+
|
563 |
+
# Call parent constructor with custom optimizer
|
564 |
+
optimizer = torch.optim.AdamW(
|
565 |
+
self.student_optimizer_grouped_parameters,
|
566 |
+
lr=self.skipbert_train_args.learning_rate,
|
567 |
+
)
|
568 |
+
|
569 |
+
scheduler = get_scheduler(
|
570 |
+
name=self.skipbert_train_args.lr_scheduler_type,
|
571 |
+
optimizer=optimizer,
|
572 |
+
num_warmup_steps=self.skipbert_train_args.warmup_steps,
|
573 |
+
# num_training_steps=training_args.max_steps
|
574 |
+
num_training_steps=100/(self.skipbert_train_args.per_device_train_batch_size * self.skipbert_train_args.gradient_accumulation_steps)
|
575 |
+
)
|
576 |
+
|
577 |
+
# Initialize the trainer
|
578 |
+
trainer = SkipBertTrainer(
|
579 |
+
student_model=self.student_data_influence_model,
|
580 |
+
teacher_model=self.teacher_data_influence_model,
|
581 |
+
args=self.skipbert_train_args,
|
582 |
+
train_dataset=train_converted,
|
583 |
+
eval_dataset=train_converted.shuffle().select(range(min(len(train_converted),10))),
|
584 |
+
compute_metrics=compute_metrics_skipbert,
|
585 |
+
# SkipBERT specific arguments
|
586 |
+
alpha=0.5,
|
587 |
+
temperature=2.0,
|
588 |
+
beta=1.0,
|
589 |
+
use_logits=self.skipbert_cfg.use_logits,
|
590 |
+
use_att=self.skipbert_cfg.use_att,
|
591 |
+
use_rep=self.skipbert_cfg.use_rep,
|
592 |
+
use_embedding=self.skipbert_cfg.use_embedding,
|
593 |
+
att_layer_maps=self.skipbert_cfg.att_layer_maps,
|
594 |
+
hid_layer_maps=self.skipbert_cfg.hid_layer_maps,
|
595 |
+
epochs_no_cls=self.skipbert_cfg.epochs_no_cls,
|
596 |
+
reduce_T=self.skipbert_cfg.reduce_T,
|
597 |
+
output_mode=self.skipbert_cfg.output_mode, # 'classification' or 'regression'
|
598 |
+
num_masked_layers_teacher=self.skipbert_cfg.num_masked_layers_teacher,
|
599 |
+
num_masked_last_layers_teacher=self.skipbert_cfg.num_masked_last_layers_teacher,
|
600 |
+
fp16=self.skipbert_cfg.fp16,
|
601 |
+
num_full_hidden_layers_student=self.skipbert_cfg.num_full_hidden_layers_student,
|
602 |
+
tokenizer=self.data_influence_tokenizer,
|
603 |
+
optimizers=(optimizer,scheduler),
|
604 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
|
605 |
+
)
|
606 |
+
|
607 |
+
# Train the model
|
608 |
+
trainer.train()
|
609 |
|
610 |
|
611 |
|
|
|
656 |
metrics = self.compute_metrics({"predictions": padded_preds, "label_ids": padded_labels})
|
657 |
|
658 |
metrics.update({"eval_loss": val_loss / len(self.val_loader)})
|
659 |
+
print(f"Validation Metrics: {metrics}")
|
660 |
+
# logger.info(f"Validation Metrics: {metrics}")
|
661 |
|
662 |
if wandb_sample:
|
663 |
# Sample Logging
|
template_FL/src/pyproject.toml
CHANGED
@@ -37,14 +37,14 @@ num-server-rounds = 2
|
|
37 |
num-supernodes = 10
|
38 |
|
39 |
# Define dataset
|
40 |
-
dataset.type = '
|
41 |
dataset.name = "vicgalle/alpaca-gpt4"
|
42 |
|
43 |
# Define model settings
|
44 |
model.name = "Qwen/Qwen2.5-1.5B-Instruct"
|
45 |
model.quantization = 4
|
46 |
model.gradient-checkpointing = true
|
47 |
-
model.
|
48 |
|
49 |
### Use MATES ###
|
50 |
mates.state = true
|
@@ -60,67 +60,66 @@ mates.selection-fraction = 0.4
|
|
60 |
|
61 |
# Model setting
|
62 |
skipbert.student-model = "bert-base-uncased"
|
63 |
-
skipbert.
|
64 |
-
skipbert.
|
65 |
-
skipbert.
|
66 |
-
skipbert.
|
|
|
67 |
|
68 |
# Training hyperparameters
|
69 |
-
skipbert.
|
70 |
-
skipbert.
|
71 |
-
skipbert.
|
72 |
-
skipbert.
|
73 |
-
skipbert.
|
74 |
-
skipbert.
|
75 |
-
skipbert.
|
76 |
-
skipbert.
|
77 |
-
skipbert.
|
78 |
-
skipbert.
|
79 |
-
skipbert.
|
80 |
-
skipbert.
|
81 |
-
skipbert.
|
82 |
-
skipbert.
|
83 |
-
skipbert.
|
84 |
-
skipbert.
|
85 |
-
skipbert.
|
86 |
-
skipbert.
|
87 |
-
skipbert.
|
|
|
88 |
|
89 |
# Knowledge distillation parameters
|
90 |
skipbert.beta = 0.01
|
91 |
skipbert.T = 1.0
|
92 |
skipbert.alpha = 1.0
|
93 |
-
skipbert.
|
94 |
-
skipbert.
|
95 |
|
96 |
# Training schedule and features
|
97 |
-
skipbert.
|
98 |
|
99 |
# Feature usage flags
|
100 |
-
skipbert.
|
101 |
-
skipbert.
|
102 |
-
skipbert.
|
103 |
-
skipbert.
|
104 |
|
105 |
# Training modes
|
106 |
-
skipbert.
|
107 |
-
skipbert.do_eval = true
|
108 |
-
skipbert.do_predict = false
|
109 |
-
skipbert.do_fit = false
|
110 |
skipbert.fp16 = false
|
111 |
-
skipbert.
|
112 |
-
skipbert.
|
113 |
-
skipbert.
|
114 |
-
skipbert.
|
115 |
-
skipbert.
|
116 |
|
117 |
# N-gram settings
|
118 |
-
skipbert.
|
119 |
-
skipbert.
|
120 |
|
121 |
# Layer mappings
|
122 |
-
skipbert.
|
123 |
-
skipbert.
|
124 |
|
125 |
### END ###
|
126 |
|
@@ -138,8 +137,8 @@ train.save-every-round = 5
|
|
138 |
train.learning-rate-max = 5e-5
|
139 |
train.learning-rate-min = 1e-6
|
140 |
train.seq-length = 256
|
141 |
-
train.
|
142 |
-
train.
|
143 |
train.verbose = false
|
144 |
|
145 |
# Define training agruments for HF Trainer
|
@@ -164,7 +163,7 @@ train.training-arguments.eval-strategy = "epoch"
|
|
164 |
train.training-arguments.save-strategy = "epoch"
|
165 |
train.training-arguments.ddp-find-unused-parameters = false
|
166 |
train.training-arguments.group-by-length = true
|
167 |
-
train.training-arguments.
|
168 |
train.training-arguments.report-to = "wandb"
|
169 |
|
170 |
# Define local training settings
|
|
|
37 |
num-supernodes = 10
|
38 |
|
39 |
# Define dataset
|
40 |
+
dataset.type = 'homo' # type = ['homo','hete']
|
41 |
dataset.name = "vicgalle/alpaca-gpt4"
|
42 |
|
43 |
# Define model settings
|
44 |
model.name = "Qwen/Qwen2.5-1.5B-Instruct"
|
45 |
model.quantization = 4
|
46 |
model.gradient-checkpointing = true
|
47 |
+
model.flash-attention = false
|
48 |
|
49 |
### Use MATES ###
|
50 |
mates.state = true
|
|
|
60 |
|
61 |
# Model setting
|
62 |
skipbert.student-model = "bert-base-uncased"
|
63 |
+
skipbert.output-mode = "regression"
|
64 |
+
skipbert.num-layers-student = 12
|
65 |
+
skipbert.num-full-hidden-layers-student = 6
|
66 |
+
skipbert.num-masked-layers-teacher = 0
|
67 |
+
skipbert.num-masked-last-layers-teacher = 0
|
68 |
|
69 |
# Training hyperparameters
|
70 |
+
skipbert.train-batch-size = 4
|
71 |
+
skipbert.gradient-accumulation-steps = 1
|
72 |
+
skipbert.eval-batch-size = 4
|
73 |
+
skipbert.eval-accumulation-steps = 1
|
74 |
+
skipbert.learning-rate = 2.0e-5
|
75 |
+
skipbert.num-train-epochs = 10
|
76 |
+
skipbert.eval-step = 10
|
77 |
+
skipbert.max-seq-length = 256
|
78 |
+
skipbert.weight-decay = 1.0e-4
|
79 |
+
skipbert.warmup-steps = 100 # 500
|
80 |
+
skipbert.do-train = true
|
81 |
+
skipbert.do-eval = true
|
82 |
+
skipbert.do-predict = false
|
83 |
+
skipbert.max-steps = -1
|
84 |
+
skipbert.evaluation-strategy = "epoch"
|
85 |
+
skipbert.save-strategy = "epoch"
|
86 |
+
skipbert.lr-scheduler-type = "cosine" # or 'warmup-linear'
|
87 |
+
skipbert.logging-dir = './skipbert-logs'
|
88 |
+
skipbert.output-dir = "./skipbert-results"
|
89 |
+
skipbert.report-to = 'wandb'
|
90 |
|
91 |
# Knowledge distillation parameters
|
92 |
skipbert.beta = 0.01
|
93 |
skipbert.T = 1.0
|
94 |
skipbert.alpha = 1.0
|
95 |
+
skipbert.reduce-T = 1.0
|
96 |
+
skipbert.epochs-no-cls = 5
|
97 |
|
98 |
# Training schedule and features
|
99 |
+
skipbert.freeze-lower-layers = true
|
100 |
|
101 |
# Feature usage flags
|
102 |
+
skipbert.use-logits = true
|
103 |
+
skipbert.use-att = true
|
104 |
+
skipbert.use-rep = true
|
105 |
+
skipbert.use-embedding = false
|
106 |
|
107 |
# Training modes
|
108 |
+
skipbert.do-fit = false
|
|
|
|
|
|
|
109 |
skipbert.fp16 = false
|
110 |
+
skipbert.no-pretrain = false
|
111 |
+
skipbert.use-init-weight = false
|
112 |
+
skipbert.share-param = true
|
113 |
+
skipbert.do-lower-case = true
|
114 |
+
skipbert.no-cuda = false
|
115 |
|
116 |
# N-gram settings
|
117 |
+
skipbert.n-gram-left = 1
|
118 |
+
skipbert.n-gram-right = 1
|
119 |
|
120 |
# Layer mappings
|
121 |
+
skipbert.att-layer-maps = "1, 3, 5, 7, 9, 11"
|
122 |
+
skipbert.hid-layer-maps = "6, 7, 8, 9, 10, 11, 12"
|
123 |
|
124 |
### END ###
|
125 |
|
|
|
137 |
train.learning-rate-max = 5e-5
|
138 |
train.learning-rate-min = 1e-6
|
139 |
train.seq-length = 256
|
140 |
+
train.prompt-template-name = "alpaca"
|
141 |
+
train.train-on-inputs = true
|
142 |
train.verbose = false
|
143 |
|
144 |
# Define training agruments for HF Trainer
|
|
|
163 |
train.training-arguments.save-strategy = "epoch"
|
164 |
train.training-arguments.ddp-find-unused-parameters = false
|
165 |
train.training-arguments.group-by-length = true
|
166 |
+
train.training-arguments.load-best-model-at-end = true
|
167 |
train.training-arguments.report-to = "wandb"
|
168 |
|
169 |
# Define local training settings
|