Hiraishin commited on
Commit
aca3292
1 Parent(s): aee2dcc

Upload model

Browse files
Files changed (3) hide show
  1. config.json +31 -0
  2. mistral_contrastive.py +60 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "embedding-model-mistral-64m-contrastive/checkpoint-7750",
3
+ "architectures": [
4
+ "MistralModelEmbedding"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModel": "mistral_contrastive.MistralModelEmbedding"
9
+ },
10
+ "bos_token_id": 1,
11
+ "embedding_size": 768,
12
+ "eos_token_id": 2,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 512,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 2048,
17
+ "max_position_embeddings": 32768,
18
+ "model_type": "mistral",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 8,
21
+ "num_key_value_heads": 8,
22
+ "pad_token_id": 0,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_theta": 10000.0,
25
+ "sliding_window": 4096,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.38.1",
29
+ "use_cache": true,
30
+ "vocab_size": 32000
31
+ }
mistral_contrastive.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import MistralPreTrainedModel, MistralModel, MistralConfig
3
+ from typing import Dict
4
+ from transformers.file_utils import ModelOutput
5
+ from typing import List, Optional, Tuple, Union
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch import nn, Tensor
8
+ from dataclasses import dataclass
9
+ from torch import nn
10
+ import torch
11
+ from transformers.file_utils import ModelOutput
12
+ import torch.nn.functional as F
13
+
14
+ COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
15
+
16
+ @dataclass
17
+ class EncoderOutput(ModelOutput):
18
+ loss: Optional[Tensor] = None
19
+
20
+ class MistralModelEmbedding(MistralPreTrainedModel):
21
+ def __init__(self, config, **kwargs):
22
+ super().__init__(config, **kwargs)
23
+
24
+ self.model = MistralModel(config)
25
+ self.dense_layer = nn.Linear(
26
+ self.config.hidden_size,
27
+ self.config.embedding_size,
28
+ bias=False
29
+ )
30
+ self.post_init()
31
+
32
+
33
+ def encode(self, features):
34
+ if features is None:
35
+ return None
36
+ psg_out = self.model.forward(**features,return_dict=True)
37
+ logits = self.dense_layer(psg_out.last_hidden_state)
38
+ input_ids = features['input_ids']
39
+ batch_size = input_ids.shape[0]
40
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
41
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
42
+ sequence_lengths = sequence_lengths.to(logits.device)
43
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
44
+ return pooled_logits
45
+
46
+
47
+ def forward(self, query: Dict[str, Tensor] = None,
48
+ passage: Dict[str, Tensor] = None, labels = None, margin = 1.0):
49
+ q_reps = self.encode(query)
50
+ p_reps = self.encode(passage)
51
+
52
+ loss = None
53
+ if labels is not None:
54
+ distances = COSINE_DISTANCE(q_reps, p_reps)
55
+ losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(margin - distances).pow(2))
56
+ loss = losses.mean()
57
+
58
+ return EncoderOutput(
59
+ loss=loss,
60
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8779c25f343a3e09f5602e3e32991a6e7db339d547adc5b4091113068d6052f6
3
+ size 96494632