|
import torch |
|
|
|
def tensor_to_size(source, dest_size): |
|
if isinstance(dest_size, torch.Tensor): |
|
dest_size = dest_size.shape[0] |
|
source_size = source.shape[0] |
|
|
|
if source_size < dest_size: |
|
shape = [dest_size - source_size] + [1]*(source.dim()-1) |
|
source = torch.cat((source, source[-1:].repeat(shape)), dim=0) |
|
elif source_size > dest_size: |
|
source = source[:dest_size] |
|
|
|
return source |
|
|
|
def tensor_to_image(tensor): |
|
image = tensor.mul(255).clamp(0, 255).byte().cpu() |
|
image = image[..., [2, 1, 0]].numpy() |
|
return image |
|
|
|
def image_to_tensor(image): |
|
tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) |
|
tensor = tensor[..., [2, 1, 0]] |
|
return tensor |
|
|