File size: 1,322 Bytes
117183e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import argparse
import os.path
from PIL import Image
from models.color_naming import ColorNaming
from torchvision.transforms import functional as TF

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_categories', type=int, default=6)
    parser.add_argument('--image_path', type=str, default='/home/dserrano/Documents/datasets/FiveK-DPE/input/a0001-jmac_DSC1459.png')
    parser.add_argument('--output_path', type=str)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    color_naming = ColorNaming(num_categories=args.num_categories)

    if os.path.isfile(args.image_path):
        image_tensor = TF.pil_to_tensor(Image.open(args.image_path).convert('RGB')).unsqueeze(0)
        cn_probs = color_naming(image_tensor).float().repeat(1, 3, 1, 1).cpu()
        output_images = (1 - cn_probs) * 255 * torch.ones_like(image_tensor).repeat(args.num_categories, 1, 1, 1) + cn_probs * image_tensor.repeat(args.num_categories, 1, 1, 1)

        import matplotlib.pyplot as plt
        fig = plt.subplots(1, args.num_categories, figsize=(20, 20))
        for i in range(args.num_categories):
            plt.subplot(1, args.num_categories, i+1)
            plt.imshow(output_images[i].permute(1, 2, 0).numpy().astype('uint8'))
        plt.show()