import numpy as np from tqdm import tqdm, trange import os import argparse from glob import glob import torch from torch import utils from torch.nn import functional as F from torchvision.transforms import functional as TF from torchvision.transforms import InterpolationMode from video_module.dataset import Video_DS from video_module.model import AFB_URR, FeatureBank from test_image_seg import test_waterseg import myutils torch.set_grad_enabled(False) def get_args(): parser = argparse.ArgumentParser(description='V-FloodNet: Water Video Segmentation') parser.add_argument('--gpu', type=int, default=0, help='GPU card id.') parser.add_argument('--budget', type=int, default=250000, help='Max number of features in the feature bank.') parser.add_argument('--viz', action='store_true', default=True, help='Visualize data.') parser.add_argument('--model-path', type=str, required=True, help='Path to the checkpoint.') parser.add_argument('--update-rate', type=float, default=0.1, help='Update Rate for merging new features.') parser.add_argument('--merge-thres', type=float, default=0.95, help='Merging Rate threshold.') parser.add_argument('--test-path', type=str, required=True, help='Path to the test video frames.') parser.add_argument('--test-name', type=str, required=True, help='Name for the test video.') return parser.parse_args() def main(args, device): model = AFB_URR(device, update_bank=True, load_imagenet_params=False) model = model.to(device) model.eval() downsample_size = 480 if os.path.isfile(args.model_path): checkpoint = torch.load(args.model_path) end_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model'], strict=False) train_loss = checkpoint['loss'] seed = checkpoint['seed'] print(myutils.gct(), f'Loaded checkpoint {args.model_path}. (end_epoch: {end_epoch}, train_loss: {train_loss}, seed: {seed})') else: print(myutils.gct(), f'No checkpoint found at {args.model_path}') raise IOError img_list = sorted(glob(os.path.join(args.test_path, '*.jpg')) + glob(os.path.join(args.test_path, '*.png'))) first_frame = myutils.load_image_in_PIL(img_list[0]) first_name = os.path.basename(img_list[0])[:-4] out_dir = './output/segs' mask_dir = os.path.join(out_dir, args.test_name, 'mask') mask_path = os.path.join(mask_dir, first_name + '.png') if not os.path.exists(mask_path): image_model_path = './records/link_efficientb4_model.pth' test_waterseg(image_model_path, img_list[0], args.test_name, out_dir, device) first_mask = myutils.load_image_in_PIL(mask_path, 'P') seq_dataset = Video_DS(img_list, first_frame, first_mask) seq_loader = utils.data.DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=1) seg_dir = os.path.join(out_dir, args.test_name, 'mask') os.makedirs(seg_dir, exist_ok=True) if args.viz: overlay_dir = os.path.join(out_dir, args.test_name, 'overlay') os.makedirs(overlay_dir, exist_ok=True) obj_n = seq_dataset.obj_n fb = FeatureBank(obj_n, args.budget, device, update_rate=args.update_rate, thres_close=args.merge_thres) ori_first_frame = seq_dataset.first_frame.unsqueeze(0).to(device) ori_first_mask = seq_dataset.first_mask.unsqueeze(0).to(device) first_frame = TF.resize(ori_first_frame, downsample_size, InterpolationMode.BICUBIC) first_mask = TF.resize(ori_first_mask, downsample_size, InterpolationMode.NEAREST) pred = torch.argmax(ori_first_mask[0], dim=0).cpu().numpy().astype(np.uint8) seg_path = os.path.join(seg_dir, f'{first_name}.png') myutils.save_seg_mask(pred, seg_path, myutils.color_palette) if args.viz: overlay_path = os.path.join(overlay_dir, f'{first_name}.png') myutils.save_overlay(ori_first_frame[0], pred, overlay_path, myutils.color_palette) with torch.no_grad(): k4_list, v4_list = model.memorize(first_frame, first_mask) fb.init_bank(k4_list, v4_list) for idx, (frame, frame_name) in enumerate(tqdm(seq_loader)): ori_frame = frame.to(device) ori_size = ori_frame.shape[-2:] frame = TF.resize(ori_frame, downsample_size, InterpolationMode.BICUBIC) score, _ = model.segment(frame, fb) pred_mask = F.softmax(score, dim=1) k4_list, v4_list = model.memorize(frame, pred_mask) fb.update(k4_list, v4_list, idx + 1) pred = TF.resize(pred_mask, ori_size, InterpolationMode.BICUBIC) pred = torch.argmax(pred[0], dim=0).cpu().numpy().astype(np.uint8) pred = myutils.postprocessing_pred(pred) seg_path = os.path.join(seg_dir, f'{frame_name[0]}.png') myutils.save_seg_mask(pred, seg_path, myutils.color_palette) if args.viz: overlay_path = os.path.join(overlay_dir, f'{frame_name[0]}.png') myutils.save_overlay(ori_frame[0], pred, overlay_path, myutils.color_palette) fb.print_peak_mem() if __name__ == '__main__': args = get_args() print(myutils.gct(), 'Args =', args) if args.gpu >= 0 and torch.cuda.is_available(): device = torch.device('cuda', args.gpu) else: raise ValueError('CUDA is required. --gpu must be >= 0.') assert os.path.isdir(args.test_path) main(args, device) print(myutils.gct(), 'Test video segmentation done.')