--- base_model: - google-t5/t5-small --- 1. Download the repo ```python import os import torch from glob import glob from transformers import AutoModelForSeq2SeqLM, AutoConfig model_name = 'marsggbo/t5-small_dff2048_dmodel32_token-pattern-predictor_switch64_wmt16' # ignore the mismatched size, because lm_head was modified model = AutoModelForSeq2SeqLM.from_pretrained( model_name, ignore_mismatched_sizes=True, use_safetensors=False ) ``` 3. Build the model ```python home_path = os.path.expanduser('~') num_classes = 64 # switch64 ckpt_path = f"{home_path}/.cache/huggingface/hub/*{model_name.split('/')[-1]}/snapshots/*/*bin" ckpt_path = glob(ckpt_path)[0] model_config = AutoConfig.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_config(config=model_config) model.lm_head = torch.nn.Linear(model.config.hidden_size, num_classes*6, bias=False) model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=True) ```