Spaces:
Running
on
A10G
Running
on
A10G
File size: 1,856 Bytes
320e465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import sys
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
from .raft import RAFT
from .utils import flow_viz
from .utils.utils import InputPadder
DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
padder = InputPadder(images.shape)
return padder.pad(images)[0]
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
# img_flo = np.concatenate([img, flo], axis=0)
img_flo = flo
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
# cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
# cv2.waitKey()
def demo(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
def RAFT_infer(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
return model
|