Gurveer05 commited on
Commit
d9d8dae
·
1 Parent(s): bf1f674

Improved model

Browse files
data/test.txt CHANGED
@@ -1 +1 @@
1
- CTCAAGCTGAGCAGTGGGTTTGCTCTGGAGGGGAAGCTCAACGGTGGCGACAAGGAAGAATCTGCTTGCGAGGCGAGCCCTGACGCCGCTGATAGCGACCAAAGGTGGATTAAACAACCCATTTCATCATTCTTCTTCCTTGTTAGTTATGATTCCCACGCTTGCCTTTCATGAATCATGATCCTATATGTATATTGATATTAATCAGTTCTAGAAAGTTCAACAACATTTGAGCATGTCAAAACCTGATCGTTGCCTGTTCCATGTCAACAGTGGATTATAACACGTGCAAATGTAGCTATTTGTGTGAGAAGACGTGTGATCGACTCTTTTTTTATATAGATAGCATTGAGATCAACTGTTTGTATATATCTTGTCATAACATTTTTACTTCGTAGCAACGTACGAGCGTTCACCTATTTGTATATAAGTTATCATGATATTTATAAGTTACCGTTGCAACGCACGGACACTCACCTAGTATAGTTTATGTATTACAGTACTAGGAGCCCTAGGCTTCCAATAACTAGAAAAAGTCCTGGTCAGTCGAACCAAACCACAATCCGACGTATACATTCTGGTTCCCCCACGCCCCCATCCGTTCGATTCA
 
1
+ ATGGACAAACTCTAGTAACGGT
models/transformer/prediction-model/saved_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b398b2fb6306ba1659ba6aefe6e280cc2b65d61bd15ed4a6234b0a8df43e0cc
3
+ size 191081761
module/__pycache__/dataio.cpython-311.pyc CHANGED
Binary files a/module/__pycache__/dataio.cpython-311.pyc and b/module/__pycache__/dataio.cpython-311.pyc differ
 
module/__pycache__/metrics.cpython-311.pyc CHANGED
Binary files a/module/__pycache__/metrics.cpython-311.pyc and b/module/__pycache__/metrics.cpython-311.pyc differ
 
module/__pycache__/transformers_utility.cpython-311.pyc CHANGED
Binary files a/module/__pycache__/transformers_utility.cpython-311.pyc and b/module/__pycache__/transformers_utility.cpython-311.pyc differ
 
module/transformers_utility.py CHANGED
@@ -1,6 +1,6 @@
1
  from pathlib import PosixPath
2
  from typing import Union, Optional
3
-
4
  from transformers import (
5
  RobertaConfig,
6
  RobertaTokenizerFast,
@@ -81,8 +81,13 @@ def load_model(model_name: str,
81
  )
82
  if pretrained_model:
83
  # print(f"Loading from pretrained model {pretrained_model}")
84
- model = model_class.from_pretrained(
85
- str(pretrained_model), config=config_obj)
 
 
 
 
 
86
  else:
87
  print("Loading untrained model")
88
  model = model_class(config=config_obj)
 
1
  from pathlib import PosixPath
2
  from typing import Union, Optional
3
+ import torch
4
  from transformers import (
5
  RobertaConfig,
6
  RobertaTokenizerFast,
 
81
  )
82
  if pretrained_model:
83
  # print(f"Loading from pretrained model {pretrained_model}")
84
+ model = model_class(config=config_obj)
85
+ state_dict = torch.load(pretrained_model)
86
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
87
+ unexpected_keys = [k for k in state_dict.keys() if 'position_ids' in k]
88
+ for key in unexpected_keys:
89
+ del state_dict[key]
90
+ model.load_state_dict(state_dict)
91
  else:
92
  print("Loading untrained model")
93
  model = model_class(config=config_obj)
prediction.py CHANGED
@@ -1,10 +1,11 @@
1
  from module import config, transformers_utility as tr, utils, metrics, dataio
2
  from prettytable import PrettyTable
 
3
 
4
  table = PrettyTable()
5
  table.field_names = config.tissues
6
  TOKENIZER_DIR = config.models / "byte-level-bpe-tokenizer"
7
- PRETRAINED_MODEL = config.models / "transformer" / "prediction-model"
8
  DATA_DIR = config.data
9
 
10
  def load_model(args, settings):
@@ -49,7 +50,7 @@ def main(TEST_DATA):
49
  dataset_test = datasets["train"]
50
 
51
  print("Getting predictions:")
52
- preds = metrics.get_predictions(model, dataset_test)
53
  for e in preds:
54
  table.add_row(e)
55
  print(table)
 
1
  from module import config, transformers_utility as tr, utils, metrics, dataio
2
  from prettytable import PrettyTable
3
+ import numpy as np
4
 
5
  table = PrettyTable()
6
  table.field_names = config.tissues
7
  TOKENIZER_DIR = config.models / "byte-level-bpe-tokenizer"
8
+ PRETRAINED_MODEL = config.models / "transformer" / "prediction-model" / "saved_model.pth"
9
  DATA_DIR = config.data
10
 
11
  def load_model(args, settings):
 
50
  dataset_test = datasets["train"]
51
 
52
  print("Getting predictions:")
53
+ preds = np.exp(np.array(metrics.get_predictions(model, dataset_test))) - 1
54
  for e in preds:
55
  table.add_row(e)
56
  print(table)