File size: 3,702 Bytes
bee0aca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# --------------------------------------------------------
# Eagle2
# Copyright (c) 2025 NVIDIA
# Licensed under The Apache License [see LICENSE for details]
# --------------------------------------------------------

import os
from typing import Union

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from .configuration_siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)


class MultiBackboneChannelConcatenationVisionModelConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MultiBackboneChannelConcatenationVisionModelConfig`]. It is used to
    instantiate a vision encoder according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vision_path (str): Path to the vision model or its configuration.
        mm_vision_select_layer (int, optional): The layer to select from the vision model
                                                for multi-modal processing. Defaults to -2.
        grid_size (int, optional): The size of the grid for vision processing. Defaults to 32.
        **kwargs: Additional keyword arguments to be passed to the parent PretrainedConfig.
        
    """

    model_type = 'MOB'

    def __init__(
            self,
            vision_path,
            mm_vision_select_layer=-2,
            grid_size=32,
            input_image_size=1024,
            hidden_size='lazy_calculation',
            image_size=1024,
            freeze_backbones=None,
            moe_version_type=None,
            delay_load=False,
            convnext_img_size=1024,
            vision_tower_siglip_path=None,
            vision_tower_convnext_path='convnext_xxlarge.clip_laion2b_soup',
            normalize_type='siglip',
            **kwargs,
    ):
        super().__init__(**kwargs)

        self.normalize_type = normalize_type
        self.vision_path = vision_path
        self.mm_vision_select_layer = mm_vision_select_layer
        self.grid_size = grid_size
        self.input_image_size = input_image_size
        self.image_size = image_size
        self.hidden_size = hidden_size
        self.freeze_backbones = freeze_backbones
        self.moe_version_type = moe_version_type
        self.delay_load = delay_load
        self.convnext_img_size = convnext_img_size
        # other args. to make it compatable with eagle-next
        self.vision_tower_siglip_path = vision_tower_siglip_path
        self.vision_tower_convnext_path = vision_tower_convnext_path
        self.vision_tower = self.vision_path[4:] # remove `MOB:` prefix

        # asserts
        assert image_size == input_image_size, f"input_image_size ({input_image_size}) != image_size ({image_size})"

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        if 'vision_config' in config_dict:
            config_dict = config_dict['vision_config']

        if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
            )

        return cls.from_dict(config_dict, **kwargs)