jesseab commited on
Commit
34517cd
·
1 Parent(s): 31eae02

Updated model.py and added .joblib

Browse files
Files changed (2) hide show
  1. model.py +69 -75
  2. pca_model.joblib +3 -0
model.py CHANGED
@@ -1,9 +1,9 @@
1
  # model.py
2
  import os
3
- from typing import Optional
4
-
5
  import torch
6
  import torch.nn as nn
 
7
  from monai.transforms import (
8
  Compose,
9
  CopyItemsD,
@@ -14,11 +14,20 @@ from monai.transforms import (
14
  ScaleIntensityD,
15
  )
16
 
17
- # Constants for your typical config
 
 
 
 
 
18
  RESOLUTION = 2
19
- INPUT_SHAPE_AE = (80, 96, 80)
 
20
 
21
- # Define the exact transform pipeline for input MRI
 
 
 
22
  transforms_fn = Compose([
23
  CopyItemsD(keys={'image_path'}, names=['image']),
24
  LoadImageD(image_only=True, keys=['image']),
@@ -28,94 +37,79 @@ transforms_fn = Compose([
28
  ScaleIntensityD(minv=0, maxv=1, keys=['image']),
29
  ])
30
 
31
- def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
 
32
  """
33
  Preprocess an MRI using MONAI transforms to produce
34
- a 5D tensor (batch=1, channels=1, D, H, W) for inference.
35
  """
36
  data_dict = {"image_path": image_path}
37
  output_dict = transforms_fn(data_dict)
38
- image_tensor = output_dict["image"] # shape: (1, D, H, W)
39
- image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
40
- return image_tensor.to(device)
41
 
42
 
43
- class ShallowLinearAutoencoder(nn.Module):
 
 
 
44
  """
45
- A purely linear autoencoder with one hidden layer.
46
- - Flatten input into a vector
47
- - Linear encoder (no activation)
48
- - Linear decoder (no activation)
49
- - Reshape output to original volume shape
50
- """
51
- def __init__(self, input_shape=(80, 96, 80), hidden_size=1200):
52
- super().__init__()
53
- self.input_shape = input_shape
54
- self.input_dim = input_shape[0] * input_shape[1] * input_shape[2]
55
- self.hidden_size = hidden_size
56
-
57
- # Encoder (no activation for PCA-like behavior)
58
- self.encoder = nn.Sequential(
59
- nn.Flatten(),
60
- nn.Linear(self.input_dim, self.hidden_size),
61
- )
62
-
63
- # Decoder (no activation)
64
- self.decoder = nn.Sequential(
65
- nn.Linear(self.hidden_size, self.input_dim),
66
- )
67
 
68
- def encode(self, x: torch.Tensor):
69
- return self.encoder(x)
 
 
 
 
 
70
 
71
- def decode(self, z: torch.Tensor):
72
- out = self.decoder(z)
73
- # Reshape to (N, 1, D, H, W)
74
- return out.view(-1, 1, *self.input_shape)
75
 
76
  def forward(self, x: torch.Tensor):
77
  """
78
- Return (reconstruction, embedding, None) to keep a similar API
79
- to the old VAE-based code, though there's no σ for sampling.
 
 
 
 
 
80
  """
81
- z = self.encode(x)
82
- reconstruction = self.decode(z)
83
- return reconstruction, z, None
 
 
84
 
 
 
85
 
86
- class Brain2vec(nn.Module):
87
- """
88
- A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...)
89
- method for model loading, mirroring the old usage with AutoencoderKL.
90
- """
91
- def __init__(self, device: str = "cpu"):
92
- super().__init__()
93
- # Instantiate the shallow linear model
94
- self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200)
95
- self.to(device)
96
 
97
- def forward(self, x: torch.Tensor):
98
- """
99
- Forward pass that returns (reconstruction, embedding, None).
100
- """
101
- return self.model(x)
102
 
103
  @staticmethod
104
- def from_pretrained(
105
- checkpoint_path: Optional[str] = None,
106
- device: str = "cpu"
107
- ) -> nn.Module:
108
  """
109
- Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided.
110
- Args:
111
- checkpoint_path (Optional[str]): path to a .pth checkpoint
112
- device (str): "cpu", "cuda", etc.
113
  """
114
- model = Brain2vec(device=device)
115
- if checkpoint_path is not None:
116
- if not os.path.exists(checkpoint_path):
117
- raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
118
- state_dict = torch.load(checkpoint_path, map_location=device)
119
- model.load_state_dict(state_dict)
120
- model.eval()
121
- return model
 
1
  # model.py
2
  import os
3
+ import numpy as np
 
4
  import torch
5
  import torch.nn as nn
6
+
7
  from monai.transforms import (
8
  Compose,
9
  CopyItemsD,
 
14
  ScaleIntensityD,
15
  )
16
 
17
+ # If you used joblib or pickle to save your PCA model:
18
+ from joblib import load # or "import pickle"
19
+
20
+ #################################################
21
+ # Constants
22
+ #################################################
23
  RESOLUTION = 2
24
+ INPUT_SHAPE_AE = (80, 96, 80) # The typical shape from your pipelines
25
+ FLATTENED_DIM = INPUT_SHAPE_AE[0] * INPUT_SHAPE_AE[1] * INPUT_SHAPE_AE[2]
26
 
27
+
28
+ #################################################
29
+ # Define MONAI Transforms for Preprocessing
30
+ #################################################
31
  transforms_fn = Compose([
32
  CopyItemsD(keys={'image_path'}, names=['image']),
33
  LoadImageD(image_only=True, keys=['image']),
 
37
  ScaleIntensityD(minv=0, maxv=1, keys=['image']),
38
  ])
39
 
40
+
41
+ def preprocess_mri(image_path: str) -> torch.Tensor:
42
  """
43
  Preprocess an MRI using MONAI transforms to produce
44
+ a 5D Torch tensor: (batch=1, channel=1, D, H, W).
45
  """
46
  data_dict = {"image_path": image_path}
47
  output_dict = transforms_fn(data_dict)
48
+ # shape => (1, D, H, W)
49
+ image_tensor = output_dict["image"].unsqueeze(0) # => (batch=1, channel=1, D, H, W)
50
+ return image_tensor.float() # typically float32
51
 
52
 
53
+ #################################################
54
+ # PCA "Autoencoder" Wrapper
55
+ #################################################
56
+ class PCABrain2vec(nn.Module):
57
  """
58
+ A PCA-based 'autoencoder' that mimics the old interface:
59
+ - from_pretrained(...) to load a PCA model from disk
60
+ - forward(...) returns (reconstruction, embedding, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ Under the hood, it:
63
+ - takes in a torch tensor shape (N, 1, D, H, W)
64
+ - flattens it (N, 614400)
65
+ - uses PCA's transform(...) to get embeddings => shape (N, n_components)
66
+ - uses inverse_transform(...) to get reconstructions => shape (N, 614400)
67
+ - reshapes back to (N, 1, D, H, W)
68
+ """
69
 
70
+ def __init__(self, pca_model=None):
71
+ super().__init__()
72
+ # We'll store the fitted PCA model (from scikit-learn)
73
+ self.pca_model = pca_model # e.g., an instance of IncrementalPCA or PCA
74
 
75
  def forward(self, x: torch.Tensor):
76
  """
77
+ Returns (reconstruction, embedding, None).
78
+
79
+ 1) Convert x => numpy array => flatten => (N, 614400)
80
+ 2) embedding = pca_model.transform(flat_x)
81
+ 3) reconstruction_np = pca_model.inverse_transform(embedding)
82
+ 4) reshape => (N, 1, 80, 96, 80)
83
+ 5) convert to torch => return (recon, embed, None)
84
  """
85
+ # Expect x shape => (N, 1, D, H, W) => flatten to (N, D*H*W)
86
+ n_samples = x.shape[0]
87
+ # Convert to CPU np
88
+ x_cpu = x.detach().cpu().numpy() # shape: (N, 1, D, H, W)
89
+ x_flat = x_cpu.reshape(n_samples, -1) # shape: (N, 614400)
90
 
91
+ # PCA transform => embeddings shape (N, n_components)
92
+ embedding_np = self.pca_model.transform(x_flat)
93
 
94
+ # PCA inverse_transform => recon shape (N, 614400)
95
+ recon_np = self.pca_model.inverse_transform(embedding_np)
96
+ # Reshape back => (N, 1, 80, 96, 80)
97
+ recon_np = recon_np.reshape(n_samples, 1, *INPUT_SHAPE_AE)
 
 
 
 
 
 
98
 
99
+ # Convert back to torch
100
+ reconstruction_torch = torch.from_numpy(recon_np).float()
101
+ embedding_torch = torch.from_numpy(embedding_np).float()
102
+ return reconstruction_torch, embedding_torch, None
 
103
 
104
  @staticmethod
105
+ def from_pretrained(pca_path: str):
 
 
 
106
  """
107
+ Load a pre-trained PCA model (pickled or joblib).
108
+ Returns an instance of PCABrain2vec with that model.
 
 
109
  """
110
+ if not os.path.exists(pca_path):
111
+ raise FileNotFoundError(f"Could not find PCA model at {pca_path}")
112
+ # Example: pca_model = pickle.load(open(pca_path, 'rb'))
113
+ # or use joblib:
114
+ pca_model = load(pca_path)
115
+ return PCABrain2vec(pca_model=pca_model)
 
 
pca_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1806d58fc32b8132cc7cfbc252dcb613d64a76bbc2836440a67f16eb3a585c4f
3
+ size 2951592991