jupyterjazz makram93 commited on
Commit
6a92924
·
verified ·
1 Parent(s): 955fea2

feat-routing (#26)

Browse files

- feat: set adapter based on prompt (71b163e3134eb40c1a6da775ae7d945134e986a7)
- fix: read prompts from config (51411ff21ad7889871b6465f876eb0a0ade0a7d0)
- fix: check for exact task names (b3c540ce62778ffa8b9e0fa69bbbaa1042fc337e)


Co-authored-by: Mohammad Kalim Akram <[email protected]>

configuration_xlm_roberta.py CHANGED
@@ -23,6 +23,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
 
26
  lora_rank=4,
27
  lora_dropout_p=0.0,
28
  lora_alpha=1,
@@ -55,6 +56,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
55
  self.classifier_dropout = classifier_dropout
56
  self.load_trained_adapters = load_trained_adapters
57
  self.lora_adaptations = lora_adaptations
 
58
  self.lora_rank = lora_rank
59
  self.lora_dropout_p = lora_dropout_p
60
  self.lora_alpha = lora_alpha
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
26
+ lora_prompts=None,
27
  lora_rank=4,
28
  lora_dropout_p=0.0,
29
  lora_alpha=1,
 
56
  self.classifier_dropout = classifier_dropout
57
  self.load_trained_adapters = load_trained_adapters
58
  self.lora_adaptations = lora_adaptations
59
+ self.lora_prompts = lora_prompts
60
  self.lora_rank = lora_rank
61
  self.lora_dropout_p = lora_dropout_p
62
  self.lora_alpha = lora_alpha
modeling_lora.py CHANGED
@@ -14,9 +14,6 @@ from transformers import PretrainedConfig
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
 
16
 
17
- LORA_NO_UPDATE = '__lora_no_update__'
18
-
19
-
20
  def initialized_weights(
21
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
22
  ) -> torch.Tensor:
@@ -231,6 +228,16 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
231
  raise ValueError(
232
  f'`lora_adaptations` must be a list and contain at least one element'
233
  )
 
 
 
 
 
 
 
 
 
 
234
  self._adaptation_map = {
235
  name: idx for idx, name in enumerate(self._lora_adaptations)
236
  }
@@ -332,9 +339,18 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
332
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
333
  )
334
 
335
- def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
336
- if task != LORA_NO_UPDATE:
337
- self.current_task = task
 
 
 
 
 
 
 
 
 
338
 
339
  return self.roberta(*args, **kwargs)
340
 
@@ -355,7 +371,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
355
  def encode(
356
  self,
357
  *args,
358
- task: Union[str, None] = LORA_NO_UPDATE,
359
  **kwargs,
360
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
361
  """
@@ -364,18 +380,24 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
364
  task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
365
  Specifies the task for which the encoding is intended. This parameter controls the
366
  use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
367
- to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
368
- existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
369
- adapters are disabled, and the model reverts to its original, general-purpose weights.
370
- If `task` is set to a specific LoRA adaptation, that adaptation is activated.
371
  """
372
- if task != LORA_NO_UPDATE:
373
- if not task:
 
 
 
 
 
 
 
374
  warnings.warn(
375
  f"Task-specific embeddings are disabled. To enable, specify the `task` "
376
  f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
377
  category=UserWarning,
378
  )
379
- self.current_task = task
380
 
381
  return self.roberta.encode(*args, **kwargs)
 
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
 
16
 
 
 
 
17
  def initialized_weights(
18
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
19
  ) -> torch.Tensor:
 
228
  raise ValueError(
229
  f'`lora_adaptations` must be a list and contain at least one element'
230
  )
231
+ self._lora_prompts = config.lora_prompts
232
+ if (
233
+ not isinstance(self._lora_prompts, dict)
234
+ or len(self._lora_prompts) != len(self._lora_adaptations)
235
+ or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
236
+ ):
237
+ raise ValueError(
238
+ f'`lora_prompts` must be a dict and contain the same number of elements '
239
+ f'as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`.'
240
+ )
241
  self._adaptation_map = {
242
  name: idx for idx, name in enumerate(self._lora_adaptations)
243
  }
 
339
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
340
  )
341
 
342
+ def forward(self, *args, task_type: Union[str, None] = None, **kwargs):
343
+ if task_type:
344
+ self.current_task = task_type
345
+ else:
346
+ input_ids = kwargs["input_ids"]
347
+ input_text = self.roberta.tokenizer.decode(input_ids[0], skip_special_tokens=True)
348
+ for task_name, prompt in self._lora_prompts.items():
349
+ if input_text.startswith(prompt):
350
+ self.current_task = task_name
351
+ break
352
+ else:
353
+ self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
354
 
355
  return self.roberta(*args, **kwargs)
356
 
 
371
  def encode(
372
  self,
373
  *args,
374
+ task_type: Union[str, None] = None,
375
  **kwargs,
376
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
377
  """
 
380
  task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
381
  Specifies the task for which the encoding is intended. This parameter controls the
382
  use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
383
+ to `None`, all LoRA adapters are disabled, and the model reverts to its original,
384
+ general-purpose weights. If `task` is set to a specific LoRA adaptation, that adaptation
385
+ is activated.
 
386
  """
387
+ if task_type:
388
+ self.current_task = task_type
389
+ else: # infer the task from the input text
390
+ input_text = args[0][0] if isinstance(args[0], list) else args[0] # take only the first sentence
391
+ for task_name, prompt in self._lora_prompts.items():
392
+ if input_text.startswith(prompt):
393
+ self.current_task = task_name
394
+ break
395
+ else:
396
  warnings.warn(
397
  f"Task-specific embeddings are disabled. To enable, specify the `task` "
398
  f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
399
  category=UserWarning,
400
  )
401
+ self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
402
 
403
  return self.roberta.encode(*args, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -21,7 +21,7 @@ import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
- from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
@@ -440,7 +440,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
440
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
-
444
 
445
  @torch.inference_mode()
446
  def encode(
@@ -492,12 +492,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
492
  If convert_to_tensor, a stacked tensor is returned.
493
  If convert_to_numpy, a numpy matrix is returned.
494
  """
495
- from transformers import AutoTokenizer
496
-
497
- self.tokenizer = AutoTokenizer.from_pretrained(
498
- self.name_or_path, trust_remote_code=True
499
- )
500
-
501
  is_training = self.training
502
  self.eval()
503
 
@@ -1278,4 +1272,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1278
  logits=logits,
1279
  hidden_states=outputs.hidden_states,
1280
  attentions=outputs.attentions,
1281
- )
 
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
+ from transformers import PretrainedConfig, AutoTokenizer
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
 
440
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
+ self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
444
 
445
  @torch.inference_mode()
446
  def encode(
 
492
  If convert_to_tensor, a stacked tensor is returned.
493
  If convert_to_numpy, a numpy matrix is returned.
494
  """
 
 
 
 
 
 
495
  is_training = self.training
496
  self.eval()
497
 
 
1272
  logits=logits,
1273
  hidden_states=outputs.hidden_states,
1274
  attentions=outputs.attentions,
1275
+ )