MingGatsby commited on
Commit
7c91758
1 Parent(s): 67cbdeb

Upload 2 files

Browse files
Files changed (2) hide show
  1. Util/Custom_Model.py +44 -0
  2. Util/DICOM.py +81 -0
Util/Custom_Model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+
4
+ class Build_Custom_Model(nn.Module):
5
+ def __init__(self, model_name, target_size, pretrained=False):
6
+ super().__init__()
7
+ self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1)
8
+ if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"):
9
+ self.n_features = self.model.head.in_features
10
+ self.model.head = nn.Linear(self.n_features, target_size)
11
+ if(model_name=="resnet34d"):
12
+ self.n_features = self.model.fc.in_features
13
+ self.model.fc = nn.Linear(self.n_features, target_size)
14
+ if(model_name=="resnet18d"):
15
+ self.n_features = self.model.fc.in_features
16
+ self.model.fc = nn.Linear(self.n_features, target_size)
17
+ if(model_name=="tf_efficientnet_b7_ns"):
18
+ self.n_features = self.model.classifier.in_features
19
+ self.model.classifier = nn.Linear(self.n_features, target_size)
20
+ if(model_name=="tf_efficientnet_b0_ns"):
21
+ self.n_features = self.model.classifier.in_features
22
+ self.model.classifier = nn.Linear(self.n_features, target_size)
23
+ if(model_name=="tf_efficientnet_lite0"):
24
+ self.n_features = self.model.classifier.in_features
25
+ self.model.classifier = nn.Linear(self.n_features, target_size)
26
+ if(model_name=="mobilenetv2_050"):
27
+ self.n_features = self.model.classifier.in_features
28
+ self.model.classifier = nn.Linear(self.n_features, target_size)
29
+ if(model_name=="eca_nfnet_l0"):
30
+ self.n_features = self.model.head.fc.in_features
31
+ self.model.head.fc = nn.Linear(self.n_features, target_size)
32
+
33
+ def forward(self, x):
34
+ output = self.model(x)
35
+ return output
36
+
37
+ def reshape_transform(tensor, height=7, width=7):
38
+ result = tensor.reshape(tensor.size(0),
39
+ height, width, tensor.size(2))
40
+
41
+ # Bring the channels to the first dimension,
42
+ # like in CNNs.
43
+ result = result.transpose(2, 3).transpose(1, 2)
44
+ return result
Util/DICOM.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ class DICOM_Utils(object):
5
+ def apply_windowing(image_array, window_center, window_width):
6
+ """
7
+ Apply windowing to a DICOM image array.
8
+
9
+ Parameters:
10
+ - image_array: numpy array of the DICOM image
11
+ - window_center: center of the window
12
+ - window_width: width of the window
13
+
14
+ Returns:
15
+ - Windowed image array
16
+ """
17
+ lower_bound = window_center - (window_width / 2)
18
+ upper_bound = window_center + (window_width / 2)
19
+
20
+ # Apply windowing
21
+ windowed_image = image_array.copy()
22
+ windowed_image[windowed_image < lower_bound] = lower_bound
23
+ windowed_image[windowed_image > upper_bound] = upper_bound
24
+
25
+ # Normalize to [0, 255]
26
+ windowed_image = ((windowed_image - lower_bound) / window_width) * 255
27
+
28
+ return windowed_image.astype('uint8')
29
+
30
+ def transform_image_for_display(image_array):
31
+ """
32
+ Transform the image for display: Flip horizontally and then rotate 90 degrees to the right.
33
+
34
+ Parameters:
35
+ - image_array: numpy array of the image
36
+
37
+ Returns:
38
+ - Transformed image array
39
+ """
40
+ # Flip horizontally
41
+ flipped_image = np.fliplr(image_array)
42
+
43
+ # Rotate 90 degrees to the right
44
+ rotated_image = np.rot90(flipped_image, 1)
45
+
46
+ return rotated_image
47
+
48
+ def apply_CAM_overlay(heatmap, windowed_image, overlay_alpha=0.4):
49
+ """
50
+ Apply CAM (Class Activation Map) overlay to a given image.
51
+
52
+ Parameters:
53
+ - heatmap: torch.Tensor, the heatmap generated by CAM.
54
+ - windowed_image: numpy.ndarray, the windowed image to overlay the heatmap on.
55
+ - overlay_alpha: float, the transparency for overlaying heatmap. Default is 0.4.
56
+
57
+ Returns:
58
+ - overlayed: numpy.ndarray, the resulting image after overlaying the heatmap.
59
+ """
60
+ # Convert the heatmap tensor to a numpy array
61
+ heatmap_np = heatmap.cpu().numpy().squeeze()
62
+
63
+ # Normalize the heatmap to [0, 255]
64
+ heatmap_normalized = ((heatmap_np - heatmap_np.min()) /
65
+ (heatmap_np.max() - heatmap_np.min()) * 255).astype(np.uint8)
66
+
67
+ # Convert the normalized heatmap to a colormap (for example, using the "jet" colormap)
68
+ heatmap_colormap = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_JET)
69
+
70
+ # Resize the colormap to the original image size
71
+ heatmap_resized = cv2.resize(heatmap_colormap,
72
+ (windowed_image.shape[1], windowed_image.shape[0]))
73
+
74
+ # Convert the grayscale windowed_image to 3 channels
75
+ windowed_image_colored = cv2.cvtColor(windowed_image, cv2.COLOR_GRAY2BGR)
76
+
77
+ # Overlay the heatmap on the original image with a certain transparency
78
+ overlayed = cv2.addWeighted(windowed_image_colored, 1 - overlay_alpha,
79
+ heatmap_resized, overlay_alpha, 0)
80
+
81
+ return overlayed