|
|
|
|
|
|
|
import os |
|
import sys |
|
import random |
|
import warnings |
|
|
|
os.system("export BUILD_WITH_CUDA=True") |
|
os.system("python -m pip install -e segment-anything") |
|
os.system("python -m pip install -e GroundingDINO") |
|
os.system("pip install --upgrade diffusers[torch]") |
|
|
|
sys.path.insert(0, './GroundingDINO') |
|
sys.path.insert(0, './segment-anything') |
|
warnings.filterwarnings("ignore") |
|
|
|
import cv2 |
|
from scipy import ndimage |
|
|
|
import gradio as gr |
|
import argparse |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from moviepy.editor import * |
|
import torch |
|
from torch.nn import functional as F |
|
import torchvision |
|
import networks |
|
import utils |
|
|
|
|
|
from groundingdino.util.inference import Model |
|
|
|
|
|
from segment_anything.utils.transforms import ResizeLongestSide |
|
|
|
|
|
from diffusers import StableDiffusionPipeline |
|
|
|
transform = ResizeLongestSide(1024) |
|
|
|
PALETTE_back = (51, 255, 146) |
|
|
|
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swint_ogc.pth" |
|
mam_checkpoint="checkpoints/mam_sam_vitb.pth" |
|
output_dir="outputs" |
|
device = 'cuda' |
|
background_list = os.listdir('assets/backgrounds') |
|
|
|
|
|
|
|
|
|
|
|
|
|
mam_model = networks.get_generator_m2m(seg='sam', m2m='sam_decoder_deep') |
|
mam_model.to(device) |
|
checkpoint = torch.load(mam_checkpoint, map_location=device) |
|
mam_model.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True) |
|
mam_model = mam_model.eval() |
|
|
|
|
|
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=device) |
|
|
|
|
|
generator = StableDiffusionPipeline.from_pretrained("checkpoints/stable-diffusion-v1-5", torch_dtype=torch.float16) |
|
generator.to(device) |
|
|
|
def get_frames(video_in): |
|
frames = [] |
|
|
|
clip = VideoFileClip(video_in) |
|
|
|
|
|
if clip.fps > 30: |
|
print("vide rate is over 30, resetting to 30") |
|
clip_resized = clip.resize(height=512) |
|
clip_resized.write_videofile("video_resized.mp4", fps=30) |
|
else: |
|
print("video rate is OK") |
|
clip_resized = clip.resize(height=512) |
|
clip_resized.write_videofile("video_resized.mp4", fps=clip.fps) |
|
|
|
print("video resized to 512 height") |
|
|
|
|
|
cap= cv2.VideoCapture("video_resized.mp4") |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
print("video fps: " + str(fps)) |
|
i=0 |
|
while(cap.isOpened()): |
|
ret, frame = cap.read() |
|
if ret == False: |
|
break |
|
cv2.imwrite('kang'+str(i)+'.jpg',frame) |
|
frames.append('kang'+str(i)+'.jpg') |
|
i+=1 |
|
|
|
cap.release() |
|
cv2.destroyAllWindows() |
|
print("broke the video into frames") |
|
|
|
return frames, fps |
|
|
|
|
|
def create_video(frames, fps, type): |
|
print("building video result") |
|
clip = ImageSequenceClip(frames, fps=fps) |
|
clip.write_videofile(f"video_{type}_result.mp4", fps=fps) |
|
|
|
return f"video_{type}_result.mp4" |
|
|
|
|
|
def run_grounded_sam(input_image, text_prompt, task_type, background_prompt, bg_already): |
|
background_type = "generated_by_text" |
|
box_threshold = 0.25 |
|
text_threshold = 0.25 |
|
iou_threshold = 0.5 |
|
scribble_mode = "split" |
|
guidance_mode = "alpha" |
|
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_ori = input_image |
|
|
|
original_size = image_ori.shape[:2] |
|
|
|
if task_type == 'text': |
|
if text_prompt is None: |
|
print('Please input non-empty text prompt') |
|
with torch.no_grad(): |
|
detections, phrases = grounding_dino_model.predict_with_caption( |
|
image=cv2.cvtColor(image_ori, cv2.COLOR_RGB2BGR), |
|
caption=text_prompt, |
|
box_threshold=box_threshold, |
|
text_threshold=text_threshold |
|
) |
|
|
|
if len(detections.xyxy) > 1: |
|
nms_idx = torchvision.ops.nms( |
|
torch.from_numpy(detections.xyxy), |
|
torch.from_numpy(detections.confidence), |
|
iou_threshold, |
|
).numpy().tolist() |
|
|
|
detections.xyxy = detections.xyxy[nms_idx] |
|
detections.confidence = detections.confidence[nms_idx] |
|
|
|
bbox = detections.xyxy[np.argmax(detections.confidence)] |
|
bbox = transform.apply_boxes(bbox, original_size) |
|
bbox = torch.as_tensor(bbox, dtype=torch.float).to(device) |
|
|
|
image = transform.apply_image(image_ori) |
|
image = torch.as_tensor(image).to(device) |
|
image = image.permute(2, 0, 1).contiguous() |
|
|
|
pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3,1,1).to(device) |
|
pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3,1,1).to(device) |
|
|
|
image = (image - pixel_mean) / pixel_std |
|
|
|
h, w = image.shape[-2:] |
|
pad_size = image.shape[-2:] |
|
padh = 1024 - h |
|
padw = 1024 - w |
|
image = F.pad(image, (0, padw, 0, padh)) |
|
|
|
if task_type == 'scribble_point': |
|
scribble = scribble.transpose(2, 1, 0)[0] |
|
labeled_array, num_features = ndimage.label(scribble >= 255) |
|
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) |
|
centers = np.array(centers) |
|
|
|
centers = transform.apply_coords(centers, original_size) |
|
point_coords = torch.from_numpy(centers).to(device) |
|
point_coords = point_coords.unsqueeze(0).to(device) |
|
point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device) |
|
if scribble_mode == 'split': |
|
point_coords = point_coords.permute(1, 0, 2) |
|
point_labels = point_labels.permute(1, 0) |
|
|
|
sample = {'image': image.unsqueeze(0), 'point': point_coords, 'label': point_labels, 'ori_shape': original_size, 'pad_shape': pad_size} |
|
elif task_type == 'scribble_box': |
|
scribble = scribble.transpose(2, 1, 0)[0] |
|
labeled_array, num_features = ndimage.label(scribble >= 255) |
|
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) |
|
centers = np.array(centers) |
|
|
|
x_min = centers[:, 0].min() |
|
x_max = centers[:, 0].max() |
|
y_min = centers[:, 1].min() |
|
y_max = centers[:, 1].max() |
|
bbox = np.array([x_min, y_min, x_max, y_max]) |
|
bbox = transform.apply_boxes(bbox, original_size) |
|
bbox = torch.as_tensor(bbox, dtype=torch.float).to(device) |
|
|
|
sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size} |
|
elif task_type == 'text': |
|
sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size} |
|
else: |
|
print("task_type:{} error!".format(task_type)) |
|
|
|
with torch.no_grad(): |
|
feas, pred, post_mask = mam_model.forward_inference(sample) |
|
|
|
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8'] |
|
alpha_pred_os8 = alpha_pred_os8[..., : sample['pad_shape'][0], : sample['pad_shape'][1]] |
|
alpha_pred_os4 = alpha_pred_os4[..., : sample['pad_shape'][0], : sample['pad_shape'][1]] |
|
alpha_pred_os1 = alpha_pred_os1[..., : sample['pad_shape'][0], : sample['pad_shape'][1]] |
|
|
|
alpha_pred_os8 = F.interpolate(alpha_pred_os8, sample['ori_shape'], mode="bilinear", align_corners=False) |
|
alpha_pred_os4 = F.interpolate(alpha_pred_os4, sample['ori_shape'], mode="bilinear", align_corners=False) |
|
alpha_pred_os1 = F.interpolate(alpha_pred_os1, sample['ori_shape'], mode="bilinear", align_corners=False) |
|
|
|
if guidance_mode == 'mask': |
|
weight_os8 = utils.get_unknown_tensor_from_mask_oneside(post_mask, rand_width=10, train_mode=False) |
|
post_mask[weight_os8>0] = alpha_pred_os8[weight_os8>0] |
|
alpha_pred = post_mask.clone().detach() |
|
else: |
|
weight_os8 = utils.get_unknown_box_from_mask(post_mask) |
|
alpha_pred_os8[weight_os8>0] = post_mask[weight_os8>0] |
|
alpha_pred = alpha_pred_os8.clone().detach() |
|
|
|
|
|
weight_os4 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=20, train_mode=False) |
|
alpha_pred[weight_os4>0] = alpha_pred_os4[weight_os4>0] |
|
|
|
weight_os1 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=10, train_mode=False) |
|
alpha_pred[weight_os1>0] = alpha_pred_os1[weight_os1>0] |
|
|
|
alpha_pred = alpha_pred[0][0].cpu().numpy() |
|
|
|
|
|
|
|
alpha_rgb = cv2.cvtColor(np.uint8(alpha_pred*255), cv2.COLOR_GRAY2RGB) |
|
|
|
global background_img |
|
if background_type == 'real_world_sample': |
|
background_img_file = os.path.join('assets/backgrounds', random.choice(background_list)) |
|
background_img = cv2.imread(background_img_file) |
|
background_img = cv2.cvtColor(background_img, cv2.COLOR_BGR2RGB) |
|
background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0])) |
|
com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img) |
|
com_img = np.uint8(com_img) |
|
else: |
|
if background_prompt is None: |
|
print('Please input non-empty background prompt') |
|
else: |
|
if bg_already is False: |
|
background_img = generator(background_prompt).images[0] |
|
|
|
background_img = np.array(background_img) |
|
background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0])) |
|
com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img) |
|
com_img = np.uint8(com_img) |
|
|
|
green_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.array([PALETTE_back], dtype='uint8') |
|
green_img = np.uint8(green_img) |
|
|
|
return com_img, green_img, alpha_rgb |
|
|
|
def infer(video_in, trim_value, prompt, background_prompt): |
|
print(prompt) |
|
break_vid = get_frames(video_in) |
|
|
|
frames_list= break_vid[0] |
|
fps = break_vid[1] |
|
n_frame = int(trim_value*fps) |
|
|
|
if n_frame >= len(frames_list): |
|
print("video is shorter than the cut value") |
|
n_frame = len(frames_list) |
|
|
|
with_bg_result_frames = [] |
|
with_green_result_frames = [] |
|
with_matte_result_frames = [] |
|
|
|
print("set stop frames to: " + str(n_frame)) |
|
bg_already = False |
|
for i in frames_list[0:int(n_frame)]: |
|
to_numpy_i = Image.open(i).convert("RGB") |
|
|
|
|
|
image_array = np.array(to_numpy_i) |
|
|
|
results = run_grounded_sam(image_array, prompt, "text", background_prompt, bg_already) |
|
bg_already = True |
|
bg_img = Image.fromarray(results[0]) |
|
green_img = Image.fromarray(results[1]) |
|
matte_img = Image.fromarray(results[2]) |
|
|
|
|
|
|
|
bg_img.save(f"bg_result_img-{i}.jpg") |
|
with_bg_result_frames.append(f"bg_result_img-{i}.jpg") |
|
green_img.save(f"green_result_img-{i}.jpg") |
|
with_green_result_frames.append(f"green_result_img-{i}.jpg") |
|
matte_img.save(f"matte_result_img-{i}.jpg") |
|
with_matte_result_frames.append(f"matte_result_img-{i}.jpg") |
|
print("frame " + i + "/" + str(n_frame) + ": done;") |
|
|
|
vid_bg = create_video(with_bg_result_frames, fps, "bg") |
|
vid_green = create_video(with_green_result_frames, fps, "greenscreen") |
|
vid_matte = create_video(with_matte_result_frames, fps, "matte") |
|
|
|
bg_already = False |
|
print("finished !") |
|
|
|
return vid_bg, vid_green, vid_matte |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser("MAM demo", add_help=True) |
|
parser.add_argument("--debug", action="store_true", help="using debug mode") |
|
parser.add_argument("--share", action="store_true", help="share the app") |
|
parser.add_argument('--port', type=int, default=7589, help='port to run the server') |
|
parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint') |
|
args = parser.parse_args() |
|
|
|
print(args) |
|
|
|
block = gr.Blocks() |
|
if not args.no_gradio_queue: |
|
block = block.queue() |
|
|
|
with block: |
|
gr.Markdown( |
|
""" |
|
# Matting Anything in Video Demo |
|
Welcome to the Matting Anything in Video demo by @fffiloni and upload your video to get started <br/> |
|
You may open usage details below to understand how to use this demo. |
|
## Usage |
|
<details> |
|
You may upload a video to start, for the moment we only support 1 prompt type to get the alpha matte of the target: |
|
**text**: Send text prompt to identify the target instance in the `Text prompt` box. |
|
|
|
We also only support 1 background type to support image composition with the alpha matte output: |
|
**generated_by_text**: Send background text prompt to create a background image with stable diffusion model in the `Background prompt` box. |
|
|
|
</details> |
|
<a href="https://huggingface.co/spaces/fffiloni/Video-Matting-Anything?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> |
|
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> |
|
for longer sequences, more control and no queue. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_in = gr.Video() |
|
trim_in = gr.Slider(label="Cut video at (s)", minimum=1, maximum=10, step=1, value=1) |
|
|
|
|
|
text_prompt = gr.Textbox(label="Text prompt", placeholder="the girl in the middle", info="Describe the subject visible in your video that you want to matte") |
|
|
|
background_prompt = gr.Textbox(label="Background prompt", placeholder="downtown area in New York") |
|
|
|
run_button = gr.Button("Run") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
|
|
vid_bg_out = gr.Video(label="Video with background") |
|
with gr.Row(): |
|
vid_green_out = gr.Video(label="Video green screen") |
|
vid_matte_out = gr.Video(label="Video matte") |
|
|
|
gr.Examples( |
|
fn=infer, |
|
examples=[ |
|
[ |
|
"./examples/example_men_bottle.mp4", |
|
10, |
|
"the man holding a bottle", |
|
"the Sahara desert" |
|
] |
|
], |
|
inputs=[video_in, trim_in, text_prompt, background_prompt], |
|
outputs=[vid_bg_out, vid_green_out, vid_matte_out] |
|
) |
|
run_button.click(fn=infer, inputs=[ |
|
video_in, trim_in, text_prompt, background_prompt], outputs=[vid_bg_out, vid_green_out, vid_matte_out], api_name="go_matte") |
|
|
|
block.queue(max_size=24).launch(debug=args.debug, share=args.share, show_error=True) |
|
|
|
|
|
|