from PIL import Image import torchvision.transforms as transforms #to transform the images def load_image(image_path, device): image_size = 356 loader = transforms.Compose( [ transforms.Resize((image_size, image_size)), #RESIZE IMAGE transforms.ToTensor() #TRANSFORM IMAGE TO TENSOR ] ) image = Image.open(image_path) image = loader(image).unsqueeze(0) #(h, c, w) -> (1, h, c, w) adds batch dim return image.to(device) def tensor_to_image(tensor): tensor = tensor.clone().detach() # Ensure the tensor is detached from the graph tensor = tensor.squeeze(0) # Remove batch dimension if present tensor = torch.clamp(tensor, 0, 1) # Clamp the values to [0, 1] range unloader = transforms.ToPILImage() image = unloader(tensor.cpu()) return image