visheratin
commited on
Update nllb_mrl.py
Browse files- nllb_mrl.py +10 -6
nllb_mrl.py
CHANGED
@@ -4,7 +4,6 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from open_clip import create_model, get_tokenizer
|
7 |
-
from open_clip.pretrained import get_pretrained_cfg
|
8 |
from open_clip.transform import PreprocessCfg, image_transform_v2
|
9 |
from PIL import Image
|
10 |
from transformers import PretrainedConfig, PreTrainedModel
|
@@ -16,7 +15,7 @@ class MatryoshkaNllbClipConfig(PretrainedConfig):
|
|
16 |
clip_model_name: str = "",
|
17 |
target_resolution: int = -1,
|
18 |
mrl_resolutions: List[int] = [],
|
19 |
-
preprocess_cfg: Union[
|
20 |
**kwargs,
|
21 |
):
|
22 |
super().__init__(**kwargs)
|
@@ -53,11 +52,16 @@ class MatryoshkaNllbClip(PreTrainedModel):
|
|
53 |
if isinstance(device, str):
|
54 |
device = torch.device(device)
|
55 |
self.config = config
|
56 |
-
self.model = create_model(
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
)
|
59 |
self.transform = image_transform_v2(
|
60 |
-
|
61 |
is_train=False,
|
62 |
)
|
63 |
self._device = device
|
@@ -108,7 +112,7 @@ class MatryoshkaNllbClip(PreTrainedModel):
|
|
108 |
)
|
109 |
features = self.matryoshka_layer.layers[str(resolution)](features)
|
110 |
return F.normalize(features, dim=-1) if normalize else features
|
111 |
-
|
112 |
def encode_text(
|
113 |
self,
|
114 |
text,
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from open_clip import create_model, get_tokenizer
|
|
|
7 |
from open_clip.transform import PreprocessCfg, image_transform_v2
|
8 |
from PIL import Image
|
9 |
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
15 |
clip_model_name: str = "",
|
16 |
target_resolution: int = -1,
|
17 |
mrl_resolutions: List[int] = [],
|
18 |
+
preprocess_cfg: Union[dict, None] = None,
|
19 |
**kwargs,
|
20 |
):
|
21 |
super().__init__(**kwargs)
|
|
|
52 |
if isinstance(device, str):
|
53 |
device = torch.device(device)
|
54 |
self.config = config
|
55 |
+
self.model = create_model(config.clip_model_name, output_dict=True)
|
56 |
+
pp_cfg = PreprocessCfg(
|
57 |
+
size=config.preprocess_cfg["size"],
|
58 |
+
mean=config.preprocess_cfg["mean"],
|
59 |
+
std=config.preprocess_cfg["std"],
|
60 |
+
interpolation=config.preprocess_cfg["interpolation"],
|
61 |
+
resize_mode=config.preprocess_cfg["resize_mode"],
|
62 |
)
|
63 |
self.transform = image_transform_v2(
|
64 |
+
pp_cfg,
|
65 |
is_train=False,
|
66 |
)
|
67 |
self._device = device
|
|
|
112 |
)
|
113 |
features = self.matryoshka_layer.layers[str(resolution)](features)
|
114 |
return F.normalize(features, dim=-1) if normalize else features
|
115 |
+
|
116 |
def encode_text(
|
117 |
self,
|
118 |
text,
|