Update modeling_cxrbert.py

#7
Files changed (1) hide show
  1. modeling_cxrbert.py +45 -27
modeling_cxrbert.py CHANGED
@@ -3,6 +3,7 @@
3
  # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
  # ------------------------------------------------------------------------------------------
5
 
 
6
  from typing import Any, Optional, Tuple, Union
7
 
8
  import torch
@@ -16,20 +17,24 @@ from .configuration_cxrbert import CXRBertConfig
16
 
17
  BERTTupleOutput = Tuple[T, T, T, T, T]
18
 
 
 
19
  class CXRBertOutput(ModelOutput):
20
  last_hidden_state: torch.FloatTensor
21
- logits: torch.FloatTensor
22
  cls_projected_embedding: Optional[torch.FloatTensor] = None
23
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24
  attentions: Optional[Tuple[torch.FloatTensor]] = None
25
 
26
 
27
  class BertProjectionHead(nn.Module):
28
- '''
29
- Projection head to be used with BERT CLS token, it's similar to `BertPredictionHeadTransform` in HuggingFace library.
30
- :param config: CXRBertConfig
31
- :return: (batch_size, output_size)
32
- '''
 
 
33
  def __init__(self, config: CXRBertConfig) -> None:
34
  super().__init__()
35
  self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
@@ -50,13 +55,13 @@ class CXRBertModel(BertForMaskedLM):
50
  """
51
  Implements the CXR-BERT model outlined in the manuscript:
52
  Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
53
- https://arxiv.org/abs/2204.09817
54
 
55
- Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is used to align
56
- the latent vectors of image and text modalities.
57
  """
58
 
59
- config_class = CXRBertConfig
60
 
61
  def __init__(self, config: CXRBertConfig):
62
  super().__init__(config)
@@ -78,21 +83,24 @@ class CXRBertModel(BertForMaskedLM):
78
  return_dict: Optional[bool] = None,
79
  **kwargs: Any
80
  ) -> Union[BERTTupleOutput, CXRBertOutput]:
81
-
82
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
 
84
- bert_for_masked_lm_output = super().forward(input_ids=input_ids,
85
- attention_mask=attention_mask,
86
- token_type_ids=token_type_ids,
87
- position_ids=position_ids,
88
- head_mask=head_mask,
89
- inputs_embeds=inputs_embeds,
90
- output_attentions=output_attentions,
91
- output_hidden_states=True,
92
- return_dict=True)
 
 
93
 
94
  last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
95
- cls_projected_embedding = self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
 
 
96
 
97
  if return_dict:
98
  return CXRBertOutput(
@@ -108,21 +116,31 @@ class CXRBertModel(BertForMaskedLM):
108
  bert_for_masked_lm_output.logits,
109
  cls_projected_embedding,
110
  bert_for_masked_lm_output.hidden_states,
111
- bert_for_masked_lm_output.attentions,)
 
112
 
113
- def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 
 
114
  """
115
  Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
116
  The joint latent space is trained using a contrastive objective between image and text data modalities.
117
 
118
  :param input_ids: (batch_size, sequence_length)
119
  :param attention_mask: (batch_size, sequence_length)
 
120
  :return: (batch_size, projection_size)
121
  """
122
 
123
- outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
124
- output_cls_projected_embedding=True, return_dict=True)
 
125
  assert isinstance(outputs, CXRBertOutput)
126
 
127
- normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
128
- return normalized_cls_embedding
 
 
 
 
 
 
3
  # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
  # ------------------------------------------------------------------------------------------
5
 
6
+ from dataclasses import dataclass
7
  from typing import Any, Optional, Tuple, Union
8
 
9
  import torch
 
17
 
18
  BERTTupleOutput = Tuple[T, T, T, T, T]
19
 
20
+
21
+ @dataclass
22
  class CXRBertOutput(ModelOutput):
23
  last_hidden_state: torch.FloatTensor
24
+ logits: Optional[torch.FloatTensor] = None
25
  cls_projected_embedding: Optional[torch.FloatTensor] = None
26
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
27
  attentions: Optional[Tuple[torch.FloatTensor]] = None
28
 
29
 
30
  class BertProjectionHead(nn.Module):
31
+ """Projection head to be used with BERT CLS token.
32
+
33
+ This is similar to ``BertPredictionHeadTransform`` in HuggingFace.
34
+
35
+ :param config: Configuration for BERT.
36
+ """
37
+
38
  def __init__(self, config: CXRBertConfig) -> None:
39
  super().__init__()
40
  self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
 
55
  """
56
  Implements the CXR-BERT model outlined in the manuscript:
57
  Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
58
+ https://link.springer.com/chapter/10.1007/978-3-031-20059-5_1
59
 
60
+ Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is
61
+ used to align the latent vectors of image and text modalities.
62
  """
63
 
64
+ config_class = CXRBertConfig # type: ignore
65
 
66
  def __init__(self, config: CXRBertConfig):
67
  super().__init__(config)
 
83
  return_dict: Optional[bool] = None,
84
  **kwargs: Any
85
  ) -> Union[BERTTupleOutput, CXRBertOutput]:
 
86
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87
 
88
+ bert_for_masked_lm_output = super().forward(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ token_type_ids=token_type_ids,
92
+ position_ids=position_ids,
93
+ head_mask=head_mask,
94
+ inputs_embeds=inputs_embeds,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=True,
97
+ return_dict=True,
98
+ )
99
 
100
  last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
101
+ cls_projected_embedding = (
102
+ self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
103
+ )
104
 
105
  if return_dict:
106
  return CXRBertOutput(
 
116
  bert_for_masked_lm_output.logits,
117
  cls_projected_embedding,
118
  bert_for_masked_lm_output.hidden_states,
119
+ bert_for_masked_lm_output.attentions,
120
+ )
121
 
122
+ def get_projected_text_embeddings(
123
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, normalize_embeddings: bool = True
124
+ ) -> torch.Tensor:
125
  """
126
  Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
127
  The joint latent space is trained using a contrastive objective between image and text data modalities.
128
 
129
  :param input_ids: (batch_size, sequence_length)
130
  :param attention_mask: (batch_size, sequence_length)
131
+ :param normalize_embeddings: Whether to l2-normalise the embeddings.
132
  :return: (batch_size, projection_size)
133
  """
134
 
135
+ outputs = self.forward(
136
+ input_ids=input_ids, attention_mask=attention_mask, output_cls_projected_embedding=True, return_dict=True
137
+ )
138
  assert isinstance(outputs, CXRBertOutput)
139
 
140
+ cls_projected_embedding = outputs.cls_projected_embedding
141
+ assert cls_projected_embedding is not None
142
+
143
+ if normalize_embeddings:
144
+ return F.normalize(cls_projected_embedding, dim=1)
145
+
146
+ return cls_projected_embedding