medical_imaging_segmentation / apps /data_preprocessing.py
Kukulauren's picture
models and proprocessing
d679c76 verified
raw
history blame
1.37 kB
import torchio as tio
import torch
from apps.model import model
def preprocess_input(uploaded_file):
subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)})
normalize_orientation = tio.ToCanonical()
preprocess_spatial = tio.Compose([
normalize_orientation,
tio.RescaleIntensity((0, 1)),
tio.Resize((300, 300, 400))
])
transform = preprocess_spatial
dataset = tio.SubjectsDataset([subject], transform=transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = 'apps/best_model_36.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8))
aggregator = tio.inference.GridAggregator(grid_sampler)
patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4)
with torch.no_grad():
for patches_batch in patch_loader:
input_tensor = patches_batch['CT']["data"].to(device) # Get batch of patches
locations = patches_batch[tio.LOCATION] # Get locations of patches
pred = model(input_tensor) # Compute prediction
aggregator.add_batch(pred, locations)
output_tensor = aggregator.get_output_tensor()
return output_tensor, dataset