File size: 1,369 Bytes
d679c76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torchio as tio
import torch
from apps.model import model
def preprocess_input(uploaded_file):
    subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)})
    normalize_orientation = tio.ToCanonical()
    preprocess_spatial = tio.Compose([
        normalize_orientation,
        tio.RescaleIntensity((0, 1)),
        tio.Resize((300, 300, 400))
    ])
    transform = preprocess_spatial
    dataset = tio.SubjectsDataset([subject], transform=transform)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt_path = 'apps/best_model_36.ckpt'
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()
    grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8))
    aggregator = tio.inference.GridAggregator(grid_sampler)
    patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4)
    with torch.no_grad():
        for patches_batch in patch_loader:
            input_tensor = patches_batch['CT']["data"].to(device)  # Get batch of patches
            locations = patches_batch[tio.LOCATION]  # Get locations of patches
            pred = model(input_tensor)  # Compute prediction
            aggregator.add_batch(pred, locations)
    output_tensor = aggregator.get_output_tensor()
    return output_tensor, dataset