Neural-Style-Transfer-GPU / dataTransform.py
ailm's picture
Update dataTransform.py
bddbfa0 verified
raw
history blame contribute delete
846 Bytes
import torch
from PIL import Image
import torchvision.transforms as transforms #to transform the images
def load_image(image_path, device):
image_size = 356
loader = transforms.Compose(
[
transforms.Resize((image_size, image_size)), #RESIZE IMAGE
transforms.ToTensor() #TRANSFORM IMAGE TO TENSOR
]
)
image = Image.open(image_path)
image = loader(image).unsqueeze(0) #(h, c, w) -> (1, h, c, w) adds batch dim
return image.to(device)
def tensor_to_image(tensor):
tensor = tensor.clone().detach() # Ensure the tensor is detached from the graph
tensor = tensor.squeeze(0) # Remove batch dimension if present
tensor = torch.clamp(tensor, 0, 1) # Clamp the values to [0, 1] range
unloader = transforms.ToPILImage()
image = unloader(tensor.cpu())
return image