|
from transformers import PretrainedConfig |
|
from typing import List, Optional, Tuple |
|
|
|
|
|
class CXRConfig(PretrainedConfig): |
|
model_type = "cxr_basic" |
|
|
|
def __init__( |
|
self, |
|
backbone: str = "tf_efficientnetv2_s", |
|
feature_dim: int = 256, |
|
seg_dropout: float = 0.1, |
|
cls_dropout: float = 0.1, |
|
seg_num_classes: int = 4, |
|
cls_num_classes: int = 5, |
|
in_chans: int = 1, |
|
img_size: Tuple[int, int] = (320, 320), |
|
decoder_n_blocks: int = 5, |
|
decoder_channels: List[int] = [256, 128, 64, 32, 16], |
|
encoder_channels: List[int] = [24, 48, 64, 160, 256], |
|
decoder_center_block: bool = False, |
|
decoder_norm_layer: str = "bn", |
|
decoder_attention_type: Optional[str] = None, |
|
**kwargs, |
|
): |
|
self.backbone = backbone |
|
self.feature_dim = feature_dim |
|
self.seg_dropout = seg_dropout |
|
self.cls_dropout = cls_dropout |
|
self.seg_num_classes = seg_num_classes |
|
self.cls_num_classes = cls_num_classes |
|
self.in_chans = in_chans |
|
self.img_size = img_size |
|
self.decoder_n_blocks = decoder_n_blocks |
|
self.decoder_channels = decoder_channels |
|
self.encoder_channels = encoder_channels |
|
self.decoder_center_block = decoder_center_block |
|
self.decoder_norm_layer = decoder_norm_layer |
|
self.decoder_attention_type = decoder_attention_type |
|
super().__init__(**kwargs) |
|
|