remove dependencies on deepspeed and wandb
Browse files- modeling_magma.py +0 -48
modeling_magma.py
CHANGED
@@ -24,7 +24,6 @@ import numpy as np
|
|
24 |
import torch
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
27 |
-
import wandb
|
28 |
import torch.distributed as dist
|
29 |
from transformers.modeling_utils import PreTrainedModel
|
30 |
from transformers.activations import ACT2FN
|
@@ -282,12 +281,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
282 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
283 |
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
284 |
|
285 |
-
try:
|
286 |
-
if dist.get_rank() == 0:
|
287 |
-
wandb.init(project=os.environ['WANDB_PROJECT'])
|
288 |
-
except:
|
289 |
-
pass
|
290 |
-
|
291 |
self.post_init()
|
292 |
|
293 |
# def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
|
@@ -325,40 +318,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
325 |
|
326 |
def tie_weights(self):
|
327 |
return self.language_model.tie_weights()
|
328 |
-
|
329 |
-
def load_special_module_from_ckpt(self, ckpt_path, torch_dtype=None):
|
330 |
-
from deepspeed.runtime.zero import Init
|
331 |
-
from deepspeed import zero
|
332 |
-
# Defer initialization for ZeRO-3 compatibility
|
333 |
-
# with Init(data_parallel_group=None):
|
334 |
-
# # Initialize the special module
|
335 |
-
# self.vision_tower = MagmaImageTower(self.config.vision_config, require_pretrained=False)
|
336 |
-
|
337 |
-
# Load checkpoint weights into the special module
|
338 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
339 |
-
state_dict = {k.replace('visual.', ''): v for k, v in checkpoint.items() if 'visual.' in k}
|
340 |
-
|
341 |
-
# Convert checkpoint weights to match model's parameter dtype
|
342 |
-
if torch_dtype is None:
|
343 |
-
model_dtype = next(self.vision_tower.clip_vision_model.parameters()).dtype
|
344 |
-
for k, v in state_dict.items():
|
345 |
-
state_dict[k] = v.to(model_dtype)
|
346 |
-
else:
|
347 |
-
for k, v in state_dict.items():
|
348 |
-
state_dict[k] = v.to(torch_dtype)
|
349 |
-
|
350 |
-
# Temporarily gather parameters for loading (if ZeRO-3 is active)
|
351 |
-
with zero.GatheredParameters(list(self.vision_tower.parameters()), modifier_rank=0):
|
352 |
-
# Load the state dictionary
|
353 |
-
self.vision_tower.clip_vision_model.load_state_dict(state_dict, strict=False)
|
354 |
-
# After loading, ensure the module is on the correct device
|
355 |
-
for param in self.vision_tower.parameters():
|
356 |
-
param.data = param.data.to(self.device).to(torch_dtype)
|
357 |
-
|
358 |
-
# import pdb; pdb.set_trace()
|
359 |
-
# If using a DeepSpeed engine, attach the updated module
|
360 |
-
if hasattr(self, "deepspeed_engine"):
|
361 |
-
self.deepspeed_engine.module = self
|
362 |
|
363 |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
364 |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
@@ -832,13 +791,6 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
832 |
# concatenate the action accuracy across all devices
|
833 |
action_accuracy = torch.cat(action_accuracy_gather)
|
834 |
|
835 |
-
if dist.get_rank() == 0:
|
836 |
-
# remove zero values
|
837 |
-
if action_accuracy.mean() == 0:
|
838 |
-
wandb.log({"action_accuracy": action_accuracy.mean().item()})
|
839 |
-
else:
|
840 |
-
action_accuracy = action_accuracy[action_accuracy != 0]
|
841 |
-
wandb.log({"action_accuracy": action_accuracy.mean().item()})
|
842 |
else:
|
843 |
logits = self.language_model.lm_head(hidden_states)
|
844 |
logits = logits.float()
|
|
|
24 |
import torch
|
25 |
import torch.utils.checkpoint
|
26 |
from torch import nn
|
|
|
27 |
import torch.distributed as dist
|
28 |
from transformers.modeling_utils import PreTrainedModel
|
29 |
from transformers.activations import ACT2FN
|
|
|
281 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
282 |
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
self.post_init()
|
285 |
|
286 |
# def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
|
318 |
|
319 |
def tie_weights(self):
|
320 |
return self.language_model.tie_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
323 |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
|
|
791 |
# concatenate the action accuracy across all devices
|
792 |
action_accuracy = torch.cat(action_accuracy_gather)
|
793 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
794 |
else:
|
795 |
logits = self.language_model.lm_head(hidden_states)
|
796 |
logits = logits.float()
|