jw2yang commited on
Commit
6a828c2
·
1 Parent(s): f7db9ca

remove dependencies on deepspeed and wandb

Browse files
Files changed (1) hide show
  1. modeling_magma.py +0 -48
modeling_magma.py CHANGED
@@ -24,7 +24,6 @@ import numpy as np
24
  import torch
25
  import torch.utils.checkpoint
26
  from torch import nn
27
- import wandb
28
  import torch.distributed as dist
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.activations import ACT2FN
@@ -282,12 +281,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
282
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
283
  self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
284
 
285
- try:
286
- if dist.get_rank() == 0:
287
- wandb.init(project=os.environ['WANDB_PROJECT'])
288
- except:
289
- pass
290
-
291
  self.post_init()
292
 
293
  # def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
@@ -325,40 +318,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
325
 
326
  def tie_weights(self):
327
  return self.language_model.tie_weights()
328
-
329
- def load_special_module_from_ckpt(self, ckpt_path, torch_dtype=None):
330
- from deepspeed.runtime.zero import Init
331
- from deepspeed import zero
332
- # Defer initialization for ZeRO-3 compatibility
333
- # with Init(data_parallel_group=None):
334
- # # Initialize the special module
335
- # self.vision_tower = MagmaImageTower(self.config.vision_config, require_pretrained=False)
336
-
337
- # Load checkpoint weights into the special module
338
- checkpoint = torch.load(ckpt_path, map_location='cpu')
339
- state_dict = {k.replace('visual.', ''): v for k, v in checkpoint.items() if 'visual.' in k}
340
-
341
- # Convert checkpoint weights to match model's parameter dtype
342
- if torch_dtype is None:
343
- model_dtype = next(self.vision_tower.clip_vision_model.parameters()).dtype
344
- for k, v in state_dict.items():
345
- state_dict[k] = v.to(model_dtype)
346
- else:
347
- for k, v in state_dict.items():
348
- state_dict[k] = v.to(torch_dtype)
349
-
350
- # Temporarily gather parameters for loading (if ZeRO-3 is active)
351
- with zero.GatheredParameters(list(self.vision_tower.parameters()), modifier_rank=0):
352
- # Load the state dictionary
353
- self.vision_tower.clip_vision_model.load_state_dict(state_dict, strict=False)
354
- # After loading, ensure the module is on the correct device
355
- for param in self.vision_tower.parameters():
356
- param.data = param.data.to(self.device).to(torch_dtype)
357
-
358
- # import pdb; pdb.set_trace()
359
- # If using a DeepSpeed engine, attach the updated module
360
- if hasattr(self, "deepspeed_engine"):
361
- self.deepspeed_engine.module = self
362
 
363
  def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
364
  model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
@@ -832,13 +791,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
832
  # concatenate the action accuracy across all devices
833
  action_accuracy = torch.cat(action_accuracy_gather)
834
 
835
- if dist.get_rank() == 0:
836
- # remove zero values
837
- if action_accuracy.mean() == 0:
838
- wandb.log({"action_accuracy": action_accuracy.mean().item()})
839
- else:
840
- action_accuracy = action_accuracy[action_accuracy != 0]
841
- wandb.log({"action_accuracy": action_accuracy.mean().item()})
842
  else:
843
  logits = self.language_model.lm_head(hidden_states)
844
  logits = logits.float()
 
24
  import torch
25
  import torch.utils.checkpoint
26
  from torch import nn
 
27
  import torch.distributed as dist
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.activations import ACT2FN
 
281
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
282
  self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
283
 
 
 
 
 
 
 
284
  self.post_init()
285
 
286
  # def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
 
318
 
319
  def tie_weights(self):
320
  return self.language_model.tie_weights()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
323
  model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
 
791
  # concatenate the action accuracy across all devices
792
  action_accuracy = torch.cat(action_accuracy_gather)
793
 
 
 
 
 
 
 
 
794
  else:
795
  logits = self.language_model.lm_head(hidden_states)
796
  logits = logits.float()