|
import torchio as tio |
|
import torch |
|
from apps.model import model |
|
def preprocess_input(uploaded_file): |
|
subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)}) |
|
normalize_orientation = tio.ToCanonical() |
|
preprocess_spatial = tio.Compose([ |
|
normalize_orientation, |
|
tio.RescaleIntensity((0, 1)), |
|
tio.Resize((300, 300, 400)) |
|
]) |
|
transform = preprocess_spatial |
|
dataset = tio.SubjectsDataset([subject], transform=transform) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
ckpt_path = 'apps/best_model_36.ckpt' |
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
model.load_state_dict(checkpoint['state_dict']) |
|
model.to(device) |
|
model.eval() |
|
grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8)) |
|
aggregator = tio.inference.GridAggregator(grid_sampler) |
|
patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4) |
|
with torch.no_grad(): |
|
for patches_batch in patch_loader: |
|
input_tensor = patches_batch['CT']["data"].to(device) |
|
locations = patches_batch[tio.LOCATION] |
|
pred = model(input_tensor) |
|
aggregator.add_batch(pred, locations) |
|
output_tensor = aggregator.get_output_tensor() |
|
return output_tensor, dataset |
|
|