visheratin commited on
Commit
1fe0679
·
verified ·
1 Parent(s): c63ab68

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +10 -6
nllb_mrl.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from open_clip import create_model, get_tokenizer
7
- from open_clip.pretrained import get_pretrained_cfg
8
  from open_clip.transform import PreprocessCfg, image_transform_v2
9
  from PIL import Image
10
  from transformers import PretrainedConfig, PreTrainedModel
@@ -16,7 +15,7 @@ class MatryoshkaNllbClipConfig(PretrainedConfig):
16
  clip_model_name: str = "",
17
  target_resolution: int = -1,
18
  mrl_resolutions: List[int] = [],
19
- preprocess_cfg: Union[PreprocessCfg, None] = None,
20
  **kwargs,
21
  ):
22
  super().__init__(**kwargs)
@@ -53,11 +52,16 @@ class MatryoshkaNllbClip(PreTrainedModel):
53
  if isinstance(device, str):
54
  device = torch.device(device)
55
  self.config = config
56
- self.model = create_model(
57
- config.clip_model_name, output_dict=True
 
 
 
 
 
58
  )
59
  self.transform = image_transform_v2(
60
- config.preprocess_cfg,
61
  is_train=False,
62
  )
63
  self._device = device
@@ -108,7 +112,7 @@ class MatryoshkaNllbClip(PreTrainedModel):
108
  )
109
  features = self.matryoshka_layer.layers[str(resolution)](features)
110
  return F.normalize(features, dim=-1) if normalize else features
111
-
112
  def encode_text(
113
  self,
114
  text,
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from open_clip import create_model, get_tokenizer
 
7
  from open_clip.transform import PreprocessCfg, image_transform_v2
8
  from PIL import Image
9
  from transformers import PretrainedConfig, PreTrainedModel
 
15
  clip_model_name: str = "",
16
  target_resolution: int = -1,
17
  mrl_resolutions: List[int] = [],
18
+ preprocess_cfg: Union[dict, None] = None,
19
  **kwargs,
20
  ):
21
  super().__init__(**kwargs)
 
52
  if isinstance(device, str):
53
  device = torch.device(device)
54
  self.config = config
55
+ self.model = create_model(config.clip_model_name, output_dict=True)
56
+ pp_cfg = PreprocessCfg(
57
+ size=config.preprocess_cfg["size"],
58
+ mean=config.preprocess_cfg["mean"],
59
+ std=config.preprocess_cfg["std"],
60
+ interpolation=config.preprocess_cfg["interpolation"],
61
+ resize_mode=config.preprocess_cfg["resize_mode"],
62
  )
63
  self.transform = image_transform_v2(
64
+ pp_cfg,
65
  is_train=False,
66
  )
67
  self._device = device
 
112
  )
113
  features = self.matryoshka_layer.layers[str(resolution)](features)
114
  return F.normalize(features, dim=-1) if normalize else features
115
+
116
  def encode_text(
117
  self,
118
  text,