|
--- |
|
base_model: |
|
- google-t5/t5-small |
|
--- |
|
1. Download the repo |
|
|
|
```python |
|
import os |
|
import torch |
|
|
|
from glob import glob |
|
from transformers import AutoModelForSeq2SeqLM, AutoConfig |
|
|
|
model_name = 'marsggbo/t5-small_dff2048_dmodel32_token-pattern-predictor_switch64_xsum' |
|
# ignore the mismatched size, because lm_head was modified |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, ignore_mismatched_sizes=True, use_safetensors=False |
|
) |
|
``` |
|
|
|
2. Build the model |
|
```python |
|
home_path = os.path.expanduser('~') |
|
num_classes = 64 # switch64 |
|
ckpt_path = f"{home_path}/.cache/huggingface/hub/*{model_name.split('/')[-1]}/snapshots/*/*bin" |
|
ckpt_path = glob(ckpt_path)[0] |
|
|
|
model_config = AutoConfig.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_config(config=model_config) |
|
model.lm_head = torch.nn.Linear(model.config.hidden_size, num_classes*6, bias=False) |
|
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=True) |
|
``` |