Gurveer05 commited on
Commit
bf1f674
Β·
1 Parent(s): 60e7251

Added pred func

Browse files
config.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project-wide configuration settings
2
+
3
+ # Variables for train-test-split
4
+ TRAIN_SIZE: 0.7
5
+
6
+ # General parameters
7
+ max_len: 1000
8
+ num_tissues: 8
9
+ expressed_threshold: 0.1
10
+ random_seed: 766
11
+
12
+ dnabert:
13
+ max_seq_len: 512
14
+ kmer: 6
15
+ test_size: 0.2
16
+ tokenizer:
17
+ vocab_size: 5000
18
+ data:
19
+ max_seq_len: 1000
20
+ test_size: 0.2
21
+ num_labels: 8
22
+ training:
23
+ pretrain:
24
+ num_train_epochs: 3
25
+ per_device_train_batch_size: 64
26
+ per_device_eval_batch_size: 64
27
+ fp16: true
28
+ logging_steps: 50
29
+ eval_steps: 200
30
+ save_steps: 100
31
+ save_total_limit: 20
32
+ gradient_accumulation_steps: 25
33
+ learning_rate: 1.e-4
34
+ weight_decay: 0
35
+ adam_epsilon: 1.e-8
36
+ max_grad_norm: 10
37
+ warmup_steps: 50
38
+ optimizer: "lamb"
39
+ scheduler: "linear"
40
+ mlm_prob: 0.15
41
+ finetune:
42
+ # num_train_epochs: 10
43
+ num_train_epochs: 3
44
+ per_device_train_batch_size: 64
45
+ per_device_eval_batch_size: 8
46
+ fp16: true
47
+ logging_steps: 50
48
+ eval_steps: 500
49
+ save_steps: 500
50
+ save_total_limit: 10
51
+ # gradient_accumulation_steps: 1
52
+ gradient_accumulation_steps: 10
53
+ eval_accumulation_steps: 64
54
+ learning_rate: 1.e-3
55
+ # learning_rate: 1.e-1
56
+ # lr: 1.e-3
57
+ betas:
58
+ - 0.9
59
+ - 0.999
60
+ eps: 1.e-8
61
+ weight_decay: 0
62
+ adam_epsilon: 1.e-8
63
+ max_grad_norm: 10
64
+ warmup_steps: 200
65
+ num_cooldown_steps: 2000
66
+ optimizer: "lamb"
67
+ # optimizer: "adamw"
68
+ # scheduler: "delay"
69
+ scheduler: "constant"
70
+ # num_param_groups: 0
71
+ # param_group_size: 2 # Except for the classification head, which has param_group_size == 1
72
+ delay_size: 0
73
+ models:
74
+ roberta-base:
75
+ num_attention_heads: 6
76
+ num_hidden_layers: 6
77
+ type_vocab_size: 1
78
+ block_size: 258
79
+ max_tokenized_len: 256
80
+ roberta-lm: {}
81
+ roberta-pred: {}
82
+ roberta-pred-mean-pool:
83
+ hidden_dropout_prob: 0.2
84
+ output_mode: "regression"
85
+ # For sparse (bce + mse) loss
86
+ # output_mode: "sparse"
87
+ threshold: 1
88
+ alpha: 0.1
89
+ dnabert-base:
90
+ block_size: 512
91
+ max_tokenized_len: 510
92
+ dnabert-lm: {}
93
+ dnabert-pred: {}
94
+ dnabert-pred-mean-pool:
95
+ hidden_dropout_prob: 0.2
96
+ output_mode: "regression"
data/.gitkeep ADDED
File without changes
data/test.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ CTCAAGCTGAGCAGTGGGTTTGCTCTGGAGGGGAAGCTCAACGGTGGCGACAAGGAAGAATCTGCTTGCGAGGCGAGCCCTGACGCCGCTGATAGCGACCAAAGGTGGATTAAACAACCCATTTCATCATTCTTCTTCCTTGTTAGTTATGATTCCCACGCTTGCCTTTCATGAATCATGATCCTATATGTATATTGATATTAATCAGTTCTAGAAAGTTCAACAACATTTGAGCATGTCAAAACCTGATCGTTGCCTGTTCCATGTCAACAGTGGATTATAACACGTGCAAATGTAGCTATTTGTGTGAGAAGACGTGTGATCGACTCTTTTTTTATATAGATAGCATTGAGATCAACTGTTTGTATATATCTTGTCATAACATTTTTACTTCGTAGCAACGTACGAGCGTTCACCTATTTGTATATAAGTTATCATGATATTTATAAGTTACCGTTGCAACGCACGGACACTCACCTAGTATAGTTTATGTATTACAGTACTAGGAGCCCTAGGCTTCCAATAACTAGAAAAAGTCCTGGTCAGTCGAACCAAACCACAATCCGACGTATACATTCTGGTTCCCCCACGCCCCCATCCGTTCGATTCA
models/.gitkeep ADDED
File without changes
{byte-level-bpe-tokenizer β†’ models/byte-level-bpe-tokenizer}/merges.txt RENAMED
File without changes
{byte-level-bpe-tokenizer β†’ models/byte-level-bpe-tokenizer}/vocab.json RENAMED
File without changes
{transformer β†’ models/transformer}/language-model/config.json RENAMED
File without changes
{transformer β†’ models/transformer}/language-model/pytorch_model.bin RENAMED
File without changes
{transformer β†’ models/transformer}/language-model/training_args.bin RENAMED
File without changes
{transformer β†’ models/transformer}/prediction-model/config.json RENAMED
File without changes
{transformer β†’ models/transformer}/prediction-model/pytorch_model.bin RENAMED
File without changes
{transformer β†’ models/transformer}/prediction-model/training_args.bin RENAMED
File without changes
module/.gitkeep ADDED
File without changes
module/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.78 kB). View file
 
module/__pycache__/dataio.cpython-311.pyc ADDED
Binary file (6.98 kB). View file
 
module/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
module/__pycache__/models.cpython-311.pyc ADDED
Binary file (17.7 kB). View file
 
module/__pycache__/transformers_utility.cpython-311.pyc ADDED
Binary file (4.02 kB). View file
 
module/__pycache__/utils.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
module/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import yaml
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ root = Path(__file__).parent.parent
10
+ data = root / 'data'
11
+ models = root / 'models'
12
+ notebooks = root / 'notebooks'
13
+ scripts = root / 'scripts'
14
+ output = root / 'output'
15
+ docs = root / 'docs'
16
+
17
+ # Data specific paths
18
+ data_raw = data / 'raw'
19
+ data_processed = data / 'processed'
20
+ data_final = data / 'final'
21
+
22
+ # Location of tools
23
+ libs = root / 'libs'
24
+ samtools = libs / 'samtools'
25
+ bedtools = libs / 'bedtools'
26
+ dnabert = root / 'DNABERT'
27
+
28
+ # Locations of specific files
29
+ bpe_tokenizer = data_final / 'tokenizer' / 'maize_bpe_full.tokenizer.json'
30
+
31
+ # Loading settings
32
+ settings = yaml.full_load((root / 'config.yaml').open('r'))
33
+
34
+ # Setting random seeds across the whole project
35
+ random_seed = settings['random_seed']
36
+ random.seed(random_seed)
37
+ np.random.seed(random_seed)
38
+ torch.manual_seed(random_seed)
39
+
40
+
41
+ def reload_settings():
42
+ global settings
43
+ settings = yaml.full_load((root / 'config.yaml').open('r'))
44
+ tissues = [
45
+ 'tassel',
46
+ 'base',
47
+ 'anther',
48
+ 'middle',
49
+ 'ear',
50
+ 'shoot',
51
+ 'tip',
52
+ 'root'
53
+ ]
module/dataio.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Utilities for reading and writing data files.
2
+ """
3
+ import multiprocessing as mp
4
+ import os
5
+ from pathlib import PosixPath
6
+ from typing import Callable, Dict, List, Optional, Tuple, Union
7
+ from datasets import load_dataset
8
+ from torch.utils.data import Dataset
9
+
10
+ from transformers import (
11
+ DataCollatorForLanguageModeling,
12
+ PreTrainedTokenizer,
13
+ default_data_collator,
14
+ )
15
+
16
+ from . import config
17
+
18
+ # To avoid huggingface warning
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+ UBUNTU_ROOT = str(config.root)
21
+
22
+ def load_datasets(
23
+ tokenizer: PreTrainedTokenizer,
24
+ train_data: Union[str, PosixPath],
25
+ eval_data: Optional[Union[str, PosixPath]] = None,
26
+ test_data: Union[str, PosixPath] = None,
27
+ file_type: str = "csv",
28
+ delimiter: str = "\t",
29
+ seq_key: str = "sequence",
30
+ shuffle: bool = True,
31
+ filter_empty: bool = False,
32
+ n_workers: int = mp.cpu_count(),
33
+ **kwargs,
34
+ ) -> Dataset:
35
+ """Load and cache data using Huggingface datasets library
36
+
37
+ Args:
38
+ tokenizer (PreTrainedTokenizer): tokenizer to apply to the sequences
39
+ train_data (Union[str, PosixPath]): location of training data
40
+ eval_data (Union[str, PosixPath], optional): location of evaluation data. Defaults to None.
41
+ test_data (Union[str, PosixPath], optional): location of test data. Defaults to None.
42
+ file_type (str, optional): type of file. Possible values are 'text' and 'csv'. Defaults to 'csv'.
43
+ delimiter (str, optional): Defaults to '\t'.
44
+ seq_key (str, optional): Column name of sequence data Can be 'sequence', 'seq', or 'text'. Defaults to 'sequence'.
45
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
46
+ filter_empty (bool, optional): Whether to filter out empty sequences. Defaults to False.
47
+ NOTE: This completes an additional iteration, which can be time-consuming.
48
+ Only enable if you have reason to believe that preprocessing steps will
49
+ result in empty sequences.
50
+ transformation (str, optional): type of transformation to apply.
51
+ Options are 'log', 'boxcox'. Defaults to None.
52
+ log_offset (Union[float, int]): value to offset gene expression values
53
+ by before log transforming. Defaults to 1.
54
+ preprocessor (BaseEstimator): preprocessor Yeoh-Johnson transformation.
55
+ tissue_subset (Union[str, int, list], optional): tissues to subset labels to.
56
+ Defaults to None.
57
+ nshards (int, optional): Number of shards to divide data into, only
58
+ keeping the first. Defaults to None.
59
+ threshold (float, optional): filter out rows where all labels are
60
+ below `threshold`. OR if `discretize` is True, see `discretize`.
61
+ Defaults to None.
62
+ discretize (bool, optional): set gene expression values below
63
+ `threshold` to 0, above `threshold` to 1.
64
+ kmer (int, optional): whether to run the kmer flip experiment and if so,
65
+ how large kmers to flip. Defaults to None.
66
+ n_workers (int, optional): number of processes to use for preprocessing.
67
+ Defaults to `mp.cpu_count()` (number of available CPUs).
68
+ position_buckets (Tuple[int], optional): the different buckets for the bucketed
69
+ positional importance experiment
70
+
71
+ Returns:
72
+ Dataset
73
+ """
74
+ data_files = {"train": str(train_data)}
75
+ if eval_data:
76
+ data_files["eval"] = str(eval_data)
77
+ if test_data:
78
+ data_files["test"] = str(test_data)
79
+ if file_type == "csv":
80
+ kwargs.update({"delimiter": delimiter})
81
+ datasets = load_dataset(file_type, data_files=data_files, **kwargs)
82
+ # Tokenizing
83
+ preprocess_fn = make_preprocess_function(tokenizer, seq_key=seq_key)
84
+ # print("Tokenizing...")
85
+ datasets = datasets.map(preprocess_fn, batched=True, num_proc=n_workers)
86
+ if filter_empty:
87
+ datasets = datasets.filter(filter_empty_sequence)
88
+ if shuffle:
89
+ seed = config.settings["random_seed"]
90
+ datasets = datasets.shuffle(seeds={"train": seed, "eval": seed, "test": seed})
91
+ return datasets
92
+
93
+
94
+ def make_preprocess_function(tokenizer, seq_key: str = "sequence") -> callable:
95
+ """Make a preprocessing function that selects the appropriate column and
96
+ tokenizes it.
97
+
98
+ Args:
99
+ tokenizer (PreTrainedTokenizer): tokenizer to apply to each sequence
100
+ seq_key (str, optional): column name of the text data. Defaults to 'sequence'.
101
+
102
+ Returns:
103
+ callable: preprocessing function
104
+ """
105
+
106
+ def preprocess_function(examples):
107
+ if seq_key:
108
+ seqs = examples[seq_key]
109
+ else:
110
+ seqs = examples
111
+ return tokenizer(
112
+ seqs,
113
+ max_length=tokenizer.model_max_length,
114
+ truncation=True,
115
+ padding="max_length",
116
+ )
117
+
118
+ return preprocess_function
119
+
120
+ def filter_empty_sequence(example: dict) -> bool:
121
+ """Filter out empty sequences."""
122
+ # sum(example['attention_mask']) gives the number of tokens, including SOS and EOS
123
+ return sum(example["attention_mask"]) > 2
124
+
125
+ def load_data_collator(model_type: str, tokenizer=None, mlm_prob=None):
126
+ if model_type == "language-model":
127
+ assert (
128
+ tokenizer is not None
129
+ ), "tokenizer must not be None if model is type language-model"
130
+ assert (
131
+ mlm_prob is not None
132
+ ), "mlm_prob must not be None if model is type language-model"
133
+
134
+ return DataCollatorForLanguageModeling(
135
+ tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob
136
+ )
137
+ else:
138
+ return default_data_collator
module/metrics.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reusable metrics functions for evaluating models
2
+ """
3
+ import multiprocessing as mp
4
+ from typing import List
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from transformers import default_data_collator
9
+ from tqdm import tqdm
10
+
11
+ def get_predictions(
12
+ model: torch.nn.Module,
13
+ dataset: torch.utils.data.Dataset,
14
+ ) -> List:
15
+ """Compute model predictions for `dataset`.
16
+
17
+ Args:
18
+ model (torch.nn.Module): Model to evaluate
19
+ dataset (torch.utils.data.Dataset): Dataset to get predictions for
20
+ return_labels (bool, optional): Whether to return the labels (predictions are always returned).
21
+ Defaults to True.
22
+
23
+ Returns:
24
+ Tuple[torch.Tensor, torch.Tensor]: 'true_labels', 'pred_labels'
25
+ """
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model.to(device)
28
+ model.eval()
29
+ loader = DataLoader(
30
+ dataset,
31
+ batch_size=64,
32
+ collate_fn=default_data_collator,
33
+ drop_last=False,
34
+ num_workers=mp.cpu_count(),
35
+ )
36
+ pred_labels = []
37
+ for batch in tqdm(loader):
38
+ inputs = {k: batch[k].to(device) for k in ["attention_mask", "input_ids"]}
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ del inputs # to free up space on GPU
42
+ logits = outputs[0]
43
+ pred_labels.append([round(e, 4) for e in logits.cpu().tolist()[0]])
44
+
45
+ return pred_labels
module/models.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified HuggingFace transformer model classes
3
+ """
4
+ from typing import Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import BCELoss, BCEWithLogitsLoss, MSELoss, PoissonNLLLoss, KLDivLoss
10
+
11
+ from transformers import BertConfig, BertModel, RobertaConfig, RobertaModel
12
+ from transformers import BertPreTrainedModel
13
+ from transformers.modeling_outputs import SequenceClassifierOutput
14
+ from transformers import RobertaPreTrainedModel
15
+
16
+
17
+ class RobertaMeanPoolConfig(RobertaConfig):
18
+ model_type = "roberta"
19
+
20
+ def __init__(
21
+ self,
22
+ output_mode="regression",
23
+ freeze_base=True,
24
+ start_token_idx=0,
25
+ end_token_idx=1,
26
+ threshold=1,
27
+ alpha=0.5,
28
+ log_offset=1,
29
+ batch_norm=False,
30
+ **kwargs,
31
+ ):
32
+ """Constructs RobertaConfig."""
33
+ super().__init__(**kwargs)
34
+ self.output_mode = output_mode
35
+ self.freeze_base = freeze_base
36
+ self.start_token_idx = start_token_idx
37
+ self.end_token_idx = end_token_idx
38
+ self.threshold = threshold
39
+ self.alpha = alpha
40
+ self.log_offset = log_offset
41
+ self.batch_norm = batch_norm
42
+
43
+
44
+ class ClassificationHeadMeanPool(nn.Module):
45
+ """Head for sentence-level classification tasks.
46
+
47
+ Modifications:
48
+ 1. Using mean-pooling over tokens instead of CLS token
49
+ 2. Multi-output regression
50
+ """
51
+
52
+ def __init__(self, config: RobertaMeanPoolConfig):
53
+ super().__init__()
54
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
55
+ self.dense2 = nn.Linear(config.hidden_size, config.hidden_size)
56
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
57
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
58
+ self.start_token_idx = config.start_token_idx
59
+ self.end_token_idx = config.end_token_idx
60
+ self.batch_norm = (
61
+ nn.BatchNorm1d(config.hidden_size) if config.batch_norm else None
62
+ )
63
+ if self.batch_norm is not None:
64
+ print("Using batch_norm")
65
+
66
+ def forward(self, features, attention_mask=None, input_ids=None, **kwargs):
67
+ x = self.embed(features, attention_mask, input_ids, **kwargs)
68
+ x = self.out_proj(x)
69
+ return x
70
+
71
+ def embed(self, features, attention_mask=None, input_ids=None, **kwargs):
72
+ attention_mask[input_ids == self.start_token_idx] = 0
73
+ attention_mask[input_ids == self.end_token_idx] = 0
74
+ x = torch.sum(features * attention_mask.unsqueeze(2), dim=1) / torch.sum(
75
+ attention_mask, dim=1, keepdim=True
76
+ ) # Mean pooling over non-padding tokens
77
+
78
+ x = self.dropout(x)
79
+ x = self.dense(x)
80
+ x = torch.tanh(x)
81
+ x = self.dropout(x)
82
+
83
+ # Batchnorm
84
+ x = self.normalize(x)
85
+
86
+ # Second linear layer
87
+ x = self.dense2(x)
88
+ x = torch.tanh(x)
89
+ return x
90
+
91
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
92
+ if self.batch_norm is not None:
93
+ return self.batch_norm(x)
94
+ return x
95
+
96
+
97
+ class ClassificationHeadMeanPoolSparse(nn.Module):
98
+ """Classification head that predicts binary outcome (expressed/not)
99
+ and real-valued gene expression values.
100
+ """
101
+
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.classification_head = ClassificationHeadMeanPool(config)
105
+ self.regression_head = ClassificationHeadMeanPool(config)
106
+
107
+ def forward(
108
+ self, features, attention_mask=None, input_ids=None, **kwargs
109
+ ) -> Tuple[torch.Tensor]:
110
+ """Compute binarized logits and real-valued gene expressions for each tissue.
111
+
112
+ Args:
113
+ features (torch.Tensor): outputs of RoBERTa
114
+ attention_mask (Optional[torch.Tensor]): attention mask for sentence
115
+ input_ids (Optional[torch.Tensor]): original sequence inputs
116
+
117
+ Returns:
118
+ (torch.Tensor): classification logits (whether gene is expressed/not for tissue)
119
+ (torch.Tensor): gene expression value predictions (real-valued)
120
+ """
121
+ # Consider using .clone().detach()
122
+ attention_mask_copy = attention_mask.clone()
123
+ return (
124
+ self.classification_head(
125
+ features, attention_mask=attention_mask, input_ids=input_ids, **kwargs
126
+ ),
127
+ self.regression_head(
128
+ features,
129
+ attention_mask=attention_mask_copy,
130
+ input_ids=input_ids,
131
+ **kwargs,
132
+ ),
133
+ )
134
+
135
+
136
+ class SparseMSELoss(nn.Module):
137
+ """Custom loss function that takes in two inputs:
138
+ 1. Predicted logits for whether gene is expressed (1) or not (0)
139
+ 2. Real-valued log-TPM values for gene expression predictions.
140
+ """
141
+
142
+ def __init__(self, threshold: float = 1, alpha: float = 0.5):
143
+ """
144
+ Args:
145
+ threshold (float): any value below this threshold (in natural
146
+ scale, NOT log-scale) is considered "not expressed"
147
+ alpha (float): parameter controlling importance of classification
148
+ in overall accuracy. alpha == 1 means this is identical to
149
+ classification. alpha == 0 means this is identical to regression.
150
+ """
151
+ super().__init__()
152
+ self.threshold = np.log(threshold)
153
+ self.alpha = alpha
154
+ self.mse = MSELoss()
155
+ self.bce = BCEWithLogitsLoss()
156
+
157
+ def forward(self, logits: Tuple[torch.Tensor], labels: torch.Tensor):
158
+ classification_outputs, regression_outputs = logits
159
+ binarized_labels = (labels >= self.threshold).float()
160
+
161
+ mse_loss = self.mse(regression_outputs, labels)
162
+ bce_loss = self.bce(classification_outputs, binarized_labels)
163
+
164
+ # Weight the losses by the logits
165
+ # the mse loss should be weighted by the probability of being expressed
166
+ # the bce loss should be weighted by the probability of not being expressed
167
+
168
+ loss = self.alpha * bce_loss + (1 - self.alpha) * mse_loss
169
+ return loss
170
+
171
+
172
+ class ZeroInflatedNegativeBinomialNLL(nn.Module):
173
+ """Custom loss function that calculates the negative log-likelihood
174
+ according to a zero-inflated negative binomial model.
175
+ """
176
+
177
+ pass
178
+
179
+
180
+ # -------------------------------------- #
181
+ # #
182
+ # ---------- Modified RoBERTa ---------- #
183
+ # #
184
+ # -------------------------------------- #
185
+
186
+
187
+ class RobertaForSequenceClassificationMeanPool(RobertaPreTrainedModel):
188
+ """RobertaForSequenceClassification using modified classification head
189
+
190
+ Args:
191
+ RobertaPreTrainedModel ([type]): [description]
192
+
193
+ Returns:
194
+ [type]: [description]
195
+ """
196
+
197
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
198
+
199
+ def __init__(self, config: RobertaMeanPoolConfig):
200
+ super().__init__(config)
201
+ self.num_labels = config.num_labels
202
+ self.output_mode = config.output_mode or "regression"
203
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
204
+ self.threshold = config.threshold
205
+ self.alpha = config.alpha
206
+ self.log_offset = config.log_offset
207
+
208
+ if self.output_mode == "sparse":
209
+ self.classifier = ClassificationHeadMeanPoolSparse(config)
210
+ else:
211
+ self.classifier = ClassificationHeadMeanPool(config)
212
+
213
+ self.init_weights()
214
+
215
+ def forward(
216
+ self,
217
+ input_ids=None,
218
+ attention_mask=None,
219
+ token_type_ids=None,
220
+ position_ids=None,
221
+ head_mask=None,
222
+ inputs_embeds=None,
223
+ labels=None,
224
+ output_attentions=None,
225
+ output_hidden_states=None,
226
+ return_dict=None,
227
+ ):
228
+ r"""
229
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
230
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
231
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
232
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
233
+ """
234
+ return_dict = (
235
+ return_dict if return_dict is not None else self.config.use_return_dict
236
+ )
237
+
238
+ outputs = self.roberta(
239
+ input_ids,
240
+ attention_mask=attention_mask,
241
+ token_type_ids=token_type_ids,
242
+ position_ids=position_ids,
243
+ head_mask=head_mask,
244
+ inputs_embeds=inputs_embeds,
245
+ output_attentions=output_attentions,
246
+ output_hidden_states=output_hidden_states,
247
+ return_dict=return_dict,
248
+ )
249
+ sequence_output = outputs[0]
250
+ logits = self.classifier(
251
+ sequence_output, attention_mask=attention_mask, input_ids=input_ids
252
+ )
253
+
254
+ loss = None
255
+ if labels is not None:
256
+ if self.output_mode == "regression":
257
+ loss_fct = MSELoss()
258
+ elif self.output_mode == "sparse":
259
+ loss_fct = SparseMSELoss(threshold=self.threshold, alpha=self.alpha)
260
+ elif self.output_mode == "classification":
261
+ loss_fct = BCEWithLogitsLoss()
262
+ elif self.output_mode == "poisson":
263
+ loss_fct = PoissonNLLLoss()
264
+
265
+ loss = loss_fct(
266
+ logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
267
+ )
268
+
269
+ if not return_dict:
270
+ output = (logits,) + outputs[2:]
271
+ return ((loss,) + output) if loss is not None else output
272
+
273
+ return SequenceClassifierOutput(
274
+ loss=loss,
275
+ logits=logits,
276
+ hidden_states=outputs.hidden_states,
277
+ attentions=outputs.attentions,
278
+ )
279
+
280
+ def embed(
281
+ self,
282
+ input_ids=None,
283
+ attention_mask=None,
284
+ token_type_ids=None,
285
+ position_ids=None,
286
+ head_mask=None,
287
+ inputs_embeds=None,
288
+ labels=None,
289
+ output_attentions=None,
290
+ output_hidden_states=None,
291
+ return_dict=None,
292
+ ):
293
+ """Embed sequences by running the `forward` method up to the dense layer of the classifier"""
294
+ outputs = self.roberta(
295
+ input_ids,
296
+ attention_mask=attention_mask,
297
+ token_type_ids=token_type_ids,
298
+ position_ids=position_ids,
299
+ head_mask=head_mask,
300
+ inputs_embeds=inputs_embeds,
301
+ output_attentions=output_attentions,
302
+ output_hidden_states=output_hidden_states,
303
+ return_dict=return_dict,
304
+ )
305
+ sequence_output = outputs[0]
306
+ embeddings = self.classifier.embed(
307
+ sequence_output, attention_mask=attention_mask, input_ids=input_ids
308
+ )
309
+ return embeddings
310
+
311
+ def get_tissue_embeddings(self):
312
+ return self.classifier.out_proj.weight.detach()
313
+
314
+ def predict(
315
+ self,
316
+ input_ids=None,
317
+ attention_mask=None,
318
+ token_type_ids=None,
319
+ position_ids=None,
320
+ head_mask=None,
321
+ inputs_embeds=None,
322
+ labels=None,
323
+ output_attentions=None,
324
+ output_hidden_states=None,
325
+ return_dict=None,
326
+ ):
327
+ logits = self.forward(
328
+ input_ids=input_ids,
329
+ attention_mask=attention_mask,
330
+ token_type_ids=token_type_ids,
331
+ position_ids=position_ids,
332
+ head_mask=head_mask,
333
+ inputs_embeds=inputs_embeds,
334
+ output_attentions=output_attentions,
335
+ output_hidden_states=output_hidden_states,
336
+ return_dict=return_dict,
337
+ )[0]
338
+ if self.output_mode == "sparse":
339
+ binary_logits, pred_values = logits
340
+ # Convert logits to binary predictions
341
+ binary_preds = binary_logits < 0
342
+ # return binary_preds * pred_values
343
+ pred_values[binary_preds] = np.log(self.log_offset)
344
+ return pred_values
345
+ return logits
346
+
347
+
348
+ # -------------------------------------- #
349
+ # #
350
+ # ---------- Modified BERT ----------- #
351
+ # #
352
+ # -------------------------------------- #
353
+
354
+
355
+ class BertMeanPoolConfig(BertConfig):
356
+ model_type = "bert"
357
+
358
+ def __init__(
359
+ self, output_mode="regression", start_token_idx=2, end_token_idx=3, **kwargs
360
+ ):
361
+ """Constructs BertConfig."""
362
+ super().__init__(**kwargs)
363
+ self.output_mode = output_mode
364
+ self.start_token_idx = start_token_idx
365
+ self.end_token_idx = end_token_idx
366
+
367
+
368
+ class BertForSequenceClassificationMeanPool(BertPreTrainedModel):
369
+ def __init__(self, config):
370
+ super().__init__(config)
371
+ self.num_labels = config.num_labels
372
+ self.output_mode = config.output_mode or "regression"
373
+ self.bert = BertModel(config)
374
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
375
+
376
+ self.classifier = ClassificationHeadMeanPool(config)
377
+
378
+ self.init_weights()
379
+
380
+ def forward(
381
+ self,
382
+ input_ids=None,
383
+ attention_mask=None,
384
+ token_type_ids=None,
385
+ position_ids=None,
386
+ head_mask=None,
387
+ inputs_embeds=None,
388
+ labels=None,
389
+ output_attentions=None,
390
+ output_hidden_states=None,
391
+ return_dict=None,
392
+ ):
393
+ r"""
394
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
395
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
396
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
397
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
398
+ """
399
+ return_dict = (
400
+ return_dict if return_dict is not None else self.config.use_return_dict
401
+ )
402
+
403
+ outputs = self.bert(
404
+ input_ids,
405
+ attention_mask=attention_mask,
406
+ token_type_ids=token_type_ids,
407
+ position_ids=position_ids,
408
+ head_mask=head_mask,
409
+ inputs_embeds=inputs_embeds,
410
+ output_attentions=output_attentions,
411
+ output_hidden_states=output_hidden_states,
412
+ return_dict=return_dict,
413
+ )
414
+
415
+ pooled_output = outputs[0]
416
+
417
+ pooled_output = self.dropout(pooled_output)
418
+ logits = self.classifier(
419
+ pooled_output, attention_mask=attention_mask, input_ids=input_ids
420
+ )
421
+
422
+ loss = None
423
+ if labels is not None:
424
+ if self.output_mode == "regression":
425
+ # We are doing regression
426
+ loss_fct = MSELoss()
427
+ loss = loss_fct(logits.view(-1), labels.view(-1))
428
+ else:
429
+ loss_fct = BCELoss()
430
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
431
+
432
+ if not return_dict:
433
+ output = (logits,) + outputs[2:]
434
+ return ((loss,) + output) if loss is not None else output
435
+
436
+ return SequenceClassifierOutput(
437
+ loss=loss,
438
+ logits=logits,
439
+ hidden_states=outputs.hidden_states,
440
+ attentions=outputs.attentions,
441
+ )
module/transformers_utility.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import PosixPath
2
+ from typing import Union, Optional
3
+
4
+ from transformers import (
5
+ RobertaConfig,
6
+ RobertaTokenizerFast,
7
+ RobertaForMaskedLM,
8
+ RobertaForSequenceClassification,
9
+ )
10
+
11
+ from .models import (
12
+ RobertaMeanPoolConfig,
13
+ RobertaForSequenceClassificationMeanPool,
14
+ )
15
+
16
+ RobertaSettings = dict(
17
+ padding_side='left'
18
+ )
19
+
20
+
21
+ MODELS = {
22
+ "roberta-lm": (RobertaConfig, RobertaTokenizerFast, RobertaForMaskedLM, RobertaSettings),
23
+ "roberta-pred": (RobertaConfig, RobertaTokenizerFast, RobertaForSequenceClassification, RobertaSettings),
24
+ "roberta-pred-mean-pool": (RobertaMeanPoolConfig, RobertaTokenizerFast, RobertaForSequenceClassificationMeanPool, RobertaSettings)
25
+ }
26
+
27
+
28
+ def load_model(model_name: str,
29
+ tokenizer_dir: Union[str, PosixPath],
30
+ max_tokenized_len: int = 254,
31
+ pretrained_model: Union[str, PosixPath] = None,
32
+ k: Optional[int] = None,
33
+ do_lower_case: Optional[bool] = None,
34
+ padding_side: Optional[str] = 'left',
35
+ **config_settings) -> tuple:
36
+ """Load specified model, config, and tokenizer.
37
+
38
+ Args:
39
+ model_name (str): Name of model. Acceptable options are
40
+ - 'roberta-lm',
41
+ - 'roberta-pred',
42
+ - 'roberta-pred-mean-pool'
43
+ tokenizer_dir (Union[str, PosixPath]): Directory containing tokenizer
44
+ files: merges.txt and vocab.txt
45
+ max_len (int, optional): Maximum tokenized length,
46
+ not including SOS and EOS. Defaults to 254.
47
+ pretrained_model (Union[str, PosixPath], optional): path to saved
48
+ pretrained RoBERTa transformer model. Defaults to None.
49
+ k (Optional[int], optional): Size of kmers (for DNABERT model). Defaults to 6.
50
+ do_lower_case (bool, optional): Whether to convert all inputs to lower case. Defaults to None.
51
+ padding_side (str, optional): Which side to pad on. Defaults to 'left'.
52
+
53
+ Returns:
54
+ tuple: config_obj, tokenizer, model
55
+ """
56
+ config_settings = config_settings or {}
57
+ max_position_embeddings = max_tokenized_len + 2 # To include SOS and EOS
58
+ config_class, tokenizer_class, model_class, tokenizer_settings = MODELS[model_name]
59
+
60
+ kwargs = dict(
61
+ max_len=max_tokenized_len,
62
+ truncate=True,
63
+ padding="max_length",
64
+ **tokenizer_settings
65
+ )
66
+ if k is not None:
67
+ kwargs.update(dict(k=k))
68
+ if do_lower_case is not None:
69
+ kwargs.update(dict(do_lower_case=do_lower_case))
70
+ if padding_side is not None:
71
+ kwargs.update(dict(padding_side=padding_side))
72
+
73
+ tokenizer = tokenizer_class.from_pretrained(str(tokenizer_dir), **kwargs)
74
+ name_or_path = str(pretrained_model) or ''
75
+ config_obj = config_class(
76
+ vocab_size=len(tokenizer),
77
+ max_position_embeddings=max_position_embeddings,
78
+ name_or_path=name_or_path,
79
+ output_hidden_states=True,
80
+ **config_settings
81
+ )
82
+ if pretrained_model:
83
+ # print(f"Loading from pretrained model {pretrained_model}")
84
+ model = model_class.from_pretrained(
85
+ str(pretrained_model), config=config_obj)
86
+ else:
87
+ print("Loading untrained model")
88
+ model = model_class(config=config_obj)
89
+ model.resize_token_embeddings(len(tokenizer))
90
+ return config_obj, tokenizer, model
module/utils.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import wget
4
+ import requests
5
+ import re
6
+ import argparse
7
+ from types import GeneratorType, ModuleType
8
+ from typing import Union, Tuple
9
+ import subprocess
10
+ from pathlib import PosixPath, Path
11
+ import importlib as im
12
+ import json
13
+ import pickle
14
+
15
+ import pandas as pd
16
+ import numpy as np
17
+ from IPython.display import display
18
+ import torch
19
+ from tqdm import tqdm
20
+ from sklearn.metrics import r2_score
21
+
22
+ from .config import settings, output, data_final, models
23
+
24
+ def preprocess_genex(genex_data: pd.DataFrame, settings: dict) -> pd.DataFrame:
25
+ if settings["data"].get("preprocess", False):
26
+ preproc_dict = settings["data"]["preprocess"]
27
+ preproc_type = preproc_dict["type"]
28
+ if preproc_type == "log":
29
+ delta = preproc_dict["delta"]
30
+ df_preprocessed = genex_data.applymap(lambda x: np.log(x + delta))
31
+ elif preproc_type == "binary":
32
+ thresh = preproc_dict["threshold"]
33
+ df_preprocessed = genex_data.applymap(lambda x: float(x > thresh))
34
+ elif preproc_type == "ceiling":
35
+ ceiling = preproc_dict["ceiling"]
36
+ df_preprocessed = genex_data.applymap(lambda x: min(ceiling, x))
37
+ else:
38
+ df_preprocessed = genex_data
39
+ return df_preprocessed
40
+ else:
41
+ return genex_data
42
+
43
+ def get_args(
44
+ data_dir=data_final / "transformer" / "seq",
45
+ train_data="all_seqs_train.txt",
46
+ eval_data=None,
47
+ test_data="all_seqs_test.txt",
48
+ output_dir=models / "transformer" / "language-model",
49
+ model_name=None,
50
+ pretrained_model=None,
51
+ tokenizer_dir=None,
52
+ log_offset=None,
53
+ preprocessor=None,
54
+ filter_empty=False,
55
+ hyperparam_search_metrics=None,
56
+ hyperparam_search_trials=None,
57
+ transformation=None,
58
+ output_mode=None,
59
+ ) -> argparse.Namespace:
60
+ """Use Python's ArgumentParser to create a namespace from (optional) user input
61
+
62
+ Args:
63
+ data_dir ([type], optional): Base location of data files. Defaults to data_final/'transformer'/'seq'.
64
+ train_data (str, optional): Name of train data file in `data_dir` Defaults to 'all_seqs_train.txt'.
65
+ test_data (str, optional): Name of test data file in `data_dir`. Defaults to 'all_seqs_test.txt'.
66
+ output_dir ([type], optional): Location to save trained model. Defaults to models/'transformer'/'language-model'.
67
+ model_name (Union[str, PosixPath], optional): Name of model
68
+ pretrained_mdoel (Union[str, PosixPath], optional): path to config and weights for huggingface pretrained model.
69
+ tokenizer_dir (Union[str, PosixPath], optional): path to config files for huggingface pretrained tokenizer.
70
+ filter_empty (bool, optional): Whether to filter out empty sequences.
71
+ Necessary for kmer-based models; takes additional time.
72
+ hyperparam_search_metrics (Union[list, str], optional): metrics for hyperparameter search.
73
+ hyperparam_search_trials (int, optional): number of trials to run hyperparameter search.
74
+ transformation (str, optional): how to transform data. Defaults to None.
75
+ output_mode (str, optional): default output mode for model and data transformation. Defaults to None.
76
+ Returns:
77
+ argparse.Namespace: parsed arguments
78
+ """
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument(
81
+ "-w",
82
+ "--warmstart",
83
+ action="store_true",
84
+ help="Whether to start with a saved checkpoint",
85
+ default=False,
86
+ )
87
+ parser.add_argument("--num-embeddings", type=int, default=-1)
88
+ parser.add_argument(
89
+ "--data-dir",
90
+ type=str,
91
+ default=str(data_dir),
92
+ help="Directory containing train/eval data. Defaults to data/final/transformer/seq",
93
+ )
94
+ parser.add_argument(
95
+ "--train-data",
96
+ type=str,
97
+ default=train_data,
98
+ help="Name of training data file. Will be added to the end of `--data-dir`.",
99
+ )
100
+ parser.add_argument(
101
+ "--eval-data",
102
+ type=str,
103
+ default=eval_data,
104
+ help="Name of eval data file. Will be added to the end of `--data-dir`.",
105
+ )
106
+ parser.add_argument(
107
+ "--test-data",
108
+ type=str,
109
+ default=test_data,
110
+ help="Name of test data file. Will be added to the end of `--data-dir`.",
111
+ )
112
+ parser.add_argument("--output-dir", type=str, default=str(output_dir))
113
+ parser.add_argument(
114
+ "--model-name",
115
+ type=str,
116
+ help='Name of model. Supported values are "roberta-lm", "roberta-pred", "roberta-pred-mean-pool", "dnabert-lm", "dnabert-pred", "dnabert-pred-mean-pool"',
117
+ default=model_name,
118
+ )
119
+ parser.add_argument(
120
+ "--pretrained-model",
121
+ type=str,
122
+ help="Directory containing config.json and pytorch_model.bin files for loading pretrained huggingface model",
123
+ default=(str(pretrained_model) if pretrained_model else None),
124
+ )
125
+ parser.add_argument(
126
+ "--tokenizer-dir",
127
+ type=str,
128
+ help="Directory containing necessary files to instantiate pretrained tokenizer.",
129
+ default=str(tokenizer_dir),
130
+ )
131
+ parser.add_argument(
132
+ "--log-offset",
133
+ type=float,
134
+ help="Offset to apply to gene expression values before log transform",
135
+ default=log_offset,
136
+ )
137
+ parser.add_argument(
138
+ "--preprocessor",
139
+ type=str,
140
+ help="Path to pickled preprocessor file",
141
+ default=preprocessor,
142
+ )
143
+ parser.add_argument(
144
+ "--filter-empty",
145
+ help="Whether to filter out empty sequences.",
146
+ default=filter_empty,
147
+ action="store_true",
148
+ )
149
+ parser.add_argument(
150
+ "--tissue-subset", default=None, help="Subset of tissues to use", nargs="*"
151
+ )
152
+ parser.add_argument("--hyperparameter-search", action="store_true", default=False)
153
+ parser.add_argument("--ntrials", default=hyperparam_search_trials, type=int)
154
+ parser.add_argument("--metrics", default=hyperparam_search_metrics, nargs="*")
155
+ parser.add_argument("--direction", type=str, default="minimize")
156
+ parser.add_argument(
157
+ "--nshards",
158
+ type=int,
159
+ default=None,
160
+ help="Number of shards to divide data into; only the first is kept.",
161
+ )
162
+ parser.add_argument(
163
+ "--nshards-eval",
164
+ type=int,
165
+ default=None,
166
+ help="Number of shards to divide eval data into.",
167
+ )
168
+ parser.add_argument(
169
+ "--threshold",
170
+ type=float,
171
+ default=None,
172
+ help="Minimum value for filtering gene expression values.",
173
+ )
174
+ parser.add_argument(
175
+ "--transformation",
176
+ type=str,
177
+ default=transformation,
178
+ help='How to transform the data. Options are "log", "boxcox"',
179
+ )
180
+ parser.add_argument(
181
+ "--freeze-base",
182
+ action="store_true",
183
+ help="Freeze the pretrained base of the model",
184
+ )
185
+ parser.add_argument(
186
+ "--output-mode",
187
+ type=str,
188
+ help='Output mode for model: {"regression", "classification"}',
189
+ default=output_mode,
190
+ )
191
+ parser.add_argument(
192
+ "--learning-rate",
193
+ type=float,
194
+ help="Learning rate for training. Default None",
195
+ default=None,
196
+ )
197
+ parser.add_argument(
198
+ "--num-train-epochs",
199
+ type=int,
200
+ help="Number of epochs to train for",
201
+ default=None,
202
+ )
203
+ parser.add_argument(
204
+ "--search-metric",
205
+ type=str,
206
+ help="Metric to optimize in hyperparameter search",
207
+ default=None,
208
+ )
209
+ parser.add_argument("--batch-norm", action="store_true", default=False)
210
+ args = parser.parse_args()
211
+
212
+ if args.pretrained_model and not args.pretrained_model.startswith("/"):
213
+ args.pretrained_model = str(Path.cwd() / args.pretrained_model)
214
+
215
+ args.data_dir = Path(args.data_dir)
216
+ args.output_dir = Path(args.output_dir)
217
+
218
+ args.train_data = _get_fpath_if_not_none(args.data_dir, args.train_data)
219
+ args.eval_data = _get_fpath_if_not_none(args.data_dir, args.eval_data)
220
+ args.test_data = _get_fpath_if_not_none(args.data_dir, args.test_data)
221
+
222
+ args.preprocessor = Path(args.preprocessor) if args.preprocessor else None
223
+
224
+ if args.tissue_subset is not None:
225
+ if isinstance(args.tissue_subset, (int, str)):
226
+ args.tissue_subset = [args.tissue_subset]
227
+ args.tissue_subset = [
228
+ int(t) if t.isnumeric() else t for t in args.tissue_subset
229
+ ]
230
+ return args
231
+
232
+ def get_model_settings(
233
+ settings: dict, args: dict = None, model_name: str = None
234
+ ) -> dict:
235
+ """Get the appropriate model settings from the dictionary `settings`."""
236
+ if model_name is None:
237
+ model_name = args.model_name
238
+ base_model_name = model_name.split("-")[0] + "-base"
239
+ base_model_settings = settings["models"].get(base_model_name, {})
240
+ model_settings = settings["models"].get(model_name, {})
241
+ data_settings = settings["data"]
242
+ settings = dict(**base_model_settings, **model_settings, **data_settings)
243
+
244
+ if args is not None:
245
+ if args.output_mode:
246
+ settings["output_mode"] = args.output_mode
247
+ if args.tissue_subset is not None:
248
+ settings["num_labels"] = len(args.tissue_subset)
249
+ if args.batch_norm:
250
+ settings["batch_norm"] = args.batch_norm
251
+
252
+ return settings
253
+
254
+ def _get_fpath_if_not_none(
255
+ dirpath: PosixPath, fpath: PosixPath
256
+ ) -> Union[None, PosixPath]:
257
+ if fpath:
258
+ return dirpath / fpath
259
+ return None
260
+
261
+ def load_pickle(path: PosixPath) -> object:
262
+ with path.open("rb") as f:
263
+ obj = pickle.load(f)
264
+ return obj
prediction.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from module import config, transformers_utility as tr, utils, metrics, dataio
2
+ from prettytable import PrettyTable
3
+
4
+ table = PrettyTable()
5
+ table.field_names = config.tissues
6
+ TOKENIZER_DIR = config.models / "byte-level-bpe-tokenizer"
7
+ PRETRAINED_MODEL = config.models / "transformer" / "prediction-model"
8
+ DATA_DIR = config.data
9
+
10
+ def load_model(args, settings):
11
+ return tr.load_model(
12
+ args.model_name,
13
+ args.tokenizer_dir,
14
+ pretrained_model=args.pretrained_model,
15
+ log_offset=args.log_offset,
16
+ **settings,
17
+ )
18
+
19
+ def main(TEST_DATA):
20
+ args = utils.get_args(
21
+ data_dir=DATA_DIR,
22
+ train_data=TEST_DATA,
23
+ test_data=TEST_DATA,
24
+ pretrained_model=PRETRAINED_MODEL,
25
+ tokenizer_dir=TOKENIZER_DIR,
26
+ model_name="roberta-pred-mean-pool",
27
+ )
28
+
29
+ settings = utils.get_model_settings(config.settings, args)
30
+ if args.output_mode:
31
+ settings["output_mode"] = args.output_mode
32
+ if args.tissue_subset is not None:
33
+ settings["num_labels"] = len(args.tissue_subset)
34
+
35
+ print("Loading model...")
36
+ config_obj, tokenizer, model = load_model(args, settings)
37
+
38
+ print("Loading data...")
39
+ datasets = dataio.load_datasets(
40
+ tokenizer,
41
+ args.train_data,
42
+ eval_data=args.eval_data,
43
+ test_data=args.test_data,
44
+ seq_key="text",
45
+ file_type="text",
46
+ filter_empty=args.filter_empty,
47
+ shuffle=False,
48
+ )
49
+ dataset_test = datasets["train"]
50
+
51
+ print("Getting predictions:")
52
+ preds = metrics.get_predictions(model, dataset_test)
53
+ for e in preds:
54
+ table.add_row(e)
55
+ print(table)
56
+
57
+ if __name__ == "__main__":
58
+ main("test.txt")