Spaces:
Runtime error
Runtime error
File size: 4,826 Bytes
8a32844 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import matplotlib.pyplot as plt
import os, cv2
import numpy as np
from mono.utils.transform import gray_to_colormap
import shutil
import glob
from mono.utils.running import main_process
import torch
from html4vision import Col, imagetable
def save_raw_imgs(
pred: torch.tensor,
rgb: torch.tensor,
filename: str,
save_dir: str,
scale: float=200.0,
target: torch.tensor=None,
):
"""
Save raw GT, predictions, RGB in the same file.
"""
cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb)
cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16))
if target is not None:
cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16))
def save_val_imgs(
iter: int,
pred: torch.tensor,
target: torch.tensor,
rgb: torch.tensor,
filename: str,
save_dir: str,
tb_logger=None
):
"""
Save GT, predictions, RGB in the same file.
"""
rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb)
rgb = rgb.transpose((1, 2, 0))
cat_img = np.concatenate([rgb, pred_color, target_color], axis=0)
plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
# save to tensorboard
if tb_logger is not None:
tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
def save_normal_val_imgs(
iter: int,
pred: torch.tensor,
targ: torch.tensor,
rgb: torch.tensor,
filename: str,
save_dir: str,
tb_logger=None,
mask=None,
):
"""
Save GT, predictions, RGB in the same file.
"""
mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :]
std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :]
pred = pred.squeeze()
targ = targ.squeeze()
rgb = rgb.squeeze()
if pred.size(0) == 3:
pred = pred.permute(1,2,0)
if targ.size(0) == 3:
targ = targ.permute(1,2,0)
if rgb.size(0) == 3:
rgb = rgb.permute(1,2,0)
pred_color = vis_surface_normal(pred, mask)
targ_color = vis_surface_normal(targ, mask)
rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8)
try:
cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
except:
pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0]))
targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0]))
cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
# cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color)
# save to tensorboard
if tb_logger is not None:
tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor):
mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
pred = pred.squeeze().cpu().numpy()
target = target.squeeze().cpu().numpy()
rgb = rgb.squeeze().cpu().numpy()
pred[pred<0] = 0
target[target<0] = 0
max_scale = max(pred.max(), target.max())
pred_scale = (pred/max_scale * 10000).astype(np.uint16)
target_scale = (target/max_scale * 10000).astype(np.uint16)
pred_color = gray_to_colormap(pred)
target_color = gray_to_colormap(target)
pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1]))
target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1]))
rgb = ((rgb * std) + mean).astype(np.uint8)
return rgb, pred_scale, target_scale, pred_color, target_color
def create_html(name2path, save_path='index.html', size=(256, 384)):
# table description
cols = []
for k, v in name2path.items():
col_i = Col('img', k, v) # specify image content for column
cols.append(col_i)
# html table generation
imagetable(cols, out_file=save_path, imsize=size)
def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array:
"""
Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
Aargs:
normal (torch.tensor, [h, w, 3]): surface normal
mask (torch.tensor, [h, w]): valid masks
"""
normal = normal.cpu().numpy().squeeze()
n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
n_img_norm = normal / (n_img_L2 + 1e-8)
normal_vis = n_img_norm * 127
normal_vis += 128
normal_vis = normal_vis.astype(np.uint8)
if mask is not None:
mask = mask.cpu().numpy().squeeze()
normal_vis[~mask] = 0
return normal_vis
|