kisejin commited on
Commit
671e27a
·
1 Parent(s): 808a032

change: update skipbert mechanism

Browse files
template_FL/src/fedllm/client_app.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import os
4
  import warnings
 
5
  from typing import Dict, Tuple
6
 
7
  import torch
@@ -13,8 +14,16 @@ from flwr.common.config import unflatten_dict
13
  from flwr.common.typing import NDArrays, Scalar
14
  from omegaconf import DictConfig
15
 
16
- from transformers import TrainingArguments, DataCollatorForSeq2Seq, Trainer, EarlyStoppingCallback, BertForSequenceClassification, GenerationConfig
 
 
 
 
 
 
 
17
 
 
18
  from trl import SFTTrainer, SFTConfig
19
  from deepspeed.profiling.flops_profiler import get_model_profile
20
  from deepspeed.accelerator import get_accelerator
@@ -40,6 +49,11 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
40
  os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
41
  warnings.filterwarnings("ignore", category=UserWarning)
42
 
 
 
 
 
 
43
 
44
  def input_constructor(batch_size, seq_len, tokenizer):
45
  fake_seq = ""
@@ -80,13 +94,15 @@ class FlowerClient(NumPyClient):
80
  self,
81
  model_cfg: DictConfig,
82
  train_cfg: DictConfig,
83
- mates_args: DictConfig,
 
84
  trainset,
85
  valset,
86
  num_rounds,
87
  ): # pylint: disable=too-many-arguments
88
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
89
  self.train_cfg = train_cfg
 
90
 
91
  self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
92
  # self.training_arguments = SFTConfig(**train_cfg.training_arguments, max_seq_length=train_cfg.seq_length)
@@ -94,17 +110,18 @@ class FlowerClient(NumPyClient):
94
  self.num_rounds = num_rounds
95
  self.trainset = trainset
96
  self.valset = valset
97
- self.mates_args = mates_args
98
  self.holdoutset = None
99
  self.refset = None
100
- self.data_influence_model = None
 
101
  self.data_influence_tokenizer = None
102
 
103
  # instantiate model
104
  self.model, self.tokenizer = get_model(model_cfg)
105
 
106
- if self.mates_args.state:
107
- self.data_influence_model, self.data_influence_tokenizer = get_data_influence_model(model_cfg)
108
 
109
  # (
110
  # self.data_collator,
@@ -129,8 +146,8 @@ class FlowerClient(NumPyClient):
129
  # Replace -100 with pad token id in labels
130
  labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
131
 
132
- print(f"Shape of predictions: {np.shape(pred_ids)}")
133
- print(f"Shape of labels: {np.shape(labels_ids)}")
134
 
135
  # Decode predictions and labels
136
  pred_str = self.tokenizer.batch_decode(
@@ -165,6 +182,7 @@ class FlowerClient(NumPyClient):
165
  .map(
166
  lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
167
  num_proc=8,
 
168
  )
169
  )
170
 
@@ -175,16 +193,17 @@ class FlowerClient(NumPyClient):
175
  .map(
176
  lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
177
  num_proc=8,
 
178
  )
179
  )
180
 
181
  # Create holdoutset and refset if state is True
182
- if self.mates_args.state:
183
  trainset_size = len(self.trainset)
184
 
185
  # Calculate sizes for holdout and reference sets
186
- holdout_size = int(trainset_size * self.mates_args.holdout_ratio)
187
- ref_size = int(trainset_size * self.mates_args.reference_ratio)
188
 
189
  # Shuffle the trainset to ensure randomness
190
  shuffled_indices = list(range(trainset_size))
@@ -199,16 +218,17 @@ class FlowerClient(NumPyClient):
199
  self.refset = self.trainset.select(ref_indices)
200
 
201
  print(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
 
202
 
203
 
204
  def fit(
205
  self, parameters: NDArrays, config: Dict[str, Scalar]
206
  ) -> Tuple[NDArrays, int, Dict]:
207
  """Implement distributed fit function for a given client."""
208
- if self.mates_args.state and int(config["current_round"]) != 1:
209
  main_model_params, data_influence_model_params = split_models(parameters)
210
  set_parameters(self.model, main_model_params)
211
- set_parameters_bert(self.data_influence_model, data_influence_model_params)
212
  else:
213
  set_parameters(self.model, parameters)
214
 
@@ -259,18 +279,20 @@ class FlowerClient(NumPyClient):
259
  args=self.training_arguments,
260
  data_collator=self.data_collator,
261
  compute_metrics=self.compute_metrics,
262
- mates_args=self.mates_args,
263
- data_influence_model=self.data_influence_model,
 
 
264
  data_influence_tokenizer=self.data_influence_tokenizer,
265
  )
266
 
267
  # Train the model
268
  results = trainer.train()
269
 
270
- if self.mates_args.state:
271
  # After training
272
  main_model_params = get_parameters(self.model)
273
- data_influence_model_params = model_parameters_to_ndarrays(self.data_influence_model)
274
  final_model_params = concatenate_models_with_marker(main_model_params, data_influence_model_params)
275
  else:
276
  final_model_params = get_parameters(self.model)
@@ -286,7 +308,7 @@ class FlowerClient(NumPyClient):
286
  detailed=False,
287
  )
288
  flops2, macs2, params2 = get_model_profile(
289
- self.data_influence_model,
290
  kwargs=input_constructor(batch_size, seq_len, self.data_influence_tokenizer),
291
  print_profile=True,
292
  detailed=False,
@@ -315,11 +337,16 @@ def client_fn(context: Context) -> FlowerClient:
315
  client_set = load_data_homo(partition_id, num_partitions, cfg.dataset.name)
316
  else:
317
  client_set = load_data_hete(partition_id)
 
 
 
 
318
 
319
  return FlowerClient(
320
  cfg.model,
321
  cfg.train,
322
  cfg.mates,
 
323
  client_set['train'],
324
  client_set['test'],
325
  num_rounds,
 
2
 
3
  import os
4
  import warnings
5
+ import logging
6
  from typing import Dict, Tuple
7
 
8
  import torch
 
14
  from flwr.common.typing import NDArrays, Scalar
15
  from omegaconf import DictConfig
16
 
17
+ from transformers import (
18
+ TrainingArguments,
19
+ DataCollatorForSeq2Seq,
20
+ Trainer,
21
+ EarlyStoppingCallback,
22
+ # BertForSequenceClassification,
23
+ GenerationConfig
24
+ )
25
 
26
+
27
  from trl import SFTTrainer, SFTConfig
28
  from deepspeed.profiling.flops_profiler import get_model_profile
29
  from deepspeed.accelerator import get_accelerator
 
49
  os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
50
  warnings.filterwarnings("ignore", category=UserWarning)
51
 
52
+ logging.getLogger("flwr").setLevel(logging.INFO)
53
+ logging.getLogger("ClientAppActor").setLevel(logging.INFO)
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
 
58
  def input_constructor(batch_size, seq_len, tokenizer):
59
  fake_seq = ""
 
94
  self,
95
  model_cfg: DictConfig,
96
  train_cfg: DictConfig,
97
+ mates_cfg: DictConfig,
98
+ skipbert_cfg: DictConfig,
99
  trainset,
100
  valset,
101
  num_rounds,
102
  ): # pylint: disable=too-many-arguments
103
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
104
  self.train_cfg = train_cfg
105
+ self.skipbert_cfg = skipbert_cfg
106
 
107
  self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
108
  # self.training_arguments = SFTConfig(**train_cfg.training_arguments, max_seq_length=train_cfg.seq_length)
 
110
  self.num_rounds = num_rounds
111
  self.trainset = trainset
112
  self.valset = valset
113
+ self.mates_cfg = mates_cfg
114
  self.holdoutset = None
115
  self.refset = None
116
+ self.teacher_data_influence_model = None
117
+ self.student_data_influence_model = None
118
  self.data_influence_tokenizer = None
119
 
120
  # instantiate model
121
  self.model, self.tokenizer = get_model(model_cfg)
122
 
123
+ if self.mates_cfg.state:
124
+ self.teacher_data_influence_model, self.student_data_influence_model ,self.data_influence_tokenizer = get_data_influence_model(model_cfg, skipbert_cfg)
125
 
126
  # (
127
  # self.data_collator,
 
146
  # Replace -100 with pad token id in labels
147
  labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
148
 
149
+ # print(f"Shape of predictions: {np.shape(pred_ids)}")
150
+ # print(f"Shape of labels: {np.shape(labels_ids)}")
151
 
152
  # Decode predictions and labels
153
  pred_str = self.tokenizer.batch_decode(
 
182
  .map(
183
  lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
184
  num_proc=8,
185
+ remove_columns=['instruction', 'input', 'output']
186
  )
187
  )
188
 
 
193
  .map(
194
  lambda x: generate_and_tokenize_prompt(x, **tmp_dict),
195
  num_proc=8,
196
+ remove_columns=['instruction', 'input', 'output']
197
  )
198
  )
199
 
200
  # Create holdoutset and refset if state is True
201
+ if self.mates_cfg.state:
202
  trainset_size = len(self.trainset)
203
 
204
  # Calculate sizes for holdout and reference sets
205
+ holdout_size = int(trainset_size * self.mates_cfg.holdout_ratio)
206
+ ref_size = int(trainset_size * self.mates_cfg.reference_ratio)
207
 
208
  # Shuffle the trainset to ensure randomness
209
  shuffled_indices = list(range(trainset_size))
 
218
  self.refset = self.trainset.select(ref_indices)
219
 
220
  print(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
221
+ # logger.info(f"Holdoutset size: {len(self.holdoutset)}, Refset size: {len(self.refset)}")
222
 
223
 
224
  def fit(
225
  self, parameters: NDArrays, config: Dict[str, Scalar]
226
  ) -> Tuple[NDArrays, int, Dict]:
227
  """Implement distributed fit function for a given client."""
228
+ if self.mates_cfg.state and int(config["current_round"]) != 1:
229
  main_model_params, data_influence_model_params = split_models(parameters)
230
  set_parameters(self.model, main_model_params)
231
+ set_parameters_bert(self.teacher_data_influence_model, data_influence_model_params)
232
  else:
233
  set_parameters(self.model, parameters)
234
 
 
279
  args=self.training_arguments,
280
  data_collator=self.data_collator,
281
  compute_metrics=self.compute_metrics,
282
+ mates_cfg=self.mates_cfg,
283
+ skipbert_cfg=self.skipbert_cfg,
284
+ teacher_data_influence_model=self.teacher_data_influence_model,
285
+ student_data_influence_model=self.student_data_influence_model,
286
  data_influence_tokenizer=self.data_influence_tokenizer,
287
  )
288
 
289
  # Train the model
290
  results = trainer.train()
291
 
292
+ if self.mates_cfg.state:
293
  # After training
294
  main_model_params = get_parameters(self.model)
295
+ data_influence_model_params = model_parameters_to_ndarrays(self.teacher_data_influence_model)
296
  final_model_params = concatenate_models_with_marker(main_model_params, data_influence_model_params)
297
  else:
298
  final_model_params = get_parameters(self.model)
 
308
  detailed=False,
309
  )
310
  flops2, macs2, params2 = get_model_profile(
311
+ self.teacher_data_influence_model,
312
  kwargs=input_constructor(batch_size, seq_len, self.data_influence_tokenizer),
313
  print_profile=True,
314
  detailed=False,
 
337
  client_set = load_data_homo(partition_id, num_partitions, cfg.dataset.name)
338
  else:
339
  client_set = load_data_hete(partition_id)
340
+
341
+
342
+ cfg.skipbert.att_layer_maps = [int(s) for s in cfg.skipbert.att_layer_maps.split(', ')]
343
+ cfg.skipbert.hid_layer_maps = [int(k) for k in cfg.skipbert.hid_layer_maps.split(', ')]
344
 
345
  return FlowerClient(
346
  cfg.model,
347
  cfg.train,
348
  cfg.mates,
349
+ cfg.skipbert,
350
  client_set['train'],
351
  client_set['test'],
352
  num_rounds,
template_FL/src/fedllm/dataset.py CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
9
 
10
  FDS = None # Cache FederatedDataset
11
  client_id_ds = None
12
- global_test_set_homo = None
13
 
14
  def split_train_test(dataset, test_size):
15
  # Split the dataset into train and test sets
 
9
 
10
  FDS = None # Cache FederatedDataset
11
  client_id_ds = None
12
+ # global_test_set_homo = None
13
 
14
  def split_train_test(dataset, test_size):
15
  # Split the dataset into train and test sets
template_FL/src/fedllm/models.py CHANGED
@@ -11,7 +11,16 @@ from peft import (
11
  set_peft_model_state_dict,
12
  )
13
  from peft.utils import prepare_model_for_kbit_training
14
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback, BertForSequenceClassification
 
 
 
 
 
 
 
 
 
15
 
16
  from flwr.common.typing import NDArrays
17
  from transformers.trainer_callback import TrainerControl, TrainerState
@@ -90,24 +99,71 @@ def get_model(model_cfg: DictConfig):
90
 
91
  return get_peft_model(model, peft_config), tokenizer
92
 
93
- def get_data_influence_model(model_cfg: DictConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  use_cuda = torch.cuda.is_available()
95
  device_map = torch.device("cuda:0" if use_cuda else "cpu")
96
-
 
 
 
97
  # Load model with num_labels=1
98
- model = BertForSequenceClassification.from_pretrained(
99
- "bert-base-uncased",
100
- num_labels=1, # Set number of labels to 1 for regression or single-class tasks
101
  ).to(device_map)
102
 
103
- tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
 
 
 
 
 
 
 
104
 
105
  if use_cuda:
106
- model = prepare_model_for_kbit_training(
107
- model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
 
 
 
 
108
  )
 
109
 
110
- return model, tokenizer
111
 
112
 
113
  def set_parameters(model, parameters: NDArrays) -> None:
 
11
  set_peft_model_state_dict,
12
  )
13
  from peft.utils import prepare_model_for_kbit_training
14
+ from transformers import (
15
+ AutoModelForCausalLM,
16
+ AutoTokenizer,
17
+ BitsAndBytesConfig,
18
+ TrainerCallback,
19
+ # BertForSequenceClassification,
20
+ BertConfig,
21
+ )
22
+
23
+ from .skipbert.modeling import BertForSequenceClassification, SkipBertForSequenceClassification
24
 
25
  from flwr.common.typing import NDArrays
26
  from transformers.trainer_callback import TrainerControl, TrainerState
 
99
 
100
  return get_peft_model(model, peft_config), tokenizer
101
 
102
+ def get_custom_config(teacher_name, skipbert_cfg: DictConfig):
103
+ num_labels = 1 # Set number of labels to 1 for regression or single-class tasks
104
+
105
+
106
+ teacher_config = BertConfig.from_pretrained(teacher_name)
107
+ teacher_config.num_labels = num_labels
108
+ teacher_config.fit_size = teacher_config.hidden_size
109
+
110
+ student_config = BertConfig.from_pretrained(skipbert_cfg.student_model)
111
+ student_config.num_labels = num_labels
112
+ student_config.fit_size = teacher_config.hidden_size
113
+
114
+ if skipbert_cfg.num_layers_student > 0:
115
+ student_config.num_hidden_layers = skipbert_cfg.num_layers_student
116
+ if skipbert_cfg.num_full_hidden_layers_student > 0:
117
+ student_config.num_full_hidden_layers = skipbert_cfg.num_full_hidden_layers_student
118
+ else:
119
+ student_config.num_full_hidden_layers = student_config.num_hidden_layers
120
+
121
+ student_config.task_type = skipbert_cfg.output_mode
122
+ student_config.n_gram_left = skipbert_cfg.n_gram_left
123
+ student_config.n_gram_right = skipbert_cfg.n_gram_right
124
+ # student_config.plot_mode = 'plot_passive'
125
+ student_config.plot_mode = 'force_compute'
126
+ student_config.ngram_masking = 0.
127
+ if not hasattr(student_config, 'enter_hidden_size'):
128
+ student_config.enter_hidden_size = student_config.hidden_size
129
+ if not hasattr(student_config, 'max_num_entries'):
130
+ student_config.max_num_entries = 100000
131
+
132
+ return teacher_config, student_config
133
+
134
+
135
+ def get_data_influence_model(model_cfg: DictConfig, skipbert_cfg: DictConfig):
136
  use_cuda = torch.cuda.is_available()
137
  device_map = torch.device("cuda:0" if use_cuda else "cpu")
138
+ teacher_name = "bert-base-uncased"
139
+
140
+ teacher_config, student_config = get_custom_config(teacher_name=teacher_name, skipbert_cfg=skipbert_cfg)
141
+
142
  # Load model with num_labels=1
143
+ teacher_model = BertForSequenceClassification.from_pretrained(
144
+ teacher_name, config=teacher_config
 
145
  ).to(device_map)
146
 
147
+ student_model = SkipBertForSequenceClassification.from_pretrained(
148
+ skipbert_cfg.student_model, config=student_config,
149
+ do_fit=skipbert_cfg.do_fit, share_param=skipbert_cfg.share_param).to(device_map)
150
+
151
+ if skipbert_cfg.freeze_lower_layers:
152
+ student_model.freeze_shallow_layers()
153
+
154
+ tokenizer = AutoTokenizer.from_pretrained(teacher_name, do_lower_case=skipbert_cfg.do_lower_case, use_fast=True)
155
 
156
  if use_cuda:
157
+ teacher_model = prepare_model_for_kbit_training(
158
+ teacher_model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
159
+ )
160
+
161
+ student_model = prepare_model_for_kbit_training(
162
+ student_model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
163
  )
164
+
165
 
166
+ return teacher_model, student_model, tokenizer
167
 
168
 
169
  def set_parameters(model, parameters: NDArrays) -> None:
template_FL/src/fedllm/server_app.py CHANGED
@@ -1,9 +1,11 @@
1
  """flowertune-llm: A Flower / FlowerTune app."""
2
 
3
  import os
 
4
  import torch
5
  import wandb
6
  import numpy as np
 
7
  from dotenv import load_dotenv
8
  from datetime import datetime
9
  from tqdm import tqdm
@@ -12,9 +14,11 @@ from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding, Traini
12
  from .trainer import ManualTrainer
13
  from transformers.integrations import WandbCallback
14
  from torch.utils.data import DataLoader
 
15
  from flwr.common import Context, ndarrays_to_parameters
16
  from flwr.common.config import unflatten_dict
17
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
 
18
  # from flwr.server.strategy import FedAvg
19
  from omegaconf import DictConfig
20
 
@@ -27,6 +31,12 @@ from .metrics import exact_match, f1, get_rouge_score
27
 
28
  from datasets import load_dataset, Dataset
29
  from sklearn.model_selection import train_test_split
 
 
 
 
 
 
30
 
31
 
32
  load_dotenv(".env")
@@ -36,6 +46,79 @@ os.environ["WANDB_NAME"] = os.getenv("WANDB_NAME")
36
  os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
37
  # os.environ["WANDB_LOG_MODEL"] = "checkpoint"
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class LLMSampleCB(WandbCallback):
40
  def __init__(self, trainer, test_dataset, task, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
41
  "A CallBack to log samples a wandb.Table during training"
@@ -84,7 +167,7 @@ class LLMSampleCB(WandbCallback):
84
 
85
 
86
 
87
- def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_args, task):
88
 
89
  wandb.init(
90
  project='FL@CSS25',
@@ -157,8 +240,10 @@ def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_arg
157
  args=training_arguments,
158
  data_collator=data_collator,
159
  compute_metrics=compute_metrics,
160
- mates_args=mates_args,
161
- data_influence_model=None,
 
 
162
  data_influence_tokenizer=None,
163
  )
164
 
@@ -185,7 +270,7 @@ def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_arg
185
  # Get function that will be executed by the strategy's evaluate() method
186
  # Here we use it to save global model checkpoints
187
 
188
- def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_round, total_nodes, save_path, mates_args):
189
  """Return an evaluation function for saving global model."""
190
 
191
  def evaluate(server_round: int, parameters, config):
@@ -208,12 +293,14 @@ def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_r
208
  }
209
  if dataset_cfg.type == 'homo':
210
  ds = load_dataset(dataset_cfg.name)
 
 
211
  _, test = train_test_split(
212
- ds, test_size=0.09, shuffle=True, random_state=42
213
  )
214
  global_test_set_homo = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
215
 
216
- loss, metrics = test_model(global_test_set_homo, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, 'homo')
217
  total_loss = loss
218
  result_metric = {'homo_f1': metrics['homo_f1']}
219
  else:
@@ -225,7 +312,7 @@ def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_r
225
 
226
  for task in ['general', 'finance', 'math', 'medical', 'code']:
227
  ds = global_test_set_hete[task]
228
- loss, metrics = test_model(ds, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, task)
229
  list_loss.append(loss)
230
 
231
  list_f1[f'{task}_f1'] = metrics[f'{task}_f1']
@@ -273,6 +360,8 @@ def fit_weighted_average(metrics):
273
  def server_fn(context: Context):
274
  """Construct components that set the ServerApp behaviour."""
275
  # Create output directory given current timestamp
 
 
276
  current_time = datetime.now()
277
  folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
278
  save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
@@ -296,7 +385,7 @@ def server_fn(context: Context):
296
  fit_metrics_aggregation_fn=fit_weighted_average,
297
  initial_parameters=init_model_parameters,
298
  evaluate_fn=get_evaluate_fn(
299
- cfg.train, cfg.model, cfg.dataset, cfg.train.save_every_round, num_rounds, num_nodes, save_path, cfg.mates
300
  ),
301
  use_mates=cfg.mates.state
302
  )
 
1
  """flowertune-llm: A Flower / FlowerTune app."""
2
 
3
  import os
4
+ import sys
5
  import torch
6
  import wandb
7
  import numpy as np
8
+ import pandas as pd
9
  from dotenv import load_dotenv
10
  from datetime import datetime
11
  from tqdm import tqdm
 
14
  from .trainer import ManualTrainer
15
  from transformers.integrations import WandbCallback
16
  from torch.utils.data import DataLoader
17
+ import flwr
18
  from flwr.common import Context, ndarrays_to_parameters
19
  from flwr.common.config import unflatten_dict
20
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
21
+ from flwr.common.logger import FLOWER_LOGGER
22
  # from flwr.server.strategy import FedAvg
23
  from omegaconf import DictConfig
24
 
 
31
 
32
  from datasets import load_dataset, Dataset
33
  from sklearn.model_selection import train_test_split
34
+ import logging
35
+ import uuid
36
+
37
+ logging.getLogger("flwr").setLevel(logging.INFO)
38
+ logging.getLogger("ClientAppActor").setLevel(logging.INFO)
39
+ logging.getLogger("Trainer").setLevel(logging.INFO)
40
 
41
 
42
  load_dotenv(".env")
 
46
  os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
47
  # os.environ["WANDB_LOG_MODEL"] = "checkpoint"
48
 
49
+
50
+ class SessionIDFilter(logging.Filter):
51
+ """Adds a session_id to log records."""
52
+ def __init__(self, session_id):
53
+ super().__init__()
54
+ self.session_id = session_id
55
+
56
+ def filter(self, record):
57
+ record.session_id = self.session_id
58
+ return True
59
+
60
+
61
+ def configure_logging():
62
+ # Generate a unique session ID for this run
63
+ session_id = str(uuid.uuid4())
64
+
65
+ # Define log format with session ID and process ID
66
+ log_format = (
67
+ "%(asctime)s - %(session_id)s - %(process)d - %(name)s - "
68
+ "%(levelname)s - %(message)s"
69
+ )
70
+
71
+ # Create a FileHandler and attach the SessionIDFilter to it
72
+ file_handler = logging.FileHandler("main.log", mode='a') # Append mode
73
+ formatter = logging.Formatter(log_format)
74
+ file_handler.setFormatter(formatter)
75
+ file_handler.addFilter(SessionIDFilter(session_id)) # Add filter to the handler
76
+
77
+ # Console handler: logs to stdout (you can also log to stderr)
78
+ console_handler = logging.StreamHandler(sys.stdout)
79
+ console_handler.setLevel(logging.DEBUG)
80
+ console_handler.setFormatter(formatter)
81
+
82
+
83
+ # Configure root logger to use this handler
84
+ logging.basicConfig(
85
+ level=logging.INFO,
86
+ handlers=[
87
+ file_handler,
88
+ console_handler,
89
+ ], # Use the filtered handler
90
+ )
91
+
92
+ # if not any(
93
+ # isinstance(handler, logging.FileHandler) and handler.baseFilename == file_handler.baseFilename
94
+ # for handler in FLOWER_LOGGER.handlers
95
+ # ):
96
+ # FLOWER_LOGGER.addHandler(file_handler)
97
+
98
+ for handler in FLOWER_LOGGER.handlers:
99
+ FLOWER_LOGGER.addHandler(file_handler)
100
+
101
+ # Get the logger for the ClientAppActor module and attach the same file handler
102
+ client_actor_logger = logging.getLogger("flwr.simulation.ray_transport.ray_actor")
103
+ # if not any(
104
+ # isinstance(handler, logging.FileHandler) and handler.baseFilename == file_handler.baseFilename
105
+ # for handler in client_actor_logger.handlers
106
+ # ):
107
+ # client_actor_logger.addHandler(file_handler)
108
+
109
+ for handler in client_actor_logger.handlers:
110
+ client_actor_logger.addHandler(file_handler)
111
+
112
+ # Explicitly configure Ray's logger to propagate
113
+ ray_logger = logging.getLogger("ray") # Ray's parent logger
114
+ for handler in ray_logger.handlers:
115
+ ray_logger.addHandler(file_handler)
116
+
117
+
118
+ # Log the start of the session
119
+ logger = logging.getLogger(__name__)
120
+ logger.info("===== Application Started =====")
121
+
122
  class LLMSampleCB(WandbCallback):
123
  def __init__(self, trainer, test_dataset, task, num_samples=10, max_new_tokens=256, log_model="checkpoint"):
124
  "A CallBack to log samples a wandb.Table during training"
 
167
 
168
 
169
 
170
+ def test_model(dataset, model, tokenizer, train_cfg, tmp_dict, sround, mates_args, skipbert_cfg, task):
171
 
172
  wandb.init(
173
  project='FL@CSS25',
 
240
  args=training_arguments,
241
  data_collator=data_collator,
242
  compute_metrics=compute_metrics,
243
+ mates_cfg=mates_args,
244
+ skipbert_cfg=skipbert_cfg,
245
+ teacher_data_influence_model=None,
246
+ student_data_influence_model=None,
247
  data_influence_tokenizer=None,
248
  )
249
 
 
270
  # Get function that will be executed by the strategy's evaluate() method
271
  # Here we use it to save global model checkpoints
272
 
273
+ def get_evaluate_fn(train_cfg, model_cfg, dataset_cfg, save_every_round, total_round, total_nodes, save_path, mates_args, skipbert_cfg):
274
  """Return an evaluation function for saving global model."""
275
 
276
  def evaluate(server_round: int, parameters, config):
 
293
  }
294
  if dataset_cfg.type == 'homo':
295
  ds = load_dataset(dataset_cfg.name)
296
+ option = 'test' if 'test' in ds else 'train'
297
+ df = pd.DataFrame(ds[option])
298
  _, test = train_test_split(
299
+ df, test_size=0.09, shuffle=True, random_state=42
300
  )
301
  global_test_set_homo = Dataset.from_pandas(test).remove_columns(['__index_level_0__'])
302
 
303
+ loss, metrics = test_model(global_test_set_homo, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, skipbert_cfg, 'homo')
304
  total_loss = loss
305
  result_metric = {'homo_f1': metrics['homo_f1']}
306
  else:
 
312
 
313
  for task in ['general', 'finance', 'math', 'medical', 'code']:
314
  ds = global_test_set_hete[task]
315
+ loss, metrics = test_model(ds, model, tokenizer, train_cfg, tmp_dict, server_round, mates_args, skipbert_cfg, task)
316
  list_loss.append(loss)
317
 
318
  list_f1[f'{task}_f1'] = metrics[f'{task}_f1']
 
360
  def server_fn(context: Context):
361
  """Construct components that set the ServerApp behaviour."""
362
  # Create output directory given current timestamp
363
+ configure_logging()
364
+ logger = logging.getLogger(__name__)
365
  current_time = datetime.now()
366
  folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
367
  save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
 
385
  fit_metrics_aggregation_fn=fit_weighted_average,
386
  initial_parameters=init_model_parameters,
387
  evaluate_fn=get_evaluate_fn(
388
+ cfg.train, cfg.model, cfg.dataset, cfg.train.save_every_round, num_rounds, num_nodes, save_path, cfg.mates, cfg.skipbert
389
  ),
390
  use_mates=cfg.mates.state
391
  )
template_FL/src/fedllm/skipbert/modeling.py CHANGED
@@ -474,6 +474,7 @@ class ShallowSkipping(nn.Module):
474
 
475
  @torch.jit.script
476
  def merge_ngrams(input_ids, ngram_hidden_states, aux_embeddings):
 
477
  batch_size, seq_length = input_ids.shape
478
  lens = (input_ids!=0).sum(1)
479
  hidden_state = torch.zeros([batch_size, seq_length, ngram_hidden_states.size(-1)], dtype=ngram_hidden_states.dtype, device=ngram_hidden_states.device)
@@ -562,7 +563,6 @@ class ShallowSkipping(nn.Module):
562
  ):
563
 
564
  device = model.device
565
-
566
  batch_size, seq_length = input_ids.shape
567
  aux_embeddings = model.embeddings.position_embeddings2.weight[:seq_length].unsqueeze(0)
568
  aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
 
474
 
475
  @torch.jit.script
476
  def merge_ngrams(input_ids, ngram_hidden_states, aux_embeddings):
477
+ # batch_size, seq_length = input_ids.shape
478
  batch_size, seq_length = input_ids.shape
479
  lens = (input_ids!=0).sum(1)
480
  hidden_state = torch.zeros([batch_size, seq_length, ngram_hidden_states.size(-1)], dtype=ngram_hidden_states.dtype, device=ngram_hidden_states.device)
 
563
  ):
564
 
565
  device = model.device
 
566
  batch_size, seq_length = input_ids.shape
567
  aux_embeddings = model.embeddings.position_embeddings2.weight[:seq_length].unsqueeze(0)
568
  aux_embeddings = aux_embeddings + model.embeddings.token_type_embeddings2(token_type_ids)
template_FL/src/fedllm/skipbert/plot.py CHANGED
@@ -117,8 +117,8 @@ class Plot:
117
  self.max_num_entries = max_num_entries
118
  self.hidden_size = hidden_size
119
 
120
- self.trigram_to_id, self.id_to_trigram = self.build_hash_table('input_ids_tri_gram.memmap', max_num_entries)
121
- self.orig_trigram_hidden_states = _read_or_create_memmap("plot_hidden_states_tri_gram.memmap", dtype='float16', shape=(max_num_entries, 3, hidden_size))
122
 
123
  def build_hash_table(self, path, max_num_entries):
124
  n_gram = 3
 
117
  self.max_num_entries = max_num_entries
118
  self.hidden_size = hidden_size
119
 
120
+ self.trigram_to_id, self.id_to_trigram = self.build_hash_table('./input_ids_tri_gram.memmap', max_num_entries)
121
+ self.orig_trigram_hidden_states = _read_or_create_memmap("./plot_hidden_states_tri_gram.memmap", dtype='float16', shape=(max_num_entries, 3, hidden_size))
122
 
123
  def build_hash_table(self, path, max_num_entries):
124
  n_gram = 3
template_FL/src/fedllm/skipbert/trainer.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import Accelerator
2
+ from torch.utils.data import DataLoader
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import MSELoss, CrossEntropyLoss
6
+ import copy
7
+ import numpy as np
8
+ from transformers import (
9
+ # BertForSequenceClassification,
10
+ GenerationConfig,
11
+ AutoTokenizer,
12
+ Trainer,
13
+ get_scheduler,
14
+ EarlyStoppingCallback,
15
+ TrainingArguments
16
+ )
17
+ from transformers.trainer_utils import (
18
+ EvaluationStrategy,
19
+ IntervalStrategy,
20
+ )
21
+ from transformers.trainer_pt_utils import nested_detach
22
+ from transformers.utils import is_sagemaker_mp_enabled
23
+ from transformers.training_args import OptimizerNames
24
+ from typing import Dict, List, Optional, Any, Union, Tuple, Callable
25
+ import numpy as np
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from accelerate.utils import (
28
+ AutocastKwargs,
29
+ DistributedDataParallelKwargs,
30
+ DistributedType,
31
+ )
32
+ from datasets import Dataset
33
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
34
+ import wandb
35
+ import logging
36
+
37
+ logging.getLogger("Trainer").setLevel(logging.INFO)
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def compute_metrics_skipbert(pred):
42
+ """
43
+ Compute metrics for model evaluation
44
+ """
45
+ labels = pred.label_ids
46
+
47
+ preds = pred.predictions
48
+
49
+ if len(preds[0]) >= 2:
50
+ preds = torch.tensor(preds.argmax(-1))
51
+ labels = torch.tensor(labels)
52
+
53
+ acc = accuracy_score(labels, preds)
54
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
55
+ return {
56
+ 'accuracy': acc,
57
+ 'f1': f1,
58
+ 'precision': precision,
59
+ 'recall': recall
60
+ }
61
+ else:
62
+ labels = torch.tensor(pred.label_ids[:, np.newaxis])
63
+ preds = torch.tensor(pred.predictions)
64
+
65
+ # MSE
66
+ mse = nn.MSELoss()
67
+ mse_loss = mse(labels, preds)
68
+
69
+ #RMSE
70
+ rmse = torch.sqrt(mse_loss)
71
+
72
+ # MAE
73
+ mae = nn.L1Loss()
74
+ mae_loss = mae(labels, preds)
75
+
76
+
77
+ return {
78
+ 'mse': mse_loss,
79
+ 'rmse': rmse,
80
+ 'mae': mae_loss,
81
+ }
82
+
83
+
84
+ # Create custom Trainer for training SkipBERT
85
+ class SkipBertTrainer(Trainer):
86
+ def __init__(
87
+ self,
88
+ student_model: nn.Module,
89
+ teacher_model: Optional[nn.Module] = None,
90
+ train_dataset: Optional[Dataset] = None,
91
+ eval_dataset: Optional[Dataset] = None,
92
+ args: Optional[TrainingArguments] = None,
93
+ data_collator: Optional[Callable] = None,
94
+ compute_metrics: Optional[Callable] = None,
95
+ alpha: float = 0.5,
96
+ temperature: float = 2.0,
97
+ beta: float = 1.0,
98
+ use_logits: bool = True,
99
+ use_att: bool = True,
100
+ use_rep: bool = True,
101
+ use_embedding: bool = True,
102
+ att_layer_maps: Optional[List[int]] = None,
103
+ hid_layer_maps: Optional[List[int]] = None,
104
+ epochs_no_cls: int = 0,
105
+ reduce_T: int = 1,
106
+ output_mode: str = 'classification',
107
+ num_masked_layers_teacher: int = 0,
108
+ num_masked_last_layers_teacher: int = 0,
109
+ fp16: bool = False,
110
+ num_full_hidden_layers_student: int = 0,
111
+ **kwargs,
112
+ ):
113
+ """
114
+ Initialize SkipBERT Trainer with knowledge distillation capabilities.
115
+
116
+ Args:
117
+ student_model: The student model to be trained
118
+ teacher_model: The teacher model for knowledge distillation
119
+ train_dataset: Training dataset
120
+ eval_dataset: Evaluation dataset
121
+ args: Training arguments
122
+ alpha: Balance between distillation loss and cross-entropy loss
123
+ temperature: Temperature for softening probability distributions
124
+ beta: Weighting factor for different loss components
125
+ use_logits: Whether to use logits-based distillation
126
+ use_att: Whether to use attention-based distillation
127
+ use_rep: Whether to use representation-based distillation
128
+ use_embedding: Whether to use embedding-based distillation
129
+ """
130
+ # Set default training arguments if not provided
131
+ if args is None:
132
+ args = TrainingArguments(
133
+ output_dir="./results",
134
+ num_train_epochs=3,
135
+ per_device_train_batch_size=2,
136
+ per_device_eval_batch_size=2,
137
+ logging_dir='./logs',
138
+ evaluation_strategy=EvaluationStrategy.EPOCH,
139
+ save_strategy=IntervalStrategy.EPOCH,
140
+ )
141
+
142
+
143
+ # Call parent constructor
144
+ super().__init__(
145
+ model=student_model,
146
+ args=args,
147
+ train_dataset=train_dataset,
148
+ eval_dataset=eval_dataset,
149
+ data_collator=data_collator,
150
+ compute_metrics=compute_metrics,
151
+ **kwargs
152
+ )
153
+
154
+ # Store additional knowledge distillation parameters
155
+ self.teacher_model = teacher_model
156
+ self.alpha = alpha
157
+ self.temperature = temperature
158
+ self.beta = beta
159
+ self.use_logits = use_logits
160
+ self.use_att = use_att
161
+ self.use_rep = use_rep
162
+ self.use_embedding = use_embedding
163
+ self.att_layer_maps = att_layer_maps or []
164
+ self.hid_layer_maps = hid_layer_maps or []
165
+ self.epochs_no_cls = epochs_no_cls
166
+ self.reduce_T = reduce_T
167
+ self.output_mode = output_mode
168
+ self.num_masked_layers_teacher = num_masked_layers_teacher
169
+ self.num_masked_last_layers_teacher = num_masked_last_layers_teacher
170
+ self.num_full_hidden_layers_student = num_full_hidden_layers_student
171
+ self.tr_att_loss = 0
172
+ self.tr_rep_loss = 0
173
+ self.tr_cls_loss = 0
174
+ self.list_att_loss = []
175
+ self.list_rep_loss = []
176
+ self.list_embed_loss = []
177
+
178
+ # Prepare FP16 if enabled
179
+ self.fp16 = fp16
180
+ if fp16:
181
+ try:
182
+ from apex import amp
183
+ except ImportError:
184
+ raise ImportError(
185
+ "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
186
+ )
187
+
188
+ # Initialize amp
189
+ self.model, self.optimizer = amp.initialize(
190
+ self.model,
191
+ self.optimizer,
192
+ opt_level='01'
193
+ )
194
+
195
+ # Half precision for teacher model if exists
196
+ if self.teacher_model is not None:
197
+ self.teacher_model = self.teacher_model.half()
198
+
199
+ # Loss functions
200
+ self.loss_mse = MSELoss()
201
+
202
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
203
+ """
204
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
205
+
206
+ Subclass and override for custom behavior.
207
+ """
208
+
209
+ # Separate labels from inputs
210
+ labels = inputs.pop("labels")
211
+
212
+ if self.model_accepts_loss_kwargs:
213
+ loss_kwargs = {}
214
+ if num_items_in_batch is not None:
215
+ loss_kwargs["num_items_in_batch"] = num_items_in_batch
216
+ inputs = {**inputs, **loss_kwargs}
217
+
218
+ # Forward pass through student model
219
+ student_logits, student_atts, student_reps = model(**inputs)
220
+ student_reps = student_reps[-self.num_full_hidden_layers_student-1:]
221
+
222
+ # Forward pass through teacher model
223
+ self.teacher_model.eval()
224
+ with torch.no_grad():
225
+ # teacher_logits, teacher_atts, teacher_reps = self.teacher_model(**inputs)
226
+ teacher_outputs = self.teacher_model(**inputs, output_hidden_states=True, output_attentions=True)
227
+ teacher_logits, teacher_atts, teacher_reps = teacher_outputs.logits, teacher_outputs.attentions, teacher_outputs.hidden_states
228
+ start, end = self.num_masked_layers_teacher, -1 * self.num_masked_layers_teacher if self.num_masked_layers_teacher != 0 else None
229
+ teacher_reps = teacher_reps[start:end]
230
+
231
+ # Save past state if it exists
232
+ # TODO: this needs to be fixed and made cleaner later.
233
+ if self.args.past_index >= 0:
234
+ self._past = student_outputs[self.args.past_index]
235
+
236
+ # Compute losses
237
+ att_loss, rep_loss = 0., 0.
238
+
239
+ # ---------------------------
240
+ if labels is not None:
241
+
242
+
243
+ # ---------------------------
244
+ if self.att_layer_maps is None:
245
+ teacher_layer_num = len(teacher_atts)
246
+ student_layer_num = len(student_atts)
247
+ assert teacher_layer_num % student_layer_num == 0
248
+ layers_per_block = int(teacher_layer_num / student_layer_num)
249
+ new_teacher_atts = [
250
+ teacher_atts[(i * 1) * layers_per_block - 1]
251
+ for i in range(student_layer_num)
252
+ ]
253
+ assert len(student_atts) == len(new_teacher_atts)
254
+
255
+ else:
256
+ new_teacher_atts = []
257
+ for t2s in self.att_layer_maps:
258
+ if t2s >= 0:
259
+ new_teacher_atts.append(teacher_atts[t2s])
260
+ else:
261
+ new_teacher_atts.append(None)
262
+
263
+ # ----------------------------
264
+
265
+ for student_att, teacher_att in zip(student_atts, new_teacher_atts):
266
+ if teacher_att is None:
267
+ continue
268
+ student_att = torch.where(
269
+ student_att <= 1e-2,
270
+ torch.zeros_like(student_att),
271
+ student_att
272
+ )
273
+
274
+ teacher_att = torch.where(
275
+ teacher_att <= 1e-2,
276
+ torch.zeros_like(teacher_att),
277
+ teacher_att
278
+ )
279
+
280
+ att_loss += self.loss_mse(student_att, teacher_att)
281
+
282
+ # ---------------------------
283
+
284
+ if self.hid_layer_maps is None:
285
+ teacher_layer_num = len(teacher_atts) - 1
286
+ student_layer_num = len(student_atts) - 1
287
+ assert teacher_layer_num % student_layer_num == 0
288
+ layers_per_block = int(teacher_layer_num / student_layer_num)
289
+ new_teacher_reps = [
290
+ teacher_reps[i * layers_per_block]
291
+ for i in range(student_layer_num + 1)
292
+ ]
293
+ assert len(new_student_reps) == len(new_teacher_reps)
294
+ else:
295
+ new_student_reps = student_reps
296
+ new_teacher_reps = []
297
+ for t2s in self.hid_layer_maps:
298
+ if t2s >= 0:
299
+ new_teacher_reps.append(teacher_reps[t2s])
300
+ else:
301
+ new_teacher_reps.append(None)
302
+
303
+ # ---------------------------
304
+
305
+ for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
306
+ if teacher_rep is None:
307
+ continue
308
+ tmp_loss = self.loss_mse(student_rep, teacher_rep)
309
+ rep_loss += tmp_loss
310
+
311
+ self.tr_att_loss += att_loss.item()
312
+ self.tr_rep_loss += rep_loss.item()
313
+
314
+ # ---------------------------
315
+ embedding_loss = 0
316
+ if self.use_embedding:
317
+ embedding_loss = self.loss_mse(
318
+ student_reps[0], teacher_reps[0]
319
+ )
320
+
321
+ # ---------------------------
322
+
323
+ # ---------------------------
324
+
325
+ if self.use_logits and self.state.epoch >= self.epochs_no_cls:
326
+ if isinstance(student_logits, tuple) or \
327
+ isinstance(student_logits, list):
328
+ cls_loss = None
329
+ _scale = 0.
330
+ for il, logits in enumerate(student_logits):
331
+ _loss, _, _ = self._compute_distillation_loss(
332
+ student_logits, student_atts, student_reps,
333
+ teacher_logits, teacher_atts, teacher_reps,
334
+ labels
335
+ )
336
+ if cls_loss is None:
337
+ cls_loss = _loss
338
+ else:
339
+ cls_loss = _loss * (il + 1.) + cls_loss
340
+ _scale += il + 1
341
+
342
+ cls_loss = cls_loss * (1. / _scale)
343
+
344
+ else:
345
+ cls_loss, kd_loss, ce_loss = self._compute_distillation_loss(
346
+ student_logits, student_atts, student_reps,
347
+ teacher_logits, teacher_atts, teacher_reps,
348
+ labels
349
+ )
350
+ self.tr_cls_loss += cls_loss.item()
351
+
352
+ else:
353
+ cls_loss = 0
354
+
355
+ # ---------------------------
356
+
357
+
358
+ check = self.state.epoch >= self.epochs_no_cls
359
+ self.beta = self.beta * check + (1 - check) * 1.
360
+
361
+ # ---------------------------
362
+
363
+ if self.use_embedding and \
364
+ self.use_att and \
365
+ self.use_rep:
366
+ loss = self.beta * (rep_loss + att_loss + embedding_loss) + cls_loss
367
+
368
+ elif self.use_att and self.use_rep:
369
+ loss = self.beta * (rep_loss + att_loss) + cls_loss
370
+
371
+ elif self.use_embedding and self.use_att:
372
+ loss = self.beta * (att_loss + embedding_loss) + cls_loss
373
+
374
+ elif self.use_embedding and self.use_rep:
375
+ loss = self.beta * (rep_loss + embedding_loss) + cls_loss
376
+
377
+ elif self.use_att and \
378
+ not self.use_embedding and \
379
+ not self.use_rep:
380
+ loss = self.beta * att_loss + cls_loss
381
+
382
+ elif self.use_rep and \
383
+ not self.use_embedding and \
384
+ not self.use_att:
385
+ loss = self.beta * rep_loss + cls_loss
386
+
387
+ else:
388
+ loss = cls_loss
389
+
390
+
391
+ # ---------------------------
392
+
393
+ else:
394
+ if isinstance(outputs, dict) and "loss" not in outputs:
395
+ raise ValueError(
396
+ "The model did not return a loss from the inputs, only the following keys: "
397
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
398
+ )
399
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
400
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
401
+
402
+ # ---------------------------
403
+
404
+ # ---------------------------
405
+ if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
406
+ loss *= self.accelerator.num_processes
407
+ rep_loss *= self.accelerator.num_processes
408
+ att_loss *= self.accelerator.num_processes
409
+ embedding_loss *= self.accelerator.num_processes
410
+ self.list_att_loss.append(att_loss.item())
411
+ self.list_rep_loss.append(rep_loss.item())
412
+ self.list_embed_loss.append(embedding_loss.item())
413
+ # ---------------------------
414
+
415
+
416
+ # ---------------------------
417
+
418
+ # Ensure logits are properly formatted for evaluation metrics
419
+ logits = student_logits
420
+ if return_outputs:
421
+ # Ensure student_logits has the correct shape [batch_size, num_classes]
422
+ if isinstance(student_logits, (tuple, list)):
423
+ logits = student_logits[-1]
424
+ else:
425
+ logits = student_logits
426
+
427
+ # If logits is 1D, reshape it to 2D
428
+ if len(logits.shape) == 1:
429
+ logits = logits.unsqueeze(0)
430
+
431
+ # Ensure we have [batch_size, num_classes] shape
432
+ if len(logits.shape) != 2:
433
+ raise ValueError(f"Unexpected logits shape: {logits.shape}. Expected [batch_size, num_classes]")
434
+
435
+ if self.output_mode == "classification": # Classification
436
+ loss = nn.functional.cross_entropy(labels.view(-1), logits.view(-1, len(logits[0])), reduction="mean")
437
+
438
+ elif self.output_mode == "regression": # Regression
439
+ # print(f"Return output - student: {nn.functional.softmax(student_logits, dim=0).view(-1)}, labels: {labels.view(-1)}")
440
+ loss = self.loss_mse(labels.view(-1), logits.view(-1))
441
+
442
+
443
+ # ---------------------------
444
+ # print(f"loss: {loss}, att_loss: {att_loss}, rep_loss: {rep_loss}, embed_loss: {embedding_loss}, Train {return_outputs}")
445
+ return (loss, logits) if return_outputs else loss
446
+
447
+ def _compute_distillation_loss(
448
+ self,
449
+ student_logits, student_atts, student_reps,
450
+ teacher_logits, teacher_atts, teacher_reps,
451
+ labels
452
+ ):
453
+ """
454
+ Compute comprehensive knowledge distillation loss.
455
+
456
+ Args:
457
+ student_*: Student model's outputs
458
+ teacher_*: Teacher model's outputs
459
+ labels: Ground truth labels
460
+
461
+ Returns:
462
+ Computed loss
463
+ """
464
+
465
+ # Classification/distillation loss
466
+ if self.output_mode == "classification": # Classification
467
+ # Similar to previous implementation's distillation loss
468
+ if teacher_logits is not None:
469
+ student_likelihood = nn.functional.log_softmax(student_logits / self.temperature, dim=-1)
470
+ targets_prob = nn.functional.softmax(teacher_logits / self.temperature, dim=-1)
471
+ d_loss = (-targets_prob * student_likelihood).mean() * (self.temperature ** 2) / self.reduce_T
472
+ else:
473
+ d_loss = 0
474
+ # Standard cross-entropy/MSE loss
475
+ nll_loss = nn.functional.cross_entropy(student_logits, labels, reduction="mean")
476
+
477
+ elif self.output_mode == "regression": # Regression
478
+ # student_likelihood = nn.functional.softmax(student_logits, dim=0)
479
+ # teacher_likelihood = nn.functional.softmax(teacher_logits, dim=0)
480
+ student_likelihood = torch.tensor(student_logits)
481
+ teacher_likelihood = torch.tensor(teacher_logits)
482
+ d_loss = self.loss_mse(student_likelihood.view(-1), teacher_likelihood.view(-1))
483
+ nll_loss = self.loss_mse(teacher_likelihood.view(-1), labels.view(-1))
484
+ else:
485
+ assert output_mode in ["classification", "regression"]
486
+ d_loss = 0.
487
+ nll_loss = 0.
488
+ tol_loss = self.alpha * d_loss + (1 - self.alpha) * nll_loss
489
+ return tol_loss, d_loss, nll_loss
490
+
491
+ def train(
492
+ self,
493
+ resume_from_checkpoint: Optional[str] = None,
494
+ trial: Optional[Dict[str, Any]] = None,
495
+ ignore_keys_for_eval: Optional[List[str]] = None,
496
+ **kwargs
497
+ ):
498
+ """
499
+ Train method with explicit configuration for knowledge distillation training.
500
+
501
+ Args:
502
+ resume_from_checkpoint: Optional checkpoint to resume training
503
+ trial: Optional hyperparameter trial configuration
504
+ ignore_keys_for_eval: Keys to ignore during evaluation
505
+ """
506
+ # Prepare teacher model if exists
507
+ if self.teacher_model is not None:
508
+ self.teacher_model.to(self.args.device)
509
+ self.teacher_model.eval() # Ensure teacher is in eval mode
510
+
511
+ # Call parent train method
512
+ return super().train(
513
+ resume_from_checkpoint=resume_from_checkpoint,
514
+ trial=trial,
515
+ ignore_keys_for_eval=ignore_keys_for_eval,
516
+ **kwargs
517
+ )
518
+ def training_step(
519
+ self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
520
+ ) -> torch.Tensor:
521
+ """
522
+ Perform a training step on a batch of inputs.
523
+
524
+ Subclass and override to inject custom behavior.
525
+
526
+ Args:
527
+ model (`nn.Module`):
528
+ The model to train.
529
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
530
+ The inputs and targets of the model.
531
+
532
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
533
+ argument `labels`. Check your model's documentation for all accepted arguments.
534
+
535
+ Return:
536
+ `torch.Tensor`: The tensor with training loss on this batch.
537
+ """
538
+ model.train()
539
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
540
+ self.optimizer.train()
541
+
542
+ inputs = self._prepare_inputs(inputs)
543
+ if is_sagemaker_mp_enabled():
544
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
545
+ return loss_mb.reduce_mean().detach().to(self.args.device)
546
+
547
+ with self.compute_loss_context_manager():
548
+ loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
549
+
550
+ del inputs
551
+ if (
552
+ self.args.torch_empty_cache_steps is not None
553
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
554
+ ):
555
+ if is_torch_xpu_available():
556
+ torch.xpu.empty_cache()
557
+ elif is_torch_mlu_available():
558
+ torch.mlu.empty_cache()
559
+ elif is_torch_musa_available():
560
+ torch.musa.empty_cache()
561
+ elif is_torch_npu_available():
562
+ torch.npu.empty_cache()
563
+ elif is_torch_mps_available(min_version="2.0"):
564
+ torch.mps.empty_cache()
565
+ else:
566
+ torch.cuda.empty_cache()
567
+
568
+ kwargs = {}
569
+
570
+ # For LOMO optimizers you need to explicitly use the learnign rate
571
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
572
+ kwargs["learning_rate"] = self._get_learning_rate()
573
+
574
+ if self.args.n_gpu > 1:
575
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
576
+
577
+ if self.use_apex:
578
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
579
+ scaled_loss.requires_grad = True
580
+ scaled_loss.backward()
581
+
582
+ if (self.state.global_step + 1) % self.args.gradient_accumulation_steps == 0:
583
+ nn.utils.clip_grad_norm_(amp.master_params(self.optimizer[0]), 1.0)
584
+
585
+
586
+ else:
587
+ # Finally we need to normalize the loss for reporting
588
+ loss.requires_grad = True
589
+ if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
590
+ loss = loss / self.args.gradient_accumulation_steps
591
+
592
+ # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
593
+ # https://github.com/huggingface/transformers/pull/35808
594
+ if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
595
+ kwargs["scale_wrt_gas"] = False
596
+
597
+ self.accelerator.backward(loss, **kwargs)
598
+
599
+ if (self.state.global_step + 1) % self.args.gradient_accumulation_steps == 0:
600
+ # nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
601
+ nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
602
+
603
+ return loss.detach()
604
+
605
+ def evaluate(
606
+ self,
607
+ eval_dataset: Optional[Dataset] = None,
608
+ ignore_keys: Optional[List[str]] = None,
609
+ metric_key_prefix: str = "eval",
610
+ **kwargs
611
+ ) -> Dict[str, float]:
612
+ """
613
+ Evaluation method with custom metrics computation.
614
+
615
+ Args:
616
+ eval_dataset: Optional evaluation dataset
617
+ ignore_keys: Keys to ignore during evaluation
618
+ metric_key_prefix: Prefix for metrics
619
+
620
+ Returns:
621
+ Dictionary of evaluation metrics
622
+ """
623
+ # Use parent's evaluate method with optional customizations
624
+ return super().evaluate(
625
+ eval_dataset=eval_dataset,
626
+ ignore_keys=ignore_keys,
627
+ metric_key_prefix=metric_key_prefix,
628
+ **kwargs
629
+ )
630
+
631
+ def prediction_step(
632
+ self,
633
+ model: nn.Module,
634
+ inputs: Dict[str, Union[torch.Tensor, Any]],
635
+ prediction_loss_only: bool,
636
+ ignore_keys: Optional[List[str]] = None,
637
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
638
+ """
639
+ Override prediction step to handle the model's output format correctly.
640
+ """
641
+ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
642
+
643
+ return_loss = inputs.get("return_loss", None)
644
+ if return_loss is None:
645
+ return_loss = self.can_return_loss
646
+ loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
647
+
648
+
649
+ inputs = self._prepare_inputs(inputs)
650
+ if ignore_keys is None:
651
+ if hasattr(self.model, "config"):
652
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
653
+ else:
654
+ ignore_keys = []
655
+
656
+ # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
657
+ if has_labels or loss_without_labels:
658
+ labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
659
+ if len(labels) == 1:
660
+ labels = labels[0]
661
+ else:
662
+ labels = None
663
+
664
+ with torch.no_grad():
665
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
666
+ loss = loss.mean().detach()
667
+
668
+ # Get logits from outputs
669
+ if isinstance(outputs, dict):
670
+ logits = outputs["logits"]
671
+ else:
672
+ # logits = outputs[0]
673
+ logits = outputs
674
+
675
+
676
+ # Ensure logits has correct shape [batch_size, num_classes]
677
+ if len(logits.shape) == 1:
678
+ logits = logits.unsqueeze(0)
679
+
680
+ if prediction_loss_only:
681
+ return (loss, None, None)
682
+
683
+ if labels is not None:
684
+ labels = labels.detach()
685
+
686
+ logits = nested_detach(logits)
687
+ if len(logits.shape) == 1:
688
+ logits = logits[0]
689
+
690
+ # print(f"Validation loss: {loss}")
691
+ return (loss, logits, labels)
template_FL/src/fedllm/trainer.py CHANGED
@@ -3,14 +3,48 @@ from torch.utils.data import DataLoader
3
  import torch
4
  import copy
5
  import numpy as np
6
- from transformers import BertForSequenceClassification, GenerationConfig, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
7
  import inspect
8
  import logging
9
  import wandb
10
  from tqdm import tqdm
 
 
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class ManualLLMSampleCB:
15
  def __init__(self, model, tokenizer, task, num_samples=10, max_new_tokens=256):
16
  self.model = model
@@ -61,12 +95,14 @@ class ManualLLMSampleCB:
61
  def log_samples_to_wandb(self, dataset):
62
  samples_table = self.create_samples_table(dataset)
63
  wandb.log({"sample_predictions": samples_table})
64
-
 
 
65
 
66
  class ManualTrainer:
67
  def __init__(
68
  self, model, tokenizer, train_dataset, val_dataset, holdout_dataset, reference_dataset,
69
- args, data_collator, compute_metrics, mates_args, data_influence_model, data_influence_tokenizer
70
  ):
71
  self.accelerator = Accelerator()
72
  self.model = model
@@ -74,10 +110,13 @@ class ManualTrainer:
74
  self.args = args
75
  self.data_collator = data_collator
76
  self.compute_metrics = compute_metrics
77
- self.mates_args = mates_args
78
- self.data_influence_model = data_influence_model
 
 
79
  self.data_influence_tokenizer = data_influence_tokenizer
80
-
 
81
  # Remove unused columns from datasets
82
  if train_dataset:
83
  self.train_dataset = self._remove_unused_columns(train_dataset, "training")
@@ -105,13 +144,13 @@ class ManualTrainer:
105
  else:
106
  self.val_loader = None
107
 
108
- if self.mates_args.state:
109
  self.holdout_dataset = self._remove_unused_columns(holdout_dataset, "holdout")
110
  self.reference_dataset = self._remove_unused_columns(reference_dataset, "reference")
111
 
112
  self.holdout_loader = DataLoader(
113
  self.holdout_dataset,
114
- batch_size=self.mates_args.holdout_batch_size,
115
  shuffle=True,
116
  collate_fn=self.data_collator,
117
  drop_last=self.args.dataloader_drop_last
@@ -119,7 +158,7 @@ class ManualTrainer:
119
 
120
  self.reference_loader = DataLoader(
121
  self.reference_dataset,
122
- batch_size=self.mates_args.reference_batch_size,
123
  shuffle=False,
124
  collate_fn=self.data_collator,
125
  drop_last=self.args.dataloader_drop_last
@@ -136,11 +175,56 @@ class ManualTrainer:
136
  self.model, self.optimizer, self.full_train_loader, self.val_loader
137
  )
138
 
139
- if self.mates_args.state:
 
140
  # Prepare holdout and reference loaders for Accelerator
141
- self.data_influence_model, self.holdout_loader, self.reference_loader = self.accelerator.prepare(
142
- self.data_influence_model, self.holdout_loader, self.reference_loader
143
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def _remove_unused_columns(self, dataset, description=None):
146
  """
@@ -188,15 +272,15 @@ class ManualTrainer:
188
 
189
  for epoch in range(self.args.num_train_epochs):
190
  # Check if it's time to update the data influence model and state is True
191
- if self.mates_args.state and epoch % self.mates_args.update_data_influence_model_step == 0:
192
  print("Updating the data influence model and selecting high-quality data...")
193
- logger.info("Updating the data influence model and selecting high-quality data...")
194
  self.update_data_influence_model()
195
 
196
  # Filter high-quality data using the data influence model
197
  high_quality_indices = self.select_high_quality_data(
198
  dataset_size=len(self.train_dataset),
199
- selection_fraction=self.mates_args.selection_fraction,
200
  )
201
  self.train_loader = self.accelerator.prepare(
202
  self.create_filtered_dataloader(high_quality_indices)
@@ -224,16 +308,16 @@ class ManualTrainer:
224
  epoch_loss += loss.item()
225
 
226
  if (step + 1) % self.args.logging_steps == 0:
227
- # print(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
228
- logger.info(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
229
 
230
  avg_epoch_loss = epoch_loss / len(self.train_loader)
231
  training_loss.append(avg_epoch_loss)
232
 
233
  val_results = self.evaluate()
234
 
235
- # print(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
236
- logger,info(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
237
 
238
  # Early stopping logic
239
  if val_results["eval_loss"] < best_val_loss:
@@ -243,6 +327,7 @@ class ManualTrainer:
243
  early_stopping_counter += 1
244
  if early_stopping_counter >= early_stopping_patience:
245
  print("Early stopping triggered")
 
246
  break
247
 
248
  return {"training_loss": sum(training_loss) / len(training_loss), "best_val_loss": best_val_loss}
@@ -252,14 +337,20 @@ class ManualTrainer:
252
  Use the data influence model to predict quality scores and select high-quality data indices.
253
  """
254
  print("Selecting high-quality data using the data influence model...")
 
255
 
256
  # Predict influence scores for the entire dataset
257
  influence_scores = []
258
- self.data_influence_model.eval()
259
  influence_optimizer = self.accelerator.prepare(
260
- torch.optim.AdamW(self.data_influence_model.parameters(), lr=self.args.learning_rate)
 
 
 
261
  )
262
  i = 0
 
 
263
  with torch.no_grad():
264
  for batch in self.full_train_loader: # Full dataset loader
265
  text = self.tokenizer.batch_decode(
@@ -278,12 +369,13 @@ class ManualTrainer:
278
 
279
  # Train the data influence model
280
  influence_optimizer.zero_grad()
281
- outputs = self.data_influence_model(
282
  input_ids=bert_inputs['input_ids'],
283
  attention_mask=bert_inputs['attention_mask'],
284
  )
285
 
286
- influence_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
 
287
 
288
  i += 1
289
 
@@ -293,6 +385,7 @@ class ManualTrainer:
293
  # Normalize influence scores and apply Gumbel-Top-$k$ selection
294
  influence_scores = np.array(influence_scores)
295
  print(">> Influence scores shape:", influence_scores.shape)
 
296
 
297
  # Add Gumbel noise for diversity
298
  rng = np.random.default_rng()
@@ -303,7 +396,12 @@ class ManualTrainer:
303
  selection_size = int(len(influence_scores)*selection_fraction)
304
  high_quality_indices = np.argpartition(-influence_scores, selection_size)[:selection_size]
305
  print(f"Selected {len(high_quality_indices)} high-quality samples.")
306
-
 
 
 
 
 
307
  return high_quality_indices
308
 
309
  def create_filtered_dataloader(self, indices):
@@ -311,6 +409,7 @@ class ManualTrainer:
311
  Create a new dataloader with only the selected high-quality data.
312
  """
313
  print("Creating a filtered dataloader with selected high-quality data...")
 
314
  subset_dataset = torch.utils.data.Subset(self.train_dataset, indices)
315
  return torch.utils.data.DataLoader(
316
  subset_dataset,
@@ -325,16 +424,18 @@ class ManualTrainer:
325
  # Train a copy of the model on holdout data and validate on reference data
326
  copied_model = copy.deepcopy(self.model)
327
  copied_model.train()
 
 
328
  optimizer = self.accelerator.prepare(
329
  torch.optim.Adam(copied_model.parameters(), lr=self.args.learning_rate)
330
  )
331
  holdout_reference_pairs = []
332
 
333
- # print("Starting to collect holdout-reference pairs...")
334
- logger.info("Starting to collect holdout-reference pairs...")
335
  for step, holdout_batch in enumerate(self.holdout_loader):
336
- # print(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
337
- logger.info(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
338
 
339
  optimizer.zero_grad()
340
  outputs = copied_model(
@@ -352,7 +453,7 @@ class ManualTrainer:
352
  optimizer.step()
353
 
354
  print(f"Evaluating reference losses at step {step}...")
355
- logger.info(f"Evaluating reference losses at step {step}...")
356
 
357
  copied_model.eval()
358
  reference_losses = []
@@ -373,42 +474,138 @@ class ManualTrainer:
373
 
374
  # Train the data influence model using the generated pairs
375
  print("Starting to train the data influence model...")
376
- logger.info("Starting to train the data influence model...")
377
 
378
- self.data_influence_model.train()
379
- influence_optimizer = torch.optim.AdamW(self.data_influence_model.parameters(), lr=self.args.learning_rate)
380
-
381
- for step, (text, score) in enumerate(holdout_reference_pairs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  # Tokenize the text using the BERT tokenizer
383
- bert_inputs = self.data_influence_tokenizer(
384
- text,
385
- truncation=True,
386
- padding='max_length',
387
- max_length=256,
388
- return_tensors='pt'
389
- ).to(self.accelerator.device)
390
-
391
- # Convert score to tensor and enable gradients
392
- score_tensor = torch.tensor([score], device=self.accelerator.device, dtype=torch.float32, requires_grad=True)
 
 
393
 
394
  # Train the data influence model
395
  influence_optimizer.zero_grad()
396
- outputs = self.data_influence_model(
397
- input_ids=bert_inputs['input_ids'],
398
- attention_mask=bert_inputs['attention_mask'],
399
- labels=score_tensor
 
 
 
 
400
  )
401
- influence_loss = outputs.loss
402
-
 
 
 
403
  influence_loss.backward()
404
  influence_optimizer.step()
405
-
406
  if step % 50 == 0:
407
  print(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
408
- logger.info(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
 
409
 
 
 
 
 
 
 
 
410
 
411
- # Distillation for SkipBERT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
 
414
 
@@ -459,7 +656,8 @@ class ManualTrainer:
459
  metrics = self.compute_metrics({"predictions": padded_preds, "label_ids": padded_labels})
460
 
461
  metrics.update({"eval_loss": val_loss / len(self.val_loader)})
462
- print("Validation Metrics:", metrics)
 
463
 
464
  if wandb_sample:
465
  # Sample Logging
 
3
  import torch
4
  import copy
5
  import numpy as np
6
+ from transformers import (
7
+ # BertForSequenceClassification,
8
+ GenerationConfig,
9
+ AutoTokenizer,
10
+ Trainer,
11
+ get_scheduler,
12
+ EarlyStoppingCallback,
13
+ TrainingArguments,
14
+ DataCollatorWithPadding
15
+ )
16
+ from datasets import Dataset
17
+ from .skipbert.trainer import compute_metrics_skipbert, SkipBertTrainer
18
+
19
  import inspect
20
  import logging
21
  import wandb
22
  from tqdm import tqdm
23
+ import time
24
+ from functools import partial
25
+
26
+ logging.getLogger("Trainer").setLevel(logging.INFO)
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
+
31
+ def time_format(runtime, logger):
32
+ if runtime < 60:
33
+ logger.info(f'Runtime: {runtime:.2f} seconds')
34
+ elif runtime < 3600: # Less than one hour
35
+ minutes = runtime / 60
36
+ logger.info(f'Runtime: {minutes:.2f} minutes')
37
+ else:
38
+ hours = runtime / 3600
39
+ logger.info(f'Runtime: {hours:.2f} hours')
40
+
41
+
42
+
43
+ def convert_to_tokens_reg(data, tokenizer, max_seq_length, device):
44
+ input_tokenzied = tokenizer(data['text'], truncation=True, padding=True, max_length=max_seq_length, return_tensors="pt")
45
+ input_tokenzied['labels'] = torch.tensor(data['label'], dtype=torch.float32).reshape(-1, 1)
46
+ return input_tokenzied
47
+
48
  class ManualLLMSampleCB:
49
  def __init__(self, model, tokenizer, task, num_samples=10, max_new_tokens=256):
50
  self.model = model
 
95
  def log_samples_to_wandb(self, dataset):
96
  samples_table = self.create_samples_table(dataset)
97
  wandb.log({"sample_predictions": samples_table})
98
+
99
+
100
+
101
 
102
  class ManualTrainer:
103
  def __init__(
104
  self, model, tokenizer, train_dataset, val_dataset, holdout_dataset, reference_dataset,
105
+ args, data_collator, compute_metrics, mates_cfg, skipbert_cfg, teacher_data_influence_model, student_data_influence_model, data_influence_tokenizer
106
  ):
107
  self.accelerator = Accelerator()
108
  self.model = model
 
110
  self.args = args
111
  self.data_collator = data_collator
112
  self.compute_metrics = compute_metrics
113
+ self.mates_cfg = mates_cfg
114
+ self.skipbert_cfg = skipbert_cfg
115
+ self.teacher_data_influence_model = teacher_data_influence_model
116
+ self.student_data_influence_model = student_data_influence_model
117
  self.data_influence_tokenizer = data_influence_tokenizer
118
+
119
+
120
  # Remove unused columns from datasets
121
  if train_dataset:
122
  self.train_dataset = self._remove_unused_columns(train_dataset, "training")
 
144
  else:
145
  self.val_loader = None
146
 
147
+ if self.mates_cfg.state:
148
  self.holdout_dataset = self._remove_unused_columns(holdout_dataset, "holdout")
149
  self.reference_dataset = self._remove_unused_columns(reference_dataset, "reference")
150
 
151
  self.holdout_loader = DataLoader(
152
  self.holdout_dataset,
153
+ batch_size=self.mates_cfg.holdout_batch_size,
154
  shuffle=True,
155
  collate_fn=self.data_collator,
156
  drop_last=self.args.dataloader_drop_last
 
158
 
159
  self.reference_loader = DataLoader(
160
  self.reference_dataset,
161
+ batch_size=self.mates_cfg.reference_batch_size,
162
  shuffle=False,
163
  collate_fn=self.data_collator,
164
  drop_last=self.args.dataloader_drop_last
 
175
  self.model, self.optimizer, self.full_train_loader, self.val_loader
176
  )
177
 
178
+ ### Define for MATEs ###
179
+ if self.mates_cfg.state:
180
  # Prepare holdout and reference loaders for Accelerator
181
+ self.teacher_data_influence_model, self.holdout_loader, self.reference_loader = self.accelerator.prepare(
182
+ self.teacher_data_influence_model, self.holdout_loader, self.reference_loader
183
  )
184
+ self.student_data_influence_model = self.accelerator.prepare(self.student_data_influence_model)
185
+ ######
186
+
187
+ ### Define for SkipBERT ###
188
+ self.skipbert_train_args = TrainingArguments(
189
+ output_dir=self.skipbert_cfg.output_dir,
190
+ learning_rate=self.skipbert_cfg.learning_rate,
191
+ num_train_epochs=self.skipbert_cfg.num_train_epochs,
192
+ per_device_train_batch_size=self.skipbert_cfg.train_batch_size,
193
+ gradient_accumulation_steps=self.skipbert_cfg.gradient_accumulation_steps,
194
+ per_device_eval_batch_size=self.skipbert_cfg.eval_batch_size,
195
+ eval_accumulation_steps=self.skipbert_cfg.eval_accumulation_steps,
196
+ max_steps=self.skipbert_cfg.max_steps,
197
+ logging_steps = 10,
198
+ evaluation_strategy=self.skipbert_cfg.evaluation_strategy,
199
+ save_strategy=self.skipbert_cfg.save_strategy,
200
+ lr_scheduler_type=self.skipbert_cfg.lr_scheduler_type,
201
+ warmup_steps=self.skipbert_cfg.warmup_steps,
202
+ weight_decay=self.skipbert_cfg.weight_decay,
203
+ logging_dir=self.skipbert_cfg.logging_dir,
204
+ report_to='wandb',
205
+ run_name='skipbert',
206
+ do_train=self.skipbert_cfg.do_train,
207
+ do_eval=self.skipbert_cfg.do_eval,
208
+ dataloader_drop_last=False,
209
+ ddp_find_unused_parameters=False,
210
+ group_by_length=True,
211
+ load_best_model_at_end = True
212
+ )
213
+
214
+ # Prepare custom optimizer student model's parameters
215
+ if self.student_data_influence_model is not None:
216
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
217
+ self.student_optimizer_grouped_parameters = [
218
+ {
219
+ 'params': [p for n, p in self.student_data_influence_model.named_parameters() if not any(nd in n for nd in no_decay)],
220
+ 'weight_decay': 0.01
221
+ },
222
+ {
223
+ 'params': [p for n, p in self.student_data_influence_model.named_parameters() if any(nd in n for nd in no_decay)],
224
+ 'weight_decay': 0.0
225
+ }
226
+ ]
227
+ ######
228
 
229
  def _remove_unused_columns(self, dataset, description=None):
230
  """
 
272
 
273
  for epoch in range(self.args.num_train_epochs):
274
  # Check if it's time to update the data influence model and state is True
275
+ if self.mates_cfg.state and epoch % self.mates_cfg.update_data_influence_model_step == 0:
276
  print("Updating the data influence model and selecting high-quality data...")
277
+ # logger.info("Updating the data influence model and selecting high-quality data...")
278
  self.update_data_influence_model()
279
 
280
  # Filter high-quality data using the data influence model
281
  high_quality_indices = self.select_high_quality_data(
282
  dataset_size=len(self.train_dataset),
283
+ selection_fraction=self.mates_cfg.selection_fraction,
284
  )
285
  self.train_loader = self.accelerator.prepare(
286
  self.create_filtered_dataloader(high_quality_indices)
 
308
  epoch_loss += loss.item()
309
 
310
  if (step + 1) % self.args.logging_steps == 0:
311
+ print(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
312
+ # logger.info(f"Step {step + 1}: Train Loss = {epoch_loss / (step + 1):.4f}")
313
 
314
  avg_epoch_loss = epoch_loss / len(self.train_loader)
315
  training_loss.append(avg_epoch_loss)
316
 
317
  val_results = self.evaluate()
318
 
319
+ print(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
320
+ # logger.info(f"Epoch {epoch + 1}: Train Loss = {avg_epoch_loss:.4f}, Val Loss = {val_results['eval_loss']:.4f}")
321
 
322
  # Early stopping logic
323
  if val_results["eval_loss"] < best_val_loss:
 
327
  early_stopping_counter += 1
328
  if early_stopping_counter >= early_stopping_patience:
329
  print("Early stopping triggered")
330
+ # logger.info("Early stopping triggered")
331
  break
332
 
333
  return {"training_loss": sum(training_loss) / len(training_loss), "best_val_loss": best_val_loss}
 
337
  Use the data influence model to predict quality scores and select high-quality data indices.
338
  """
339
  print("Selecting high-quality data using the data influence model...")
340
+ # logger.info("Selecting high-quality data using the data influence model...")
341
 
342
  # Predict influence scores for the entire dataset
343
  influence_scores = []
344
+ self.student_data_influence_model.eval()
345
  influence_optimizer = self.accelerator.prepare(
346
+ torch.optim.AdamW(
347
+ self.student_optimizer_grouped_parameters,
348
+ lr=self.skipbert_train_args.learning_rate,
349
+ )
350
  )
351
  i = 0
352
+
353
+ start_time = time.perf_counter()
354
  with torch.no_grad():
355
  for batch in self.full_train_loader: # Full dataset loader
356
  text = self.tokenizer.batch_decode(
 
369
 
370
  # Train the data influence model
371
  influence_optimizer.zero_grad()
372
+ logits, attn_outputs, hidn_output = self.student_data_influence_model(
373
  input_ids=bert_inputs['input_ids'],
374
  attention_mask=bert_inputs['attention_mask'],
375
  )
376
 
377
+
378
+ influence_scores.extend(logits.squeeze(-1).cpu().numpy())
379
 
380
  i += 1
381
 
 
385
  # Normalize influence scores and apply Gumbel-Top-$k$ selection
386
  influence_scores = np.array(influence_scores)
387
  print(">> Influence scores shape:", influence_scores.shape)
388
+ # logger.info(">> Influence scores shape:", influence_scores.shape)
389
 
390
  # Add Gumbel noise for diversity
391
  rng = np.random.default_rng()
 
396
  selection_size = int(len(influence_scores)*selection_fraction)
397
  high_quality_indices = np.argpartition(-influence_scores, selection_size)[:selection_size]
398
  print(f"Selected {len(high_quality_indices)} high-quality samples.")
399
+ # logger.info(f"Selected {len(high_quality_indices)} high-quality samples.")
400
+
401
+ end_time = time.perf_counter()
402
+ runtime = round((end_time - start_time), 2)
403
+ time_format(runtime, logger)
404
+
405
  return high_quality_indices
406
 
407
  def create_filtered_dataloader(self, indices):
 
409
  Create a new dataloader with only the selected high-quality data.
410
  """
411
  print("Creating a filtered dataloader with selected high-quality data...")
412
+ # logger.info("Creating a filtered dataloader with selected high-quality data...")
413
  subset_dataset = torch.utils.data.Subset(self.train_dataset, indices)
414
  return torch.utils.data.DataLoader(
415
  subset_dataset,
 
424
  # Train a copy of the model on holdout data and validate on reference data
425
  copied_model = copy.deepcopy(self.model)
426
  copied_model.train()
427
+ self.accelerator.state._reset_state()
428
+ self.accelerator = Accelerator()
429
  optimizer = self.accelerator.prepare(
430
  torch.optim.Adam(copied_model.parameters(), lr=self.args.learning_rate)
431
  )
432
  holdout_reference_pairs = []
433
 
434
+ print("Starting to collect holdout-reference pairs...")
435
+ # logger.info("Starting to collect holdout-reference pairs...")
436
  for step, holdout_batch in enumerate(self.holdout_loader):
437
+ print(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
438
+ # logger.info(f"Processing holdout batch {step+1}/{len(self.holdout_loader)}...")
439
 
440
  optimizer.zero_grad()
441
  outputs = copied_model(
 
453
  optimizer.step()
454
 
455
  print(f"Evaluating reference losses at step {step}...")
456
+ # logger.info(f"Evaluating reference losses at step {step}...")
457
 
458
  copied_model.eval()
459
  reference_losses = []
 
474
 
475
  # Train the data influence model using the generated pairs
476
  print("Starting to train the data influence model...")
477
+ # logger.info("Starting to train the data influence model...")
478
 
479
+ self.teacher_data_influence_model.train()
480
+ influence_optimizer = torch.optim.AdamW(self.teacher_data_influence_model.parameters(), lr=self.args.learning_rate)
481
+
482
+ list_texts, list_score = [], []
483
+ batch_size = 0
484
+ # Convert to Dataset objective
485
+ for texts, score in holdout_reference_pairs:
486
+ if batch_size == 0:
487
+ batch_size = len(texts)
488
+ list_texts.extend(texts)
489
+ list_score.extend([score] * len(texts))
490
+
491
+
492
+ holdout_reference_pairs = {'text': list_texts, 'label': list_score}
493
+ holdout_reference_pairs = Dataset.from_dict(holdout_reference_pairs)
494
+
495
+ # Wrap the function with partial
496
+ convert_func = partial(
497
+ convert_to_tokens_reg,
498
+ tokenizer=self.data_influence_tokenizer,
499
+ max_seq_length=self.skipbert_cfg.max_seq_length,
500
+ device=self.accelerator.device
501
+ )
502
+ holdout_reference_pairs_loader = DataLoader(
503
+ holdout_reference_pairs.map(
504
+ convert_func,
505
+ batched=True,
506
+ num_proc=8,
507
+ remove_columns=holdout_reference_pairs.column_names
508
+ ),
509
+ batch_size=batch_size,
510
+ collate_fn=DataCollatorWithPadding(tokenizer=self.data_influence_tokenizer, padding=True, max_length=self.skipbert_cfg.max_seq_length), # Use the same collate function
511
+ drop_last=self.args.dataloader_drop_last
512
+ )
513
+ loss_mse = torch.nn.MSELoss()
514
+
515
+ for step, batch_input in enumerate(holdout_reference_pairs_loader):
516
  # Tokenize the text using the BERT tokenizer
517
+ batch_input = {k: v.to('cuda:0') for k, v in batch_input.items()} # cuda:0
518
+ # text, score = row['text'], row['label']
519
+ # bert_inputs = self.data_influence_tokenizer(
520
+ # text,
521
+ # truncation=True,
522
+ # padding='max_length',
523
+ # max_length=256,
524
+ # return_tensors='pt'
525
+ # ).to(self.accelerator.device)
526
+
527
+ # # Convert score to tensor and enable gradients
528
+ # score_tensor = torch.tensor([score] * len(text), device=self.accelerator.device, dtype=torch.float32, requires_grad=True)
529
 
530
  # Train the data influence model
531
  influence_optimizer.zero_grad()
532
+ # outputs = self.teacher_data_influence_model(
533
+ # input_ids=bert_inputs['input_ids'],
534
+ # attention_mask=bert_inputs['attention_mask'],
535
+ # labels=score_tensor
536
+ # )
537
+
538
+ outputs = self.teacher_data_influence_model(
539
+ **batch_input
540
  )
541
+
542
+ influence_loss = loss_mse(batch_input['labels'].view(-1), outputs.logits.view(-1))
543
+ # print(f"Loss: {influence_loss} - require_grad: {influence_loss.grad_fn}")
544
+
545
+ influence_loss.requires_grad = True
546
  influence_loss.backward()
547
  influence_optimizer.step()
548
+
549
  if step % 50 == 0:
550
  print(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
551
+ # logger.info(f"[Influence Training] Step {step}: Loss = {influence_loss.item():.4f}")
552
+
553
 
554
+ ### Distillation for SkipBERT ###
555
+ train_converted = holdout_reference_pairs.map(
556
+ convert_func,
557
+ batched=True,
558
+ num_proc=8,
559
+ remove_columns=holdout_reference_pairs.column_names
560
+ )
561
 
562
+
563
+ # Call parent constructor with custom optimizer
564
+ optimizer = torch.optim.AdamW(
565
+ self.student_optimizer_grouped_parameters,
566
+ lr=self.skipbert_train_args.learning_rate,
567
+ )
568
+
569
+ scheduler = get_scheduler(
570
+ name=self.skipbert_train_args.lr_scheduler_type,
571
+ optimizer=optimizer,
572
+ num_warmup_steps=self.skipbert_train_args.warmup_steps,
573
+ # num_training_steps=training_args.max_steps
574
+ num_training_steps=100/(self.skipbert_train_args.per_device_train_batch_size * self.skipbert_train_args.gradient_accumulation_steps)
575
+ )
576
+
577
+ # Initialize the trainer
578
+ trainer = SkipBertTrainer(
579
+ student_model=self.student_data_influence_model,
580
+ teacher_model=self.teacher_data_influence_model,
581
+ args=self.skipbert_train_args,
582
+ train_dataset=train_converted,
583
+ eval_dataset=train_converted.shuffle().select(range(min(len(train_converted),10))),
584
+ compute_metrics=compute_metrics_skipbert,
585
+ # SkipBERT specific arguments
586
+ alpha=0.5,
587
+ temperature=2.0,
588
+ beta=1.0,
589
+ use_logits=self.skipbert_cfg.use_logits,
590
+ use_att=self.skipbert_cfg.use_att,
591
+ use_rep=self.skipbert_cfg.use_rep,
592
+ use_embedding=self.skipbert_cfg.use_embedding,
593
+ att_layer_maps=self.skipbert_cfg.att_layer_maps,
594
+ hid_layer_maps=self.skipbert_cfg.hid_layer_maps,
595
+ epochs_no_cls=self.skipbert_cfg.epochs_no_cls,
596
+ reduce_T=self.skipbert_cfg.reduce_T,
597
+ output_mode=self.skipbert_cfg.output_mode, # 'classification' or 'regression'
598
+ num_masked_layers_teacher=self.skipbert_cfg.num_masked_layers_teacher,
599
+ num_masked_last_layers_teacher=self.skipbert_cfg.num_masked_last_layers_teacher,
600
+ fp16=self.skipbert_cfg.fp16,
601
+ num_full_hidden_layers_student=self.skipbert_cfg.num_full_hidden_layers_student,
602
+ tokenizer=self.data_influence_tokenizer,
603
+ optimizers=(optimizer,scheduler),
604
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
605
+ )
606
+
607
+ # Train the model
608
+ trainer.train()
609
 
610
 
611
 
 
656
  metrics = self.compute_metrics({"predictions": padded_preds, "label_ids": padded_labels})
657
 
658
  metrics.update({"eval_loss": val_loss / len(self.val_loader)})
659
+ print(f"Validation Metrics: {metrics}")
660
+ # logger.info(f"Validation Metrics: {metrics}")
661
 
662
  if wandb_sample:
663
  # Sample Logging
template_FL/src/pyproject.toml CHANGED
@@ -37,14 +37,14 @@ num-server-rounds = 2
37
  num-supernodes = 10
38
 
39
  # Define dataset
40
- dataset.type = 'hete' # type = ['homo','hete']
41
  dataset.name = "vicgalle/alpaca-gpt4"
42
 
43
  # Define model settings
44
  model.name = "Qwen/Qwen2.5-1.5B-Instruct"
45
  model.quantization = 4
46
  model.gradient-checkpointing = true
47
- model.flash_attention = false
48
 
49
  ### Use MATES ###
50
  mates.state = true
@@ -60,67 +60,66 @@ mates.selection-fraction = 0.4
60
 
61
  # Model setting
62
  skipbert.student-model = "bert-base-uncased"
63
- skipbert.num_layers_student = 12
64
- skipbert.num_full_hidden_layers_student = 6
65
- skipbert.num_masked_layers_teacher = 0
66
- skipbert.num_masked_last_layers_teacher = 0
 
67
 
68
  # Training hyperparameters
69
- skipbert.train_batch_size = 8
70
- skipbert.gradient_accumulation_steps = 2
71
- skipbert.eval_batch_size = 8
72
- skipbert.eval_accumulation_steps = 2
73
- skipbert.learning_rate = 2.0e-5
74
- skipbert.num_train_epochs = 10
75
- skipbert.eval_step = 10
76
- skipbert.max_seq_length = 128
77
- skipbert.weight_decay = 1.0e-4
78
- skipbert.warmup_steps = 100 # 500
79
- skipbert.do_train = true
80
- skipbert.do_eval = true
81
- skipbert.max_steps = -1
82
- skipbert.evaluation_strategy = "epoch"
83
- skipbert.save_strategy = "epoch"
84
- skipbert.lr_scheduler_type = "cosine" # or 'warmup_linear'
85
- skipbert.logging_dir = './skipbert_logs'
86
- skipbert.output_dir = "./skipbert_results"
87
- skipbert.report_to = 'wandb'
 
88
 
89
  # Knowledge distillation parameters
90
  skipbert.beta = 0.01
91
  skipbert.T = 1.0
92
  skipbert.alpha = 1.0
93
- skipbert.reduce_T = 1.0
94
- skipbert.epochs_no_cls = 5
95
 
96
  # Training schedule and features
97
- skipbert.freeze_lower_layers = true
98
 
99
  # Feature usage flags
100
- skipbert.use_logits = true
101
- skipbert.use_att = true
102
- skipbert.use_rep = true
103
- skipbert.use_embedding = false
104
 
105
  # Training modes
106
- skipbert.do_train = true
107
- skipbert.do_eval = true
108
- skipbert.do_predict = false
109
- skipbert.do_fit = false
110
  skipbert.fp16 = false
111
- skipbert.no_pretrain = false
112
- skipbert.use_init_weight = false
113
- skipbert.share_param = true
114
- skipbert.do_lower_case = true
115
- skipbert.no_cuda = false
116
 
117
  # N-gram settings
118
- skipbert.n_gram_left = 1
119
- skipbert.n_gram_right = 1
120
 
121
  # Layer mappings
122
- skipbert.att_layer_maps: [1, 3, 5, 7, 9, 11]
123
- skipbert.hid_layer_maps: [6, 7, 8, 9, 10, 11, 12]
124
 
125
  ### END ###
126
 
@@ -138,8 +137,8 @@ train.save-every-round = 5
138
  train.learning-rate-max = 5e-5
139
  train.learning-rate-min = 1e-6
140
  train.seq-length = 256
141
- train.prompt_template_name = "alpaca"
142
- train.train_on_inputs = true
143
  train.verbose = false
144
 
145
  # Define training agruments for HF Trainer
@@ -164,7 +163,7 @@ train.training-arguments.eval-strategy = "epoch"
164
  train.training-arguments.save-strategy = "epoch"
165
  train.training-arguments.ddp-find-unused-parameters = false
166
  train.training-arguments.group-by-length = true
167
- train.training-arguments.load_best_model_at_end = true
168
  train.training-arguments.report-to = "wandb"
169
 
170
  # Define local training settings
 
37
  num-supernodes = 10
38
 
39
  # Define dataset
40
+ dataset.type = 'homo' # type = ['homo','hete']
41
  dataset.name = "vicgalle/alpaca-gpt4"
42
 
43
  # Define model settings
44
  model.name = "Qwen/Qwen2.5-1.5B-Instruct"
45
  model.quantization = 4
46
  model.gradient-checkpointing = true
47
+ model.flash-attention = false
48
 
49
  ### Use MATES ###
50
  mates.state = true
 
60
 
61
  # Model setting
62
  skipbert.student-model = "bert-base-uncased"
63
+ skipbert.output-mode = "regression"
64
+ skipbert.num-layers-student = 12
65
+ skipbert.num-full-hidden-layers-student = 6
66
+ skipbert.num-masked-layers-teacher = 0
67
+ skipbert.num-masked-last-layers-teacher = 0
68
 
69
  # Training hyperparameters
70
+ skipbert.train-batch-size = 4
71
+ skipbert.gradient-accumulation-steps = 1
72
+ skipbert.eval-batch-size = 4
73
+ skipbert.eval-accumulation-steps = 1
74
+ skipbert.learning-rate = 2.0e-5
75
+ skipbert.num-train-epochs = 10
76
+ skipbert.eval-step = 10
77
+ skipbert.max-seq-length = 256
78
+ skipbert.weight-decay = 1.0e-4
79
+ skipbert.warmup-steps = 100 # 500
80
+ skipbert.do-train = true
81
+ skipbert.do-eval = true
82
+ skipbert.do-predict = false
83
+ skipbert.max-steps = -1
84
+ skipbert.evaluation-strategy = "epoch"
85
+ skipbert.save-strategy = "epoch"
86
+ skipbert.lr-scheduler-type = "cosine" # or 'warmup-linear'
87
+ skipbert.logging-dir = './skipbert-logs'
88
+ skipbert.output-dir = "./skipbert-results"
89
+ skipbert.report-to = 'wandb'
90
 
91
  # Knowledge distillation parameters
92
  skipbert.beta = 0.01
93
  skipbert.T = 1.0
94
  skipbert.alpha = 1.0
95
+ skipbert.reduce-T = 1.0
96
+ skipbert.epochs-no-cls = 5
97
 
98
  # Training schedule and features
99
+ skipbert.freeze-lower-layers = true
100
 
101
  # Feature usage flags
102
+ skipbert.use-logits = true
103
+ skipbert.use-att = true
104
+ skipbert.use-rep = true
105
+ skipbert.use-embedding = false
106
 
107
  # Training modes
108
+ skipbert.do-fit = false
 
 
 
109
  skipbert.fp16 = false
110
+ skipbert.no-pretrain = false
111
+ skipbert.use-init-weight = false
112
+ skipbert.share-param = true
113
+ skipbert.do-lower-case = true
114
+ skipbert.no-cuda = false
115
 
116
  # N-gram settings
117
+ skipbert.n-gram-left = 1
118
+ skipbert.n-gram-right = 1
119
 
120
  # Layer mappings
121
+ skipbert.att-layer-maps = "1, 3, 5, 7, 9, 11"
122
+ skipbert.hid-layer-maps = "6, 7, 8, 9, 10, 11, 12"
123
 
124
  ### END ###
125
 
 
137
  train.learning-rate-max = 5e-5
138
  train.learning-rate-min = 1e-6
139
  train.seq-length = 256
140
+ train.prompt-template-name = "alpaca"
141
+ train.train-on-inputs = true
142
  train.verbose = false
143
 
144
  # Define training agruments for HF Trainer
 
163
  train.training-arguments.save-strategy = "epoch"
164
  train.training-arguments.ddp-find-unused-parameters = false
165
  train.training-arguments.group-by-length = true
166
+ train.training-arguments.load-best-model-at-end = true
167
  train.training-arguments.report-to = "wandb"
168
 
169
  # Define local training settings