MilaDeepGraph commited on
Commit
314a644
1 Parent(s): 6821c3c

init from Jiqing's repo

Browse files
Files changed (5) hide show
  1. README.md +108 -3
  2. config.json +63 -0
  3. configuration_protst.py +53 -0
  4. modeling_protst.py +285 -0
  5. pytorch_model.bin +3 -0
README.md CHANGED
@@ -1,3 +1,108 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Abstract
3
+ Current protein language models (PLMs) learn protein representations mainly based on their sequences, thereby well capturing co-evolutionary information, but they are unable to explicitly acquire protein functions, which is the end goal of protein representation learning. Fortunately, for many proteins, their textual property descriptions are available, where their various functions are also described. Motivated by this fact, we first build the ProtDescribe dataset to augment protein sequences with text descriptions of their functions and other important properties. Based on this dataset, we propose the [ProtST framework](https://arxiv.org/abs/2301.12040) to enhance Protein Sequence pre-training and understanding by biomedical Texts. During pre-training, we design three types of tasks, i.e., unimodal mask prediction, multimodal representation alignment and multimodal mask prediction, to enhance a PLM with protein property information with different granularities and, at the same time, preserve the PLM’s original representation power. On downstream tasks, ProtST enables both supervised learning and zeroshot prediction. We verify the superiority of ProtST-induced PLMs over previous ones on diverse representation learning benchmarks. Under the zero-shot setting, we show the effectiveness of ProtST on zero-shot protein classification, and ProtST also enables functional protein retrieval from a large-scale database without any function annotation. Source code and model weights are available at [https://github.com/DeepGraphLearning/ProtST](https://github.com/DeepGraphLearning/ProtST).
4
+
5
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f0a673f0d40f6aae296b4a/o4F5-Cm-gGdHPpX5rPVKx.png)
6
+
7
+ ## Example
8
+ This example shows how to use ProtST on zero-shot classification task.
9
+ ```python
10
+ import logging
11
+ import functools
12
+ from tqdm import tqdm
13
+ import torch
14
+ from datasets import load_dataset
15
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def tokenize_protein(example, protein_tokenizer=None, padding=None):
21
+ protein_seqs = example["prot_seq"]
22
+
23
+ protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
24
+ example["protein_input_ids"] = protein_inputs.input_ids
25
+ example["protein_attention_mask"] = protein_inputs.attention_mask
26
+
27
+ return example
28
+
29
+
30
+ def label_embedding(labels, text_tokenizer, text_model, device):
31
+ # embed label descriptions
32
+ label_feature = []
33
+ with torch.inference_mode():
34
+ for label in labels:
35
+ label_input_ids = text_tokenizer.encode(label, max_length=128,
36
+ truncation=True, add_special_tokens=False)
37
+ label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
38
+ label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
39
+ attention_mask = label_input_ids != text_tokenizer.pad_token_id
40
+ attention_mask = attention_mask.to(device)
41
+
42
+ text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
43
+
44
+ label_feature.append(text_outputs["text_feature"])
45
+ label_feature = torch.cat(label_feature, dim=0)
46
+ label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
47
+
48
+ return label_feature
49
+
50
+ def zero_shot_eval(logger, device,
51
+ test_dataset, target_field, protein_model, logit_scale, label_feature):
52
+
53
+ # get prediction and target
54
+ test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
55
+ preds, targets = [], []
56
+ with torch.inference_mode():
57
+ for data in tqdm(test_dataloader):
58
+ target = data[target_field]
59
+ targets.append(target)
60
+
61
+ protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
62
+ attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
63
+
64
+ protein_outputs = protein_model(protein_input_ids, attention_mask=attention_mask)
65
+
66
+ protein_feature = protein_outputs["protein_feature"]
67
+ protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
68
+ pred = logit_scale * protein_feature @ label_feature.t()
69
+ preds.append(pred)
70
+
71
+ preds = torch.cat(preds, dim=0)
72
+ targets = torch.tensor(targets, dtype=torch.long, device=device)
73
+ accuracy = (preds.argmax(dim=-1) == targets).float().mean().item()
74
+ logger.warning("Zero-shot accuracy: %.6f" % accuracy)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ # get datasets
79
+ raw_datasets = load_dataset("Jiqing/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
80
+
81
+ #device = torch.device("cuda:0")
82
+ device = torch.device("cpu")
83
+
84
+ protst_model = AutoModel.from_pretrained("Jiqing/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
85
+ protein_model = protst_model.protein_model
86
+ text_model = protst_model.text_model
87
+ logit_scale = protst_model.logit_scale
88
+ logit_scale.requires_grad = False
89
+ logit_scale = logit_scale.to(device)
90
+ logit_scale = logit_scale.exp()
91
+
92
+ protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
93
+ text_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
94
+
95
+ func_tokenize_protein = functools.partial(tokenize_protein, protein_tokenizer=protein_tokenizer, padding=False)
96
+ test_dataset = raw_datasets.map(
97
+ func_tokenize_protein, batched=False,
98
+ remove_columns=["prot_seq"],
99
+ desc="Running tokenize_proteins on dataset",
100
+ )
101
+
102
+ labels = load_dataset("Jiqing/subloc_template", cache_dir="~/.cache/huggingface/datasets")["train"]["name"]
103
+
104
+ text_tokenizer.encode(labels[0], max_length=128, truncation=True, add_special_tokens=False)
105
+ label_feature = label_embedding(labels, text_tokenizer, text_model, device)
106
+ zero_shot_eval(logger, device, test_dataset, "localization",
107
+ protein_model, logit_scale, label_feature)
108
+ ```
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProtSTModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModel": "modeling_protst.ProtSTModel",
7
+ "AutoConfig": "configuration_protst.ProtSTConfig"
8
+ },
9
+ "model_type": "protst",
10
+ "protein_config": {
11
+ "_name_or_path": "/tmp/facebook/esm1b_t33_650M_UR50S",
12
+ "architectures": [
13
+ "EsmForMaskedLM"
14
+ ],
15
+ "attention_probs_dropout_prob": 0.0,
16
+ "classifier_dropout": null,
17
+ "emb_layer_norm_before": true,
18
+ "esmfold_config": null,
19
+ "hidden_act": "gelu",
20
+ "hidden_dropout_prob": 0.0,
21
+ "hidden_size": 1280,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 5120,
24
+ "is_folding_model": false,
25
+ "layer_norm_eps": 1e-05,
26
+ "mask_token_id": 32,
27
+ "max_position_embeddings": 1026,
28
+ "model_type": "esm",
29
+ "num_attention_heads": 20,
30
+ "num_hidden_layers": 33,
31
+ "cls_token_id": 0,
32
+ "pad_token_id": 1,
33
+ "eos_token_id": 2,
34
+ "position_embedding_type": "absolute",
35
+ "token_dropout": true,
36
+ "torch_dtype": "float32",
37
+ "use_cache": true,
38
+ "vocab_list": null,
39
+ "vocab_size": 33
40
+ },
41
+ "text_config": {
42
+ "architectures": [
43
+ "BertForMaskedLM"
44
+ ],
45
+ "model_type": "bert",
46
+ "attention_probs_dropout_prob": 0.1,
47
+ "hidden_act": "gelu",
48
+ "pad_token_id": 0,
49
+ "cls_token_id": 2,
50
+ "sep_token_id": 3,
51
+ "hidden_dropout_prob": 0.1,
52
+ "hidden_size": 768,
53
+ "initializer_range": 0.02,
54
+ "intermediate_size": 3072,
55
+ "max_position_embeddings": 512,
56
+ "num_attention_heads": 12,
57
+ "num_hidden_layers": 12,
58
+ "type_vocab_size": 2,
59
+ "vocab_size": 30522
60
+ },
61
+ "torch_dtype": "float32",
62
+ "transformers_version": "4.37.0.dev0"
63
+ }
configuration_protst.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from transformers.models.esm import EsmConfig
4
+ from transformers.models.bert import BertConfig
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class ProtSTConfig(PretrainedConfig):
10
+ r"""
11
+ This is the configuration class to store the configuration of a [`ProtSTModel`].
12
+
13
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
14
+ documentation from [`PretrainedConfig`] for more information.
15
+
16
+ Args:
17
+ protein_config (`dict`, *optional*):
18
+ Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
19
+ text_config (`dict`, *optional*):
20
+ Dictionary of configuration options used to initialize [`BertForPubMed`].
21
+ ```"""
22
+
23
+ model_type = "protst"
24
+
25
+ def __init__(
26
+ self,
27
+ protein_config=None,
28
+ text_config=None,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+
33
+ if protein_config is None:
34
+ protein_config = {}
35
+ logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.")
36
+
37
+ if text_config is None:
38
+ text_config = {}
39
+ logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.")
40
+
41
+ self.protein_config = EsmConfig(**protein_config)
42
+ self.text_config = BertConfig(**text_config)
43
+
44
+ @classmethod
45
+ def from_protein_text_configs(
46
+ cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs
47
+ ):
48
+ r"""
49
+ Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
50
+ [`ProtSTConfig`]: An instance of a configuration object
51
+ """
52
+
53
+ return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
modeling_protst.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional, Tuple, Union
5
+ from dataclasses import dataclass
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import ModelOutput
8
+ from transformers.models.esm import EsmPreTrainedModel, EsmModel
9
+ from transformers.models.bert import BertPreTrainedModel, BertModel
10
+ from .configuration_protst import ProtSTConfig
11
+
12
+
13
+ @dataclass
14
+ class EsmProteinRepresentationOutput(ModelOutput):
15
+
16
+ protein_feature: torch.FloatTensor = None
17
+ residue_feature: torch.FloatTensor = None
18
+
19
+
20
+ @dataclass
21
+ class BertTextRepresentationOutput(ModelOutput):
22
+
23
+ text_feature: torch.FloatTensor = None
24
+ word_feature: torch.FloatTensor = None
25
+
26
+
27
+ @dataclass
28
+ class ProtSTClassificationOutput(ModelOutput):
29
+
30
+ loss: Optional[torch.FloatTensor] = None
31
+ logits: torch.FloatTensor = None
32
+
33
+ class ProtSTHead(nn.Module):
34
+ def __init__(self, config, out_dim=512):
35
+ super().__init__()
36
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
37
+ self.out_proj = nn.Linear(config.hidden_size, out_dim)
38
+
39
+ def forward(self, x):
40
+ x = self.dense(x)
41
+ x = nn.functional.relu(x)
42
+ x = self.out_proj(x)
43
+ return x
44
+
45
+
46
+ class BertForPubMed(BertPreTrainedModel):
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+
50
+ self.pad_token_id = config.pad_token_id
51
+ self.cls_token_id = config.cls_token_id
52
+ self.sep_token_id = config.sep_token_id
53
+
54
+ self.bert = BertModel(config, add_pooling_layer=False)
55
+ self.text_mlp = ProtSTHead(config)
56
+ self.word_mlp = ProtSTHead(config)
57
+
58
+ self.post_init() # NOTE
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.Tensor] = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ token_type_ids: Optional[torch.Tensor] = None,
65
+ position_ids: Optional[torch.Tensor] = None,
66
+ head_mask: Optional[torch.Tensor] = None,
67
+ inputs_embeds: Optional[torch.Tensor] = None,
68
+ encoder_hidden_states: Optional[torch.Tensor] = None,
69
+ encoder_attention_mask: Optional[torch.Tensor] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ return_dict: Optional[bool] = None,
73
+ ) -> Union[Tuple[torch.Tensor], ModelOutput]:
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ outputs = self.bert(
77
+ input_ids,
78
+ attention_mask=attention_mask,
79
+ token_type_ids=token_type_ids,
80
+ position_ids=position_ids,
81
+ head_mask=head_mask,
82
+ inputs_embeds=inputs_embeds,
83
+ encoder_hidden_states=encoder_hidden_states,
84
+ encoder_attention_mask=encoder_attention_mask,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict,
88
+ )
89
+ word_feature = outputs.last_hidden_state
90
+ is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id)
91
+ special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
92
+ pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype)
93
+ pooled_feature = self.text_mlp(pooled_feature)
94
+ word_feature = self.word_mlp(word_feature)
95
+
96
+ if not return_dict:
97
+ return (pooled_feature, word_feature)
98
+
99
+ return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature)
100
+
101
+
102
+
103
+
104
+ class EsmForProteinRepresentation(EsmPreTrainedModel):
105
+ def __init__(self, config):
106
+ super().__init__(config)
107
+
108
+ self.cls_token_id = config.cls_token_id
109
+ self.pad_token_id = config.pad_token_id
110
+ self.eos_token_id = config.eos_token_id
111
+
112
+ self.esm = EsmModel(config, add_pooling_layer=False)
113
+ self.protein_mlp = ProtSTHead(config)
114
+ self.residue_mlp = ProtSTHead(config)
115
+
116
+ self.post_init() # NOTE
117
+
118
+ def forward(
119
+ self,
120
+ input_ids: Optional[torch.LongTensor] = None,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ position_ids: Optional[torch.LongTensor] = None,
123
+ head_mask: Optional[torch.Tensor] = None,
124
+ inputs_embeds: Optional[torch.FloatTensor] = None,
125
+ output_attentions: Optional[bool] = None,
126
+ output_hidden_states: Optional[bool] = None,
127
+ return_dict: Optional[bool] = None,
128
+ ) -> Union[Tuple, EsmProteinRepresentationOutput]:
129
+
130
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
131
+
132
+ outputs = self.esm(
133
+ input_ids,
134
+ attention_mask=attention_mask,
135
+ position_ids=position_ids,
136
+ head_mask=head_mask,
137
+ inputs_embeds=inputs_embeds,
138
+ output_attentions=output_attentions,
139
+ output_hidden_states=output_hidden_states,
140
+ return_dict=return_dict,
141
+ )
142
+
143
+ residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
144
+
145
+ # mean readout
146
+ is_special = (
147
+ (input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id)
148
+ )
149
+ special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
150
+ protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype)
151
+
152
+ # For ProtST pretrain and zero-shot
153
+ protein_feature = self.protein_mlp(protein_feature)
154
+ residue_feature = self.residue_mlp(residue_feature)
155
+
156
+
157
+ return EsmProteinRepresentationOutput(
158
+ protein_feature=protein_feature, residue_feature=residue_feature
159
+ )
160
+
161
+
162
+ class ProtSTPreTrainedModel(PreTrainedModel):
163
+ config_class = ProtSTConfig
164
+
165
+ def _compute_protein_feature(self,
166
+ protein_input_ids, protein_attention_mask, protein_position_ids,
167
+ output_attentions, output_hidden_states
168
+ ):
169
+
170
+ protein_outputs = self.protein_model(
171
+ protein_input_ids,
172
+ attention_mask=protein_attention_mask,
173
+ position_ids=protein_position_ids,
174
+ head_mask=None,
175
+ inputs_embeds=None,
176
+ encoder_hidden_states=None,
177
+ encoder_attention_mask=None,
178
+ output_attentions=output_attentions,
179
+ output_hidden_states=output_hidden_states,
180
+ return_dict=None,
181
+ )
182
+
183
+ return protein_outputs
184
+
185
+ def _compute_text_feature(self,
186
+ text_input_ids, text_attention_mask, text_position_ids,
187
+ output_attentions, output_hidden_states
188
+ ):
189
+ text_outputs = self.text_model(
190
+ text_input_ids,
191
+ attention_mask=text_attention_mask,
192
+ position_ids=text_position_ids,
193
+ head_mask=None,
194
+ inputs_embeds=None,
195
+ encoder_hidden_states=None,
196
+ encoder_attention_mask=None,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=None,
200
+ )
201
+
202
+ return text_outputs
203
+
204
+
205
+ class ProtSTModel(ProtSTPreTrainedModel):
206
+ def __init__(self, config):
207
+ super().__init__(config)
208
+
209
+ self.config = config
210
+ self.protein_model = EsmForProteinRepresentation(config.protein_config)
211
+ self.text_model = BertForPubMed(config.text_config)
212
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
213
+
214
+ self.post_init() # NOTE
215
+
216
+ def forward(self,
217
+ protein_input_ids: Optional[torch.LongTensor] = None,
218
+ text_input_ids: Optional[torch.LongTensor] = None,
219
+ protein_attention_mask: Optional[torch.Tensor] = None,
220
+ text_attention_mask: Optional[torch.Tensor] = None,
221
+ protein_position_ids: Optional[torch.LongTensor] = None,
222
+ text_position_ids: Optional[torch.LongTensor] = None,
223
+ output_attentions: Optional[bool] = None,
224
+ output_hidden_states: Optional[bool] = None,
225
+ ):
226
+ # Not implement yet
227
+ return None
228
+
229
+
230
+ class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel):
231
+ def __init__(self, config):
232
+ super().__init__(config)
233
+
234
+ self.config = config
235
+ self.protein_model = EsmForProteinRepresentation(config.protein_config)
236
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
237
+ self.classifier = ProtSTHead(config.protein_config, out_dim=config.num_labels)
238
+
239
+ self.post_init() # NOTE
240
+
241
+ def forward(
242
+ self,
243
+ input_ids: Optional[torch.LongTensor] = None,
244
+ attention_mask: Optional[torch.Tensor] = None,
245
+ position_ids: Optional[torch.LongTensor] = None,
246
+ head_mask: Optional[torch.Tensor] = None,
247
+ inputs_embeds: Optional[torch.FloatTensor] = None,
248
+ labels: Optional[torch.LongTensor] = None,
249
+ output_attentions: Optional[bool] = None,
250
+ output_hidden_states: Optional[bool] = None,
251
+ return_dict: Optional[bool] = None,
252
+ ) -> Union[Tuple, ProtSTClassificationOutput]:
253
+ r"""
254
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
255
+ Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
256
+ Returns:
257
+ Examples:
258
+ """
259
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
260
+
261
+ outputs = self.protein_model(
262
+ input_ids,
263
+ attention_mask=attention_mask,
264
+ position_ids=position_ids,
265
+ head_mask=head_mask,
266
+ inputs_embeds=inputs_embeds,
267
+ output_attentions=output_attentions,
268
+ output_hidden_states=output_hidden_states,
269
+ return_dict=return_dict,
270
+ )
271
+
272
+ logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
273
+
274
+ loss = None
275
+ if labels is not None:
276
+ loss_fct = nn.CrossEntropyLoss()
277
+
278
+ labels = labels.to(logits.device)
279
+ loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
280
+
281
+ if not return_dict:
282
+ output = (logits,)
283
+ return ((loss,) + output) if loss is not None else output
284
+
285
+ return ProtSTClassificationOutput(loss=loss, logits=logits)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c59f77e12992626701f6bdfb732b5b9171f753fda86df7f68aa2135ebd421868
3
+ size 135