File size: 846 Bytes
bddbfa0
 
ed7df54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from PIL import Image
import torchvision.transforms as transforms  #to transform the images


def load_image(image_path, device):

  image_size = 356

  loader = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),  #RESIZE IMAGE
            transforms.ToTensor()  #TRANSFORM IMAGE TO TENSOR
        ]
    )

  image = Image.open(image_path)
  image = loader(image).unsqueeze(0) #(h, c, w) -> (1, h, c, w) adds batch dim

  return image.to(device)


def tensor_to_image(tensor):
    tensor = tensor.clone().detach()  # Ensure the tensor is detached from the graph
    tensor = tensor.squeeze(0)  # Remove batch dimension if present
    tensor = torch.clamp(tensor, 0, 1)  # Clamp the values to [0, 1] range

    unloader = transforms.ToPILImage()
    image = unloader(tensor.cpu())
    return image