Config File Missing
Hi, I was trying to use this model in the following way:
from vit_prisma.models.base_vit import HookedViT
model_name= "IamYash/ImageNet-Tiny-AttentionOnly"
model = HookedViT.from_pretrained( model_name,
is_timm=False,
is_clip=False,
center_writing_weights=True,
center_unembed=True,
fold_ln=True,
refactor_factored_attn_matrices=True,
)
However, I get the following error:
OSError: IamYash/ImageNet-Tiny-AttentionOnly does not appear to have a file named config.json
When I check the files here on HuggingFace, I can also not find the config.file. How is this model inteded to be used?
Thank you in advance.
Thanks for raising the issue. We're looking into it asap. https://github.com/soniajoseph/ViT-Prisma/issues/79
Hello, @karreska . I'd like to redirect you to the now updated checkpoints at https://huggingface.co/Prisma-Multimodal/ImageNet-Tiny-AttentionOnly-Patch16 . My sincere apologies for the inconvenience, and we will update the documentation dealing with pre trained prisma checkpoints. You can follow the steps mentioned below for now.
Begin by downloading the checkpoints that you'd like to load with Prisma
import torch
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.configs.HookedViTConfig import HookedViTConfig
config = HookedViTConfig(n_layers=1, patch_size=16, d_model = 768, attn_only=True, d_head = 192, d_mlp = 3072, n_classes=1000, return_type="class_logits", normalization_type=None)
model = HookedViT(config)
checkpoint_path = 'checkpoints/checkpoint.pth' # replace with the path of the checkpoint that you'd like to load
state_dict = torch.load(checkpoint_path)
model.load_and_process_state_dict(
state_dict,
fold_ln=True,
center_writing_weights=True,
fold_value_biases=True,
refactor_factored_attn_matrices=True,
)
We're aware that this isn't the ideal workflow for using the pre-trained Prisma models, and it will be updated to be much more simpler. Thank you for being patient.