from transformers import PretrainedConfig from typing import List class ViTMixConfig(PretrainedConfig): #Note you cannot change the expert layers for now... model_type = "VitMix" def __init__( self, image_size = 28, patch_size = 14, num_classes = 10, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, num_experts = 12, **kwargs ): if image_size % patch_size != 0: print(f"image size must be half patch size! img_size: {image_size} | patch_size{patch_size}") self.image_size = image_size self.patch_size = patch_size self.num_classes = num_classes self.dim = dim self.depth = depth self.heads = heads self.mlp_dim = mlp_dim self.num_experts = num_experts super().__init__(**kwargs)