--- tags: - Foundation model for eye license: cc-by-nc-4.0 --- ## RETFound - A foundation model for retinal imaging This is the official repo for RETFound, which is based on [MAE](https://github.com/facebookresearch/mae): Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE) ### Key features - RETFound is pre-trained on 1.6 million retinal images with self-supervised learning - RETFound has been validated in multiple disease detection tasks - RETFound can be efficiently adapted to customised tasks ### News - A [visualisation demo](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_visualize.ipynb) is added ### Install environment 1. Create environment with conda: ``` conda create -n retfound python=3.7.5 -y conda activate retfound ``` 2. Install dependencies ``` git clone https://github.com/rmaphoh/RETFound_MAE/ cd RETFound_MAE pip install -r requirement.txt ``` ### Fine-tuning with RETFound weights To fine tune RETFound on your own data, follow these steps: 1. Download the RETFound pre-trained weights
ViT-Large | |
---|---|
Colour fundus image | download |
OCT | download |
3. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training. ``` python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ --batch_size 16 \ --world_size 1 \ --model vit_large_patch16 \ --epochs 50 \ --blr 5e-3 --layer_decay 0.65 \ --weight_decay 0.05 --drop_path 0.2 \ --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./finetune_IDRiD/ \ --finetune ./RETFound_cfp_weights.pth ``` 4. For evaluation only ``` python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ --eval --batch_size 16 \ --world_size 1 \ --model vit_large_patch16 \ --epochs 50 \ --blr 5e-3 --layer_decay 0.65 \ --weight_decay 0.05 --drop_path 0.2 \ --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./internal_IDRiD/ \ --resume ./finetune_IDRiD/checkpoint-best.pth ``` ### Load the model and weights (if you want to call the model in your code) ```python import torch import models_vit from util.pos_embed import interpolate_pos_embed from timm.models.layers import trunc_normal_ # call the model model = models_vit.__dict__['vit_large_patch16']( num_classes=2, drop_path_rate=0.2, global_pool=True, ) # load RETFound weights checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding interpolate_pos_embed(model, checkpoint_model) # load pre-trained model msg = model.load_state_dict(checkpoint_model, strict=False) assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} # manually initialize fc layer trunc_normal_(model.head.weight, std=2e-5) print("Model = %s" % str(model)) ```