|
from transformers.configuration_utils import PretrainedConfig |
|
import sys |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
LlamaConfig, |
|
LlamaForCausalLM, |
|
PreTrainedModel, |
|
) |
|
from .attrdict_config import AttrDict |
|
|
|
class VisionConfig(PretrainedConfig): |
|
model_type = "vision" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class AlignerConfig(PretrainedConfig): |
|
model_type = "aligner" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class GenVisionConfig(PretrainedConfig): |
|
model_type = "gen_vision" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class GenAlignerConfig(PretrainedConfig): |
|
model_type = "gen_aligner" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class GenHeadConfig(PretrainedConfig): |
|
model_type = "gen_head" |
|
cls: str = "" |
|
params: AttrDict = {} |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.cls = kwargs.get("cls", "") |
|
if not isinstance(self.cls, str): |
|
self.cls = self.cls.__name__ |
|
|
|
self.params = AttrDict(kwargs.get("params", {})) |
|
|
|
|
|
class MultiModalityConfig(PretrainedConfig): |
|
model_type = "multi_modality" |
|
vision_config: VisionConfig |
|
aligner_config: AlignerConfig |
|
|
|
gen_vision_config: GenVisionConfig |
|
gen_aligner_config: GenAlignerConfig |
|
gen_head_config: GenHeadConfig |
|
|
|
language_config: LlamaConfig |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
vision_config = kwargs.get("vision_config", {}) |
|
self.vision_config = VisionConfig(**vision_config) |
|
|
|
aligner_config = kwargs.get("aligner_config", {}) |
|
self.aligner_config = AlignerConfig(**aligner_config) |
|
|
|
gen_vision_config = kwargs.get("gen_vision_config", {}) |
|
self.gen_vision_config = GenVisionConfig(**gen_vision_config) |
|
|
|
gen_aligner_config = kwargs.get("gen_aligner_config", {}) |
|
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) |
|
|
|
gen_head_config = kwargs.get("gen_head_config", {}) |
|
self.gen_head_config = GenHeadConfig(**gen_head_config) |
|
|
|
language_config = kwargs.get("language_config", {}) |
|
if isinstance(language_config, LlamaConfig): |
|
self.language_config = language_config |
|
else: |
|
self.language_config = LlamaConfig(**language_config) |
|
|