|
## RETFound - A foundation model for retinal imaging |
|
|
|
|
|
This is the official repo for RETFound, which is based on [MAE](https://github.com/facebookresearch/mae): |
|
|
|
Please contact **[email protected]** or **[email protected]** if you have questions. |
|
|
|
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE) |
|
|
|
|
|
### Key features |
|
|
|
- RETFound is pre-trained on 1.6 million retinal images with self-supervised learning |
|
- RETFound has been validated in multiple disease detection tasks |
|
- RETFound can be efficiently adapted to customised tasks |
|
|
|
|
|
### News |
|
|
|
- A [visualisation demo](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_visualize.ipynb) is added |
|
|
|
### Install environment |
|
|
|
1. Create environment with conda: |
|
|
|
``` |
|
conda create -n retfound python=3.7.5 -y |
|
conda activate retfound |
|
``` |
|
|
|
2. Install dependencies |
|
|
|
``` |
|
git clone https://github.com/rmaphoh/RETFound_MAE/ |
|
cd RETFound_MAE |
|
pip install -r requirement.txt |
|
``` |
|
|
|
|
|
### Fine-tuning with RETFound weights |
|
|
|
To fine tune RETFound on your own data, follow these steps: |
|
|
|
1. Download the RETFound pre-trained weights |
|
<table><tbody> |
|
<!-- START TABLE --> |
|
<!-- TABLE HEADER --> |
|
<th valign="bottom"></th> |
|
<th valign="bottom">ViT-Large</th> |
|
<!-- TABLE BODY --> |
|
<tr><td align="left">Colour fundus image</td> |
|
<td align="center"><a href="https://drive.google.com/file/d/1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE/view?usp=sharing">download</a></td> |
|
</tr> |
|
<!-- TABLE BODY --> |
|
<tr><td align="left">OCT</td> |
|
<td align="center"><a href="https://drive.google.com/file/d/1m6s7QYkjyjJDlpEuXm7Xp3PmjN-elfW2/view?usp=sharing">download</a></td> |
|
</tr> |
|
</tbody></table> |
|
|
|
2. Organise your data into this directory structure (using IDRiD as an [example](Example.ipynb)) |
|
|
|
<p align="left"> |
|
<img src="./pic/file_index.jpg" width="160"> |
|
</p> |
|
|
|
|
|
3. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training. |
|
|
|
|
|
``` |
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ |
|
--batch_size 16 \ |
|
--world_size 1 \ |
|
--model vit_large_patch16 \ |
|
--epochs 50 \ |
|
--blr 5e-3 --layer_decay 0.65 \ |
|
--weight_decay 0.05 --drop_path 0.2 \ |
|
--nb_classes 5 \ |
|
--data_path ./IDRiD_data/ \ |
|
--task ./finetune_IDRiD/ \ |
|
--finetune ./RETFound_cfp_weights.pth |
|
|
|
``` |
|
|
|
|
|
4. For evaluation only |
|
|
|
|
|
``` |
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ |
|
--eval --batch_size 16 \ |
|
--world_size 1 \ |
|
--model vit_large_patch16 \ |
|
--epochs 50 \ |
|
--blr 5e-3 --layer_decay 0.65 \ |
|
--weight_decay 0.05 --drop_path 0.2 \ |
|
--nb_classes 5 \ |
|
--data_path ./IDRiD_data/ \ |
|
--task ./internal_IDRiD/ \ |
|
--resume ./finetune_IDRiD/checkpoint-best.pth |
|
|
|
``` |
|
|
|
|
|
### Load the model and weights (if you want to call the model in your code) |
|
|
|
```python |
|
import torch |
|
import models_vit |
|
from util.pos_embed import interpolate_pos_embed |
|
from timm.models.layers import trunc_normal_ |
|
|
|
# call the model |
|
model = models_vit.__dict__['vit_large_patch16']( |
|
num_classes=2, |
|
drop_path_rate=0.2, |
|
global_pool=True, |
|
) |
|
|
|
# load RETFound weights |
|
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu') |
|
checkpoint_model = checkpoint['model'] |
|
state_dict = model.state_dict() |
|
for k in ['head.weight', 'head.bias']: |
|
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: |
|
print(f"Removing key {k} from pretrained checkpoint") |
|
del checkpoint_model[k] |
|
|
|
# interpolate position embedding |
|
interpolate_pos_embed(model, checkpoint_model) |
|
|
|
# load pre-trained model |
|
msg = model.load_state_dict(checkpoint_model, strict=False) |
|
|
|
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} |
|
|
|
# manually initialize fc layer |
|
trunc_normal_(model.head.weight, std=2e-5) |
|
|
|
print("Model = %s" % str(model)) |
|
``` |
|
|
|
|