from transformers import VisionTextDualEncoderConfig | |
class VTDEConfig(VisionTextDualEncoderConfig): | |
model_type = "vtde" | |
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, | |
text_pooling_mode='mean', | |
vision_pooling_mode='max', | |
**kwargs): | |
""" | |
pooling_mode in ['mean', 'max', 'cls'] | |
https://arxiv.org/pdf/2210.09996.pdf | |
https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56 | |
""" | |
self.text_pooling_mode = text_pooling_mode | |
self.vision_pooling_mode = vision_pooling_mode | |
super().__init__(projection_dim, logit_scale_init_value, **kwargs) | |