Add files using upload-large-folder tool
Browse files- config.json +0 -8
- config_molmoe.py +48 -291
- example.py +55 -0
- modeling_molmoe.py +14 -94
config.json
CHANGED
@@ -95,14 +95,6 @@
|
|
95 |
"rope_theta": 10000.0,
|
96 |
"scale_logits": false,
|
97 |
"system_prompt_kind": "demo_or_style",
|
98 |
-
"tokenizer": {
|
99 |
-
"identifier": "allenai/gpt-neox-olmo-dolma-v1_5",
|
100 |
-
"olmo_bos_token_id": null,
|
101 |
-
"olmo_eos_token_id": null,
|
102 |
-
"tokenizer_adds_space": false,
|
103 |
-
"tokenizer_dir": null,
|
104 |
-
"truncate_direction": "right"
|
105 |
-
},
|
106 |
"transformers_version": "4.45.0.dev0",
|
107 |
"unconditioned": false,
|
108 |
"use_cache": true,
|
|
|
95 |
"rope_theta": 10000.0,
|
96 |
"scale_logits": false,
|
97 |
"system_prompt_kind": "demo_or_style",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
"transformers_version": "4.45.0.dev0",
|
99 |
"unconditioned": false,
|
100 |
"use_cache": true,
|
config_molmoe.py
CHANGED
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|
2 |
|
3 |
import logging
|
4 |
from dataclasses import asdict, dataclass, field
|
|
|
5 |
from glob import glob
|
|
|
6 |
from pathlib import Path
|
7 |
from typing import (
|
8 |
Any,
|
@@ -17,168 +19,36 @@ from typing import (
|
|
17 |
cast,
|
18 |
)
|
19 |
|
20 |
-
import torch
|
21 |
from transformers import PretrainedConfig
|
22 |
-
|
23 |
-
from omegaconf import OmegaConf as om
|
24 |
-
from omegaconf.errors import OmegaConfBaseException
|
25 |
-
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
26 |
-
import gin
|
27 |
-
|
28 |
-
#from olmo.aliases import PathOrStr
|
29 |
-
from .aliases import PathOrStr
|
30 |
-
#from olmo.exceptions import OLMoConfigurationError
|
31 |
-
from .exceptions import OLMoConfigurationError
|
32 |
-
#from olmo.util import StrEnum, resource_path
|
33 |
-
from .util import StrEnum, resource_path
|
34 |
-
|
35 |
-
#from olmo.mm_data.data_utils import build_tokenizer
|
36 |
-
from .data_utils import build_tokenizer
|
37 |
-
#from olmo.multimodal_preprocessor import MultiModalPreprocessor
|
38 |
-
from .multimodal_preprocessor import MultiModalPreprocessor
|
39 |
-
|
40 |
-
__all__ = [
|
41 |
-
"ActivationType",
|
42 |
-
"ActivationCheckpointingStrategy",
|
43 |
-
"BlockType",
|
44 |
-
"LayerNormType",
|
45 |
-
"VisionBackboneType",
|
46 |
-
"VisionBackboneConfig",
|
47 |
-
"InitFnType",
|
48 |
-
"ModelConfig",
|
49 |
-
"OptimizerType",
|
50 |
-
"OptimizerConfig",
|
51 |
-
"SchedulerType",
|
52 |
-
"SchedulerConfig",
|
53 |
-
"DataConfig",
|
54 |
-
"InstanceFilterConfig",
|
55 |
-
"EvaluatorConfig",
|
56 |
-
"TokenizerConfig",
|
57 |
-
"TrainConfig",
|
58 |
-
"PaddingDirection",
|
59 |
-
"TruncationDirection",
|
60 |
-
"SpeedMonitorConfig",
|
61 |
-
"WandbConfig",
|
62 |
-
"CompilerConfig",
|
63 |
-
"WandbConfig",
|
64 |
-
"FSDPPrecision",
|
65 |
-
"FSDPWrapStrategy",
|
66 |
-
"FSDPConfig",
|
67 |
-
"CheckpointType",
|
68 |
-
]
|
69 |
|
70 |
C = TypeVar("C", bound="BaseConfig")
|
71 |
D = TypeVar("D", bound="DictConfig|ListConfig")
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
class AttentionType(StrEnum):
|
75 |
sdpa = "sdpa"
|
76 |
direct = "direct"
|
77 |
flash = "flash"
|
78 |
|
79 |
|
80 |
-
class BaseConfig:
|
81 |
-
@classmethod
|
82 |
-
def _register_resolvers(cls, validate_paths: bool = True):
|
83 |
-
# Expands path globs into a list.
|
84 |
-
def path_glob(*paths) -> List[str]:
|
85 |
-
out = []
|
86 |
-
for path in paths:
|
87 |
-
matches = sorted(glob(path))
|
88 |
-
if not matches and validate_paths:
|
89 |
-
raise FileNotFoundError(f"{path} does not match any files or dirs")
|
90 |
-
out.extend(matches)
|
91 |
-
return out
|
92 |
-
|
93 |
-
# Chooses the first path in the arguments that exists.
|
94 |
-
def path_choose(*paths) -> str:
|
95 |
-
from .util import is_url
|
96 |
-
|
97 |
-
for path in paths:
|
98 |
-
if is_url(path) or Path(path).exists():
|
99 |
-
return path
|
100 |
-
if validate_paths:
|
101 |
-
raise FileNotFoundError(", ".join(paths))
|
102 |
-
else:
|
103 |
-
return ""
|
104 |
-
|
105 |
-
# Finds the latest checkpoint in a folder.
|
106 |
-
def path_last_checkpoint(path) -> str:
|
107 |
-
from .util import find_latest_checkpoint
|
108 |
-
|
109 |
-
latest_checkpoint = find_latest_checkpoint(path)
|
110 |
-
if latest_checkpoint is None:
|
111 |
-
if validate_paths:
|
112 |
-
raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
|
113 |
-
else:
|
114 |
-
return ""
|
115 |
-
else:
|
116 |
-
return str(latest_checkpoint)
|
117 |
-
|
118 |
-
om.register_new_resolver("path.glob", path_glob, replace=True)
|
119 |
-
om.register_new_resolver("path.choose", path_choose, replace=True)
|
120 |
-
om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
|
121 |
-
|
122 |
-
@classmethod
|
123 |
-
def update_legacy_settings(cls, config: D) -> D:
|
124 |
-
"""
|
125 |
-
Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
|
126 |
-
"""
|
127 |
-
return config
|
128 |
-
|
129 |
-
@classmethod
|
130 |
-
def new(cls: Type[C], **kwargs) -> C:
|
131 |
-
cls._register_resolvers()
|
132 |
-
conf = om.structured(cls)
|
133 |
-
try:
|
134 |
-
if kwargs:
|
135 |
-
conf = om.merge(conf, kwargs)
|
136 |
-
return cast(C, om.to_object(conf))
|
137 |
-
except OmegaConfBaseException as e:
|
138 |
-
raise OLMoConfigurationError(str(e))
|
139 |
-
|
140 |
-
@classmethod
|
141 |
-
def load(
|
142 |
-
cls: Type[C],
|
143 |
-
path: PathOrStr,
|
144 |
-
overrides: Optional[List[str]] = None,
|
145 |
-
key: Optional[str] = None,
|
146 |
-
validate_paths: bool = True,
|
147 |
-
) -> C:
|
148 |
-
"""Load from a YAML file."""
|
149 |
-
cls._register_resolvers(validate_paths=validate_paths)
|
150 |
-
schema = om.structured(cls)
|
151 |
-
try:
|
152 |
-
raw = om.load(str(path))
|
153 |
-
|
154 |
-
# Backwards compatibility hack, we need this here not in `update_legacy_settings`
|
155 |
-
# since it has to be applied before selecting with `key`
|
156 |
-
if "tokenizer" in raw and "model" in raw:
|
157 |
-
raw["model"]["tokenizer"] = raw.pop("tokenizer")
|
158 |
-
|
159 |
-
if key is not None:
|
160 |
-
raw = raw[key] # type: ignore
|
161 |
-
raw = cls.update_legacy_settings(raw)
|
162 |
-
conf = om.merge(schema, raw)
|
163 |
-
if overrides:
|
164 |
-
conf = om.merge(conf, om.from_dotlist(overrides))
|
165 |
-
return cast(C, om.to_object(conf))
|
166 |
-
except OmegaConfBaseException as e:
|
167 |
-
raise OLMoConfigurationError(str(e))
|
168 |
-
|
169 |
-
def save(self, path: PathOrStr) -> None:
|
170 |
-
"""Save to a YAML file."""
|
171 |
-
om.save(config=self, f=str(path))
|
172 |
-
|
173 |
-
def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
|
174 |
-
out = asdict(self) # type: ignore
|
175 |
-
if exclude is not None:
|
176 |
-
for name in exclude:
|
177 |
-
if name in out:
|
178 |
-
del out[name]
|
179 |
-
return out
|
180 |
-
|
181 |
-
|
182 |
class LayerNormType(StrEnum):
|
183 |
default = "default"
|
184 |
"""
|
@@ -290,7 +160,7 @@ class ImageProjectType(StrEnum):
|
|
290 |
|
291 |
|
292 |
@dataclass
|
293 |
-
class VisionBackboneConfig
|
294 |
image_model_type: VisionBackboneType = VisionBackboneType.openai
|
295 |
image_default_input_size: Tuple[int, int] = (336, 336)
|
296 |
image_patch_size: int = 14
|
@@ -328,18 +198,7 @@ class TruncationDirection(StrEnum):
|
|
328 |
|
329 |
|
330 |
@dataclass
|
331 |
-
class
|
332 |
-
identifier: str = "gpt2"
|
333 |
-
truncate_direction: TruncationDirection = TruncationDirection.right
|
334 |
-
# Does the tokenizer automatically start input text with a space
|
335 |
-
tokenizer_adds_space: Optional[bool] = False
|
336 |
-
tokenizer_dir: Optional[str] = None # tokenizer directory if using a seqio tokenizer
|
337 |
-
olmo_bos_token_id: Optional[int] = None
|
338 |
-
olmo_eos_token_id: Optional[int] = None
|
339 |
-
|
340 |
-
|
341 |
-
@dataclass
|
342 |
-
class ModelConfig(BaseConfig):
|
343 |
"""
|
344 |
OLMo (model) configuration.
|
345 |
"""
|
@@ -429,11 +288,6 @@ class ModelConfig(BaseConfig):
|
|
429 |
|
430 |
rope_impl: str = "cockatoo"
|
431 |
|
432 |
-
vision_backbone: Optional[VisionBackboneConfig] = None
|
433 |
-
"""
|
434 |
-
Vision backbone settings for multi-modal models.
|
435 |
-
"""
|
436 |
-
|
437 |
vit_load_path: Optional[str] = None
|
438 |
"""
|
439 |
Use this to load the vit model.
|
@@ -749,129 +603,10 @@ class ModelConfig(BaseConfig):
|
|
749 |
Used for Gemma-2.
|
750 |
"""
|
751 |
|
752 |
-
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
|
753 |
-
"""
|
754 |
-
Tokenizer configuration.
|
755 |
-
"""
|
756 |
-
|
757 |
loss_token_weighting: Optional[str] = None
|
758 |
|
759 |
gin_bindings: Optional[str] = None
|
760 |
|
761 |
-
def get_tokenizer(self):
|
762 |
-
tokenizer_cfg = self.tokenizer
|
763 |
-
assert tokenizer_cfg.identifier.startswith("mm:")
|
764 |
-
kargs = {}
|
765 |
-
if tokenizer_cfg.identifier[3:].startswith("olmo-"):
|
766 |
-
kargs["olmo_bos_token_id"] = tokenizer_cfg.olmo_bos_token_id
|
767 |
-
kargs["olmo_eos_token_id"] = tokenizer_cfg.olmo_eos_token_id
|
768 |
-
return build_tokenizer(
|
769 |
-
tokenizer_cfg.identifier[3:],
|
770 |
-
adds_space=tokenizer_cfg.tokenizer_adds_space,
|
771 |
-
tokenizer_dir=tokenizer_cfg.tokenizer_dir,
|
772 |
-
pad_tokenizer_to=self.vocab_size if self.pad_tokenizer else None,
|
773 |
-
**kargs
|
774 |
-
)
|
775 |
-
|
776 |
-
def get_preprocessor(self):
|
777 |
-
vision_cfg = self.vision_backbone
|
778 |
-
h, w = self.llm_patches_per_crop()
|
779 |
-
|
780 |
-
return MultiModalPreprocessor(
|
781 |
-
loss_token_weighting=self.loss_token_weighting,
|
782 |
-
always_start_with_space=self.always_start_with_space,
|
783 |
-
tokenizer=self.get_tokenizer(),
|
784 |
-
prompt_override=self.prompt_override,
|
785 |
-
fix_image_input_idx=self.fix_image_input_idx,
|
786 |
-
prompt_templates=self.prompt_type,
|
787 |
-
system_prompt=self.system_prompt_kind,
|
788 |
-
default_inference_len=self.default_inference_len,
|
789 |
-
message_format=self.message_formatting,
|
790 |
-
unconditioned=self.unconditioned,
|
791 |
-
crop_mode=self.crop_mode,
|
792 |
-
max_crops=self.max_crops,
|
793 |
-
do_random_scale=self.do_random_scale,
|
794 |
-
base_image_input_size=vision_cfg.image_default_input_size,
|
795 |
-
image_patch_size=vision_cfg.image_patch_size,
|
796 |
-
image_token_length_h=h,
|
797 |
-
image_token_length_w=w,
|
798 |
-
use_col_tokens=self.use_col_tokens,
|
799 |
-
overlap_margins=self.overlap_margins,
|
800 |
-
image_padding_mask=self.image_padding_embed is not None
|
801 |
-
)
|
802 |
-
|
803 |
-
def __post_init__(self):
|
804 |
-
self.vit_layers = tuple(self.vit_layers) # type: ignore[assignment]
|
805 |
-
|
806 |
-
@classmethod
|
807 |
-
def update_legacy_settings(cls, config: D) -> D:
|
808 |
-
"""
|
809 |
-
Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
|
810 |
-
"""
|
811 |
-
if "flash_attention" in config:
|
812 |
-
is_flash = config.flash_attention
|
813 |
-
del config.flash_attention
|
814 |
-
config.attention_type = AttentionType.flash if is_flash else AttentionType.sdpa
|
815 |
-
|
816 |
-
if "bos_token_id" in config:
|
817 |
-
config.tokenizer.olmo_bos_token_id = config.pop("bos_token_id")
|
818 |
-
config.tokenizer.olmo_eos_token_id = config.pop("eos_token_id")
|
819 |
-
|
820 |
-
if "image_padding_mask" in config:
|
821 |
-
assert not config["image_padding_mask"]
|
822 |
-
del config["image_padding_mask"]
|
823 |
-
config["image_padding_embed"] = None
|
824 |
-
elif "image_padding_embed" not in config:
|
825 |
-
config["image_padding_embed"] = None
|
826 |
-
return config
|
827 |
-
|
828 |
-
@property
|
829 |
-
def effective_n_kv_heads(self) -> int:
|
830 |
-
if self.n_kv_heads is None:
|
831 |
-
if self.multi_query_attention is True:
|
832 |
-
return 1
|
833 |
-
else:
|
834 |
-
return self.n_heads
|
835 |
-
else:
|
836 |
-
if self.multi_query_attention is None:
|
837 |
-
return self.n_kv_heads
|
838 |
-
if self.multi_query_attention:
|
839 |
-
n_kv_heads_should_be = 1
|
840 |
-
else:
|
841 |
-
n_kv_heads_should_be = self.n_heads
|
842 |
-
if self.n_kv_heads == n_kv_heads_should_be:
|
843 |
-
return n_kv_heads_should_be
|
844 |
-
else:
|
845 |
-
raise OLMoConfigurationError(
|
846 |
-
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
847 |
-
)
|
848 |
-
|
849 |
-
@property
|
850 |
-
def image_num_patch(self):
|
851 |
-
assert self.vision_backbone is not None
|
852 |
-
return self.vision_backbone.image_num_patch
|
853 |
-
|
854 |
-
@property
|
855 |
-
def image_patch_size(self):
|
856 |
-
assert self.vision_backbone is not None
|
857 |
-
return self.visoin_backbone.image_patch_size
|
858 |
-
|
859 |
-
def llm_patches_per_crop(self):
|
860 |
-
h, w = self.image_num_patch
|
861 |
-
# Round up in case we need to pad the image features for pooling
|
862 |
-
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
863 |
-
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
864 |
-
return h, w
|
865 |
-
|
866 |
-
def get_max_crops(self) -> int:
|
867 |
-
"""Max numbers of that can be built for one image"""
|
868 |
-
if self.crop_mode == "resize":
|
869 |
-
return 1
|
870 |
-
elif "resize" in self.crop_mode:
|
871 |
-
return 1 + self.max_crops
|
872 |
-
else:
|
873 |
-
return self.max_crops
|
874 |
-
|
875 |
|
876 |
class MolmoConfig(PretrainedConfig):
|
877 |
model_type = "molmo"
|
@@ -879,7 +614,7 @@ class MolmoConfig(PretrainedConfig):
|
|
879 |
|
880 |
def __init__(self, use_cache: bool = False, **kwargs):
|
881 |
model_config = ModelConfig()
|
882 |
-
all_kwargs =
|
883 |
all_kwargs.update(kwargs)
|
884 |
all_kwargs.update({"use_cache": use_cache})
|
885 |
all_kwargs.update(
|
@@ -901,8 +636,8 @@ class MolmoConfig(PretrainedConfig):
|
|
901 |
|
902 |
@property
|
903 |
def image_num_patch(self):
|
904 |
-
|
905 |
-
return
|
906 |
|
907 |
@property
|
908 |
def llm_patches_per_crop(self):
|
@@ -910,4 +645,26 @@ class MolmoConfig(PretrainedConfig):
|
|
910 |
# Round up in case we need to pad the image features for pooling
|
911 |
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
912 |
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
913 |
-
return h, w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import logging
|
4 |
from dataclasses import asdict, dataclass, field
|
5 |
+
from enum import Enum
|
6 |
from glob import glob
|
7 |
+
from os import PathLike
|
8 |
from pathlib import Path
|
9 |
from typing import (
|
10 |
Any,
|
|
|
19 |
cast,
|
20 |
)
|
21 |
|
|
|
22 |
from transformers import PretrainedConfig
|
23 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
C = TypeVar("C", bound="BaseConfig")
|
26 |
D = TypeVar("D", bound="DictConfig|ListConfig")
|
27 |
|
28 |
|
29 |
+
PathOrStr = Union[str, PathLike]
|
30 |
+
|
31 |
+
|
32 |
+
class StrEnum(str, Enum):
|
33 |
+
"""
|
34 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
35 |
+
We include this here for compatibility with older version of Python.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __str__(self) -> str:
|
39 |
+
return self.value
|
40 |
+
|
41 |
+
def __repr__(self) -> str:
|
42 |
+
return f"'{str(self)}'"
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
class AttentionType(StrEnum):
|
47 |
sdpa = "sdpa"
|
48 |
direct = "direct"
|
49 |
flash = "flash"
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
class LayerNormType(StrEnum):
|
53 |
default = "default"
|
54 |
"""
|
|
|
160 |
|
161 |
|
162 |
@dataclass
|
163 |
+
class VisionBackboneConfig:
|
164 |
image_model_type: VisionBackboneType = VisionBackboneType.openai
|
165 |
image_default_input_size: Tuple[int, int] = (336, 336)
|
166 |
image_patch_size: int = 14
|
|
|
198 |
|
199 |
|
200 |
@dataclass
|
201 |
+
class ModelConfig:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
"""
|
203 |
OLMo (model) configuration.
|
204 |
"""
|
|
|
288 |
|
289 |
rope_impl: str = "cockatoo"
|
290 |
|
|
|
|
|
|
|
|
|
|
|
291 |
vit_load_path: Optional[str] = None
|
292 |
"""
|
293 |
Use this to load the vit model.
|
|
|
603 |
Used for Gemma-2.
|
604 |
"""
|
605 |
|
|
|
|
|
|
|
|
|
|
|
606 |
loss_token_weighting: Optional[str] = None
|
607 |
|
608 |
gin_bindings: Optional[str] = None
|
609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
|
611 |
class MolmoConfig(PretrainedConfig):
|
612 |
model_type = "molmo"
|
|
|
614 |
|
615 |
def __init__(self, use_cache: bool = False, **kwargs):
|
616 |
model_config = ModelConfig()
|
617 |
+
all_kwargs = asdict(model_config)
|
618 |
all_kwargs.update(kwargs)
|
619 |
all_kwargs.update({"use_cache": use_cache})
|
620 |
all_kwargs.update(
|
|
|
636 |
|
637 |
@property
|
638 |
def image_num_patch(self):
|
639 |
+
h, w = (336, 336)
|
640 |
+
return h // 14, w // 14
|
641 |
|
642 |
@property
|
643 |
def llm_patches_per_crop(self):
|
|
|
645 |
# Round up in case we need to pad the image features for pooling
|
646 |
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
647 |
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
648 |
+
return h, w
|
649 |
+
|
650 |
+
@property
|
651 |
+
def effective_n_kv_heads(self) -> int:
|
652 |
+
if self.n_kv_heads is None:
|
653 |
+
if self.multi_query_attention is True:
|
654 |
+
return 1
|
655 |
+
else:
|
656 |
+
return self.n_heads
|
657 |
+
else:
|
658 |
+
if self.multi_query_attention is None:
|
659 |
+
return self.n_kv_heads
|
660 |
+
if self.multi_query_attention:
|
661 |
+
n_kv_heads_should_be = 1
|
662 |
+
else:
|
663 |
+
n_kv_heads_should_be = self.n_heads
|
664 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
665 |
+
return n_kv_heads_should_be
|
666 |
+
else:
|
667 |
+
raise ValueError(
|
668 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
669 |
+
)
|
670 |
+
|
example.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
load_path = "."
|
8 |
+
|
9 |
+
# load the processor
|
10 |
+
print("Loading processor")
|
11 |
+
processor = AutoProcessor.from_pretrained(
|
12 |
+
load_path,
|
13 |
+
trust_remote_code=True,
|
14 |
+
torch_dtype='auto',
|
15 |
+
device_map='auto'
|
16 |
+
)
|
17 |
+
|
18 |
+
# load the model
|
19 |
+
print("Loading model")
|
20 |
+
model = AutoModelForCausalLM.from_pretrained(
|
21 |
+
load_path,
|
22 |
+
trust_remote_code=True,
|
23 |
+
torch_dtype='auto',
|
24 |
+
device_map='auto'
|
25 |
+
)
|
26 |
+
|
27 |
+
# process the image and text
|
28 |
+
print("Processing...")
|
29 |
+
inputs = processor.process(
|
30 |
+
images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
|
31 |
+
text="Describe this image."
|
32 |
+
)
|
33 |
+
|
34 |
+
# move inputs to the correct device and make a batch of size 1
|
35 |
+
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
|
36 |
+
|
37 |
+
# generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
|
38 |
+
print("Generating....")
|
39 |
+
output = model.generate_from_batch(
|
40 |
+
inputs,
|
41 |
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
42 |
+
tokenizer=processor.tokenizer
|
43 |
+
)
|
44 |
+
|
45 |
+
# only get generated tokens; decode them to text
|
46 |
+
generated_tokens = output[0,inputs['input_ids'].size(1):]
|
47 |
+
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
48 |
+
|
49 |
+
# print the generated text
|
50 |
+
print(generated_text)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
main()
|
modeling_molmoe.py
CHANGED
@@ -27,7 +27,7 @@ from typing import (
|
|
27 |
Set,
|
28 |
Tuple,
|
29 |
cast,
|
30 |
-
Union,
|
31 |
)
|
32 |
from copy import deepcopy
|
33 |
import torch
|
@@ -36,17 +36,10 @@ import torch.nn as nn
|
|
36 |
import torch.nn.functional as F
|
37 |
from torch import einsum
|
38 |
import einops
|
39 |
-
from transformers import PreTrainedModel
|
40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
41 |
|
42 |
-
from .
|
43 |
-
from .beam_search import (
|
44 |
-
BeamSearch,
|
45 |
-
Constraint,
|
46 |
-
FinalSequenceScorer,
|
47 |
-
Sampler
|
48 |
-
)
|
49 |
-
from .config import (
|
50 |
ActivationType,
|
51 |
BlockType,
|
52 |
LayerNormType,
|
@@ -56,10 +49,10 @@ from .config import (
|
|
56 |
AttentionType,
|
57 |
)
|
58 |
|
59 |
-
|
60 |
from .config_molmoe import (
|
61 |
MolmoConfig,
|
62 |
-
VisionBackboneConfig
|
63 |
)
|
64 |
|
65 |
if sys.version_info.minor > 8:
|
@@ -69,26 +62,14 @@ elif sys.version_info.minor == 8:
|
|
69 |
else:
|
70 |
raise SystemExit("This script supports Python 3.8 or higher")
|
71 |
|
72 |
-
__all__ = [
|
73 |
-
"LayerNormBase",
|
74 |
-
"LayerNorm",
|
75 |
-
"RMSLayerNorm",
|
76 |
-
"RotaryEmbedding",
|
77 |
-
"Activation",
|
78 |
-
"GELU",
|
79 |
-
"ReLU",
|
80 |
-
"SwiGLU",
|
81 |
-
"OLMoBlock",
|
82 |
-
"OLMoSequentialBlock",
|
83 |
-
"OLMo",
|
84 |
-
"OLMoOutput",
|
85 |
-
"OLMoGenerateOutput",
|
86 |
-
]
|
87 |
-
|
88 |
|
89 |
log = logging.getLogger(__name__)
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
92 |
def activation_checkpoint_function(cfg: ModelConfig):
|
93 |
preserve_rng_state = not (
|
94 |
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
|
@@ -114,20 +95,6 @@ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: b
|
|
114 |
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
115 |
|
116 |
|
117 |
-
def activation_checkpoint_function(cfg: MolmoConfig):
|
118 |
-
preserve_rng_state = not (
|
119 |
-
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
|
120 |
-
(cfg.residual_dropout == 0.0) and (cfg.response_residual_dropout == 0.0)
|
121 |
-
)
|
122 |
-
from torch.utils.checkpoint import checkpoint
|
123 |
-
|
124 |
-
return partial(
|
125 |
-
checkpoint,
|
126 |
-
preserve_rng_state=True,
|
127 |
-
use_reentrant=False,
|
128 |
-
)
|
129 |
-
|
130 |
-
|
131 |
def vit_activation_checkpoint_function(cfg: MolmoConfig):
|
132 |
v_cfg = cfg.vision_backbone
|
133 |
preserve_rng_state = (
|
@@ -142,22 +109,6 @@ def vit_activation_checkpoint_function(cfg: MolmoConfig):
|
|
142 |
)
|
143 |
|
144 |
|
145 |
-
def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool:
|
146 |
-
if strategy is None:
|
147 |
-
return False
|
148 |
-
elif (
|
149 |
-
(strategy == ActivationCheckpointingStrategy.whole_layer)
|
150 |
-
or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0)
|
151 |
-
or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0)
|
152 |
-
or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0)
|
153 |
-
or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0)
|
154 |
-
or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0)
|
155 |
-
):
|
156 |
-
return True
|
157 |
-
else:
|
158 |
-
return False
|
159 |
-
|
160 |
-
|
161 |
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
162 |
"""
|
163 |
Cache for attention biases and other things that would normally be stored as buffers.
|
@@ -1557,15 +1508,11 @@ class MolmoVisionBackbone(nn.Module):
|
|
1557 |
self.image_feature_dropout = Dropout(config.image_feature_dropout)
|
1558 |
|
1559 |
@classmethod
|
1560 |
-
def build(cls, config: MolmoConfig)
|
1561 |
v_cfg = config.vision_backbone
|
1562 |
assert v_cfg is not None
|
1563 |
return MolmoPretrainedVisionBackbone(config)
|
1564 |
|
1565 |
-
@abstractmethod
|
1566 |
-
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
1567 |
-
raise NotImplementedError()
|
1568 |
-
|
1569 |
def reset_parameters(self):
|
1570 |
if self.image_pooling_2d is not None:
|
1571 |
self.image_pooling_2d.reset_parameters()
|
@@ -1583,9 +1530,9 @@ class MolmoVisionBackbone(nn.Module):
|
|
1583 |
|
1584 |
|
1585 |
class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
1586 |
-
def __init__(self, config:
|
1587 |
super().__init__(config)
|
1588 |
-
v_cfg =
|
1589 |
|
1590 |
if v_cfg.image_model_type == VisionBackboneType.openai:
|
1591 |
self.image_vit = VisionTransformer(config)
|
@@ -1640,11 +1587,6 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1640 |
if self.config.use_cls_feature:
|
1641 |
nn.init.xavier_uniform_(self.cls_projector.weight)
|
1642 |
|
1643 |
-
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
1644 |
-
self.grad_checkpointing = True
|
1645 |
-
if strategy in (ActivationCheckpointingStrategy.whole_layer, ActivationCheckpointingStrategy.vit_only):
|
1646 |
-
self.image_vit.set_grad_checkpointing()
|
1647 |
-
|
1648 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
1649 |
"""
|
1650 |
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
@@ -1802,9 +1744,6 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1802 |
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
1803 |
)
|
1804 |
|
1805 |
-
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
1806 |
-
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
1807 |
-
|
1808 |
if not (
|
1809 |
0 < self.config.block_group_size <= self.config.n_layers
|
1810 |
and self.config.n_layers % self.config.block_group_size == 0
|
@@ -1846,25 +1785,14 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1846 |
]
|
1847 |
self.transformer.update({"blocks": nn.ModuleList(layers)})
|
1848 |
|
1849 |
-
self.vision_backbone: Optional[
|
1850 |
if config.vision_backbone is not None:
|
1851 |
self.vision_backbone = MolmoVisionBackbone.build(config)
|
1852 |
|
1853 |
if self.vision_backbone is not None:
|
1854 |
self.vision_backbone.reset_with_pretrained_weights()
|
1855 |
|
1856 |
-
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
1857 |
-
self.activation_checkpointing_strategy = strategy
|
1858 |
-
if self.config.block_group_size != 1:
|
1859 |
-
for block_group in self.transformer.block_groups:
|
1860 |
-
block_group.set_activation_checkpointing(strategy)
|
1861 |
-
else:
|
1862 |
-
for block in self.transformer.blocks:
|
1863 |
-
block.set_activation_checkpointing(strategy)
|
1864 |
|
1865 |
-
if self.vision_backbone is not None:
|
1866 |
-
self.vision_backbone.set_activation_checkpointing(strategy)
|
1867 |
-
|
1868 |
@property
|
1869 |
def device(self) -> torch.device:
|
1870 |
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
@@ -1873,7 +1801,6 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1873 |
else:
|
1874 |
return device
|
1875 |
|
1876 |
-
|
1877 |
def forward(
|
1878 |
self,
|
1879 |
input_ids: torch.LongTensor,
|
@@ -2069,14 +1996,7 @@ class MolmoModel(MolmoPretrainedModel):
|
|
2069 |
all_hidden_states.append(x)
|
2070 |
|
2071 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
2072 |
-
|
2073 |
-
# shape: (batch_size, seq_len, d_model)
|
2074 |
-
x, cache = self._activation_checkpoint_fn(
|
2075 |
-
layer, x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache
|
2076 |
-
)
|
2077 |
-
else:
|
2078 |
-
# shape: (batch_size, seq_len, d_model)
|
2079 |
-
x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
|
2080 |
|
2081 |
if attn_key_values is not None:
|
2082 |
assert cache is not None
|
|
|
27 |
Set,
|
28 |
Tuple,
|
29 |
cast,
|
30 |
+
Union, Any,
|
31 |
)
|
32 |
from copy import deepcopy
|
33 |
import torch
|
|
|
36 |
import torch.nn.functional as F
|
37 |
from torch import einsum
|
38 |
import einops
|
39 |
+
from transformers import PreTrainedModel, GenerationConfig, Cache
|
40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
41 |
|
42 |
+
from .config_molmoe import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
ActivationType,
|
44 |
BlockType,
|
45 |
LayerNormType,
|
|
|
49 |
AttentionType,
|
50 |
)
|
51 |
|
52 |
+
|
53 |
from .config_molmoe import (
|
54 |
MolmoConfig,
|
55 |
+
VisionBackboneConfig, ModelConfig
|
56 |
)
|
57 |
|
58 |
if sys.version_info.minor > 8:
|
|
|
62 |
else:
|
63 |
raise SystemExit("This script supports Python 3.8 or higher")
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
log = logging.getLogger(__name__)
|
67 |
|
68 |
|
69 |
+
class OLMoConfigurationError(Exception):
|
70 |
+
pass
|
71 |
+
|
72 |
+
|
73 |
def activation_checkpoint_function(cfg: ModelConfig):
|
74 |
preserve_rng_state = not (
|
75 |
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
|
|
|
95 |
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def vit_activation_checkpoint_function(cfg: MolmoConfig):
|
99 |
v_cfg = cfg.vision_backbone
|
100 |
preserve_rng_state = (
|
|
|
109 |
)
|
110 |
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
113 |
"""
|
114 |
Cache for attention biases and other things that would normally be stored as buffers.
|
|
|
1508 |
self.image_feature_dropout = Dropout(config.image_feature_dropout)
|
1509 |
|
1510 |
@classmethod
|
1511 |
+
def build(cls, config: MolmoConfig):
|
1512 |
v_cfg = config.vision_backbone
|
1513 |
assert v_cfg is not None
|
1514 |
return MolmoPretrainedVisionBackbone(config)
|
1515 |
|
|
|
|
|
|
|
|
|
1516 |
def reset_parameters(self):
|
1517 |
if self.image_pooling_2d is not None:
|
1518 |
self.image_pooling_2d.reset_parameters()
|
|
|
1530 |
|
1531 |
|
1532 |
class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
1533 |
+
def __init__(self, config: MolmoConfig):
|
1534 |
super().__init__(config)
|
1535 |
+
v_cfg = VisionBackboneConfig()
|
1536 |
|
1537 |
if v_cfg.image_model_type == VisionBackboneType.openai:
|
1538 |
self.image_vit = VisionTransformer(config)
|
|
|
1587 |
if self.config.use_cls_feature:
|
1588 |
nn.init.xavier_uniform_(self.cls_projector.weight)
|
1589 |
|
|
|
|
|
|
|
|
|
|
|
1590 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
1591 |
"""
|
1592 |
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
|
1744 |
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
1745 |
)
|
1746 |
|
|
|
|
|
|
|
1747 |
if not (
|
1748 |
0 < self.config.block_group_size <= self.config.n_layers
|
1749 |
and self.config.n_layers % self.config.block_group_size == 0
|
|
|
1785 |
]
|
1786 |
self.transformer.update({"blocks": nn.ModuleList(layers)})
|
1787 |
|
1788 |
+
self.vision_backbone: Optional[MolmoVisionBackbone] = None
|
1789 |
if config.vision_backbone is not None:
|
1790 |
self.vision_backbone = MolmoVisionBackbone.build(config)
|
1791 |
|
1792 |
if self.vision_backbone is not None:
|
1793 |
self.vision_backbone.reset_with_pretrained_weights()
|
1794 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1795 |
|
|
|
|
|
|
|
1796 |
@property
|
1797 |
def device(self) -> torch.device:
|
1798 |
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
|
|
1801 |
else:
|
1802 |
return device
|
1803 |
|
|
|
1804 |
def forward(
|
1805 |
self,
|
1806 |
input_ids: torch.LongTensor,
|
|
|
1996 |
all_hidden_states.append(x)
|
1997 |
|
1998 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
1999 |
+
x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2000 |
|
2001 |
if attn_key_values is not None:
|
2002 |
assert cache is not None
|