Spaces:
Sleeping
Sleeping
import os | |
import re | |
import shutil | |
import time | |
from types import SimpleNamespace | |
from typing import Any | |
import gradio as gr | |
import numpy as np | |
from detectron2 import engine | |
from PIL import Image | |
from inference import main, setup_cfg | |
# internal settings | |
NUM_PROCESSES = 1 | |
CROP = False | |
SCORE_THRESHOLD = 0.8 | |
MAX_PARTS = 5 | |
ARGS = SimpleNamespace( | |
config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml", | |
model="../data/models/motion_state_pred_opdformerp_rgb.pth", | |
input_format="RGB", | |
output=".output", | |
cpu=True, | |
) | |
NUM_SAMPLES = 10 | |
outputs = [] | |
def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_samples: int) -> list[Any]: | |
global outputs | |
def find_gifs(path: str) -> list[str]: | |
"""Scrape folders for all generated gif files.""" | |
for file in os.listdir(path): | |
sub_path = os.path.join(path, file) | |
if os.path.isdir(sub_path): | |
for image_file in os.listdir(sub_path): | |
if re.match(r".*\.gif$", image_file): | |
yield os.path.join(sub_path, image_file) | |
def find_images(path: str) -> list[str]: | |
"""Scrape folders for all generated gif files.""" | |
images = {} | |
for file in os.listdir(path): | |
sub_path = os.path.join(path, file) | |
if os.path.isdir(sub_path): | |
images[file] = [] | |
for image_file in sorted(os.listdir(sub_path)): | |
if re.match(r".*\.png$", image_file): | |
images[file].append(os.path.join(sub_path, image_file)) | |
return images | |
# clear old predictions | |
for path in os.listdir(ARGS.output): | |
full_path = os.path.join(ARGS.output, path) | |
if os.path.isdir(full_path): | |
shutil.rmtree(full_path) | |
else: | |
os.remove(full_path) | |
cfg = setup_cfg(ARGS) | |
engine.launch( | |
main, | |
NUM_PROCESSES, | |
args=( | |
cfg, | |
rgb_image, | |
depth_image, | |
intrinsics, | |
num_samples, | |
CROP, | |
SCORE_THRESHOLD, | |
), | |
) | |
# process output | |
# TODO: may want to select these in decreasing order of score | |
image_files = find_images(ARGS.output) | |
outputs = [] | |
for count, part in enumerate(image_files): | |
if count < MAX_PARTS: | |
outputs.append([Image.open(im) for im in image_files[part]]) | |
return [ | |
*[gr.update(value=out[0], visible=True) for out in outputs], | |
*[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))], | |
] | |
def get_trigger(idx: int, fps: int = 40, oscillate: bool = True): | |
def iter_images(*args, **kwargs): | |
if idx < len(outputs): | |
for im in outputs[idx]: | |
time.sleep(1.0 / fps) | |
yield im | |
if oscillate: | |
for im in reversed(outputs[idx]): | |
time.sleep(1.0 / fps) | |
yield im | |
else: | |
raise ValueError("Could not find any images to load into this module.") | |
return iter_images | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# OPDMulti Demo | |
Upload an image to see its range of motion. | |
""" | |
) | |
# TODO: add gr.Examples | |
with gr.Row(): | |
rgb_image = gr.Image( | |
image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True | |
) | |
depth_image = gr.Image( | |
image_mode="I;16", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True | |
) | |
intrinsics = gr.Dataframe( | |
value=[ | |
[ | |
214.85935872395834, | |
0.0, | |
125.90160319010417, | |
], | |
[ | |
0.0, | |
214.85935872395834, | |
95.13726399739583, | |
], | |
[ | |
0.0, | |
0.0, | |
1.0, | |
], | |
], | |
row_count=(3, "fixed"), | |
col_count=(3, "fixed"), | |
datatype="number", | |
type="numpy", | |
label="Intrinsics matrix", | |
show_label=True, | |
interactive=True, | |
) | |
num_samples = gr.Number( | |
value=NUM_SAMPLES, | |
label="Number of samples", | |
show_label=True, | |
interactive=True, | |
precision=0, | |
minimum=3, | |
maximum=20, | |
) | |
examples = gr.Examples( | |
examples=[ | |
["examples/59-4860.png", "examples/59-4860_d.png"], | |
["examples/174-8460.png", "examples/174-8460_d.png"], | |
["examples/187-0.png", "examples/187-0_d.png"], | |
["examples/187-23040.png", "examples/187-23040_d.png"], | |
], | |
inputs=[rgb_image, depth_image], | |
api_name=False, | |
examples_per_page=2, | |
) | |
submit_btn = gr.Button("Run model") | |
# TODO: do we want to set a maximum limit on how many parts we render? We could also show the number of components | |
# identified. | |
images = [gr.Image(type="pil", label=f"Part {idx + 1}", visible=False) for idx in range(MAX_PARTS)] | |
for idx, image_comp in enumerate(images): | |
image_comp.select(get_trigger(idx), inputs=[], outputs=image_comp, api_name=False) | |
submit_btn.click( | |
fn=predict, inputs=[rgb_image, depth_image, intrinsics, num_samples], outputs=images, api_name=False | |
) | |
demo.queue(api_open=False) | |
demo.launch() | |