File size: 3,453 Bytes
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import os, sys
import torch
import cv2
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
from dataset_curation_pipeline.IC9600.ICNet import ICNet



inference_transform = transforms.Compose([
                        transforms.Resize((512,512)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ])

def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")):
    cm_ic_map = cm(ic_img) 
    heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8))
    ori_img = Image.fromarray(ori_img)
    blend = Image.blend(ori_img,heatmap,alpha=alpha)
    blend = np.array(blend)
    return blend

        
def infer_one_image(model, img_path):
    with torch.no_grad():
        ori_img = Image.open(img_path).convert("RGB")
        ori_height = ori_img.height
        ori_width = ori_img.width
        img = inference_transform(ori_img)
        img = img.cuda()
        img = img.unsqueeze(0)
        ic_score, ic_map = model(img)
        ic_score = ic_score.item()

        
        # ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear')
        
        ## gene ic map
        # ic_map_np = ic_map.squeeze().detach().cpu().numpy()
        # out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy'
        # out_ic_map_path = os.path.join(args.output, out_ic_map_name)
        # np.save(out_ic_map_path, ic_map_np)
        
        ## gene blend map
        # ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8')
        # blend_img = blend(np.array(ori_img), ic_map_img)
        # out_blend_img_name = os.path.basename(img_path).split('.')[0]  + '.png'
        # out_blend_img_path = os.path.join(args.output, out_blend_img_name)
        # cv2.imwrite(out_blend_img_path, blend_img)
        return ic_score

    
    
def infer_directory(img_dir):
    imgs = sorted(os.listdir(img_dir))
    scores = []
    for img in tqdm(imgs):
        img_path = os.path.join(img_dir, img)
        score = infer_one_image(img_path)

        scores.append((score, img_path))
        print(img_path, score)
    
    scores = sorted(scores, key=lambda x: x[0])
    scores = scores[::-1]

    for score in scores[:50]:
        print(score)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type = str, default = './example')
    parser.add_argument('-o', '--output', type = str, default = './out')
    parser.add_argument('-d', '--device', type = int, default=0)

    args  = parser.parse_args()
    
    model = ICNet()
    model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu')))
    model.eval()
    device = torch.device(args.device)
    model.to(device)
    
    inference_transform = transforms.Compose([
        transforms.Resize((512,512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    if os.path.isfile(args.input):
        infer_one_image(args.input)
    else:
        infer_directory(args.input)