File size: 11,424 Bytes
01664b3
 
5ceacf4
 
01664b3
e5d2d7f
01664b3
20c01c5
01664b3
 
e5d2d7f
17456cf
5ceacf4
01664b3
 
 
 
 
6d737eb
01664b3
17456cf
7c581be
01664b3
 
7c581be
01664b3
 
 
 
6d737eb
01664b3
17456cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ceacf4
 
92d915f
17456cf
5ceacf4
 
 
 
 
17456cf
5ceacf4
 
 
 
 
17456cf
3c7feee
5ceacf4
 
 
 
 
 
 
92d915f
 
 
 
 
 
 
 
17456cf
7c581be
01664b3
 
 
 
 
 
 
 
17456cf
01664b3
 
 
 
 
5ceacf4
01664b3
 
92d915f
5ceacf4
 
17456cf
92d915f
6d737eb
 
92d915f
6d737eb
 
 
 
17456cf
e5d2d7f
17456cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92d915f
 
 
 
17456cf
 
 
92d915f
17456cf
6d737eb
17456cf
 
 
6d737eb
92d915f
17456cf
6d737eb
17456cf
01664b3
6d737eb
92d915f
6d737eb
 
01664b3
 
92d915f
17456cf
 
 
92d915f
 
 
3f76c42
 
 
 
 
 
 
 
 
 
 
 
 
 
01664b3
 
3f76c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01664b3
3f76c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01664b3
3f76c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20c01c5
3f76c42
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import os
import re
import shutil
import time
from types import SimpleNamespace
from typing import Any, Callable, Generator

import gradio as gr
import numpy as np
from detectron2 import engine
from huggingface_hub import hf_hub_download
from natsort import natsorted
from PIL import Image

from inference import main, setup_cfg

# internal settings
NUM_PROCESSES = 1
CROP = False
SCORE_THRESHOLD = 0.8
MAX_PARTS = 5  # TODO: we can replace this by having a slider and a single image visualization component rather than multiple components
HF_MODEL_PATH = {"repo_id": "3dlg-hcvc/opdmulti-motion-state-rgb-model", "filename": "pytorch_model.pth"}
ARGS = SimpleNamespace(
    config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml",
    model=None,
    input_format="RGB",
    output=".output",
    cpu=True,
)
NUM_SAMPLES = 10

# this variable holds the current state of results, as the user will need to be able to "reload" the results in order
# to visualize the demo again. The output images are cached by the temporary path of the image, meaning that multiple
# users should be able to simultaneously run the demo. Gradio should be able to handle the case where multiple distinct
# images are uploaded with the same name, as I believe the caching of temp path is based on base64 encoding, not the
# filename itself.
# TODO: right now there is no gc system for outputs, which means if there is enough traffic per unit time such that the
# outputs are all generated on the same system instantiation of the code, the RAM could max out, acknowledging also that
# this is not designed to run on GPU and so the model and all will also need to be stored in CPU memory. Solutions could
# include
#  1. a caching design to remove old results periodically, especially if the image is reset;
#  2. caching results on disk rather than in memory, since the cap is higher; or
#  3. figuring out some way to cache results in browser instead of in the backend (couldn't figure out a way to do this
#     earlier.
outputs: dict[str, list[list[Image.Image]]] = {}


def predict(rgb_image: str, depth_image: str, intrinsic: np.ndarray, num_samples: int) -> list[Any]:
    """
    Run model on input image and generate output visualizations.

    :param rgb_image: local path to RGB image file, used for model prediction and visualization
    :param depth_image: local path to depth image file, used for visualization
    :param intrinsic: array of dimension (3, 3) representing the intrinsic matrix of the camera
    :param num_samples: number of visualization states to generate.
    :return: list of updates to make to image components to visualize first image of visualization sequence, or
    otherwise to hide an image component from visualization.
    """
    global outputs

    def find_images(path: str) -> dict[str, list[str]]:
        """Scrape folders for all generated image files."""
        images = {}
        for file in os.listdir(path):
            sub_path = os.path.join(path, file)
            if os.path.isdir(sub_path):
                images[file] = []
                for image_file in natsorted(os.listdir(sub_path)):
                    if re.match(r".*\.png$", image_file):
                        images[file].append(os.path.join(sub_path, image_file))
        return images

    # clear old predictions
    # TODO: might be a better place for this than at the beginning of every invocation
    os.makedirs(ARGS.output, exist_ok=True)
    for path in os.listdir(ARGS.output):
        full_path = os.path.join(ARGS.output, path)
        if os.path.isdir(full_path):
            shutil.rmtree(full_path)
        else:
            os.remove(full_path)

    if not rgb_image:
        gr.Error("You must provide an RGB image before running the model.")
        return [None] * 5

    if not depth_image:
        gr.Error("You must provide a depth image before running the model.")
        return [None] * 5

    # run model
    ARGS.model = hf_hub_download(repo_id=HF_MODEL_PATH["repo_id"], filename=HF_MODEL_PATH["filename"])
    cfg = setup_cfg(ARGS)
    engine.launch(
        main,
        NUM_PROCESSES,
        args=(
            cfg,
            rgb_image,
            depth_image,
            intrinsic,
            num_samples,
            CROP,
            SCORE_THRESHOLD,
        ),
    )

    # process output
    # TODO: may want to select these in decreasing order of score
    outputs[rgb_image] = []
    image_files = find_images(ARGS.output)
    for count, part in enumerate(image_files):
        if count < MAX_PARTS:  # only visualize up to MAX_PARTS parts
            outputs[rgb_image].append([Image.open(im) for im in image_files[part]])

    return [
        *[gr.update(value=out[0], visible=True) for out in outputs[rgb_image]],
        *[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))],
    ]


def get_trigger(
    idx: int, fps: int = 15, oscillate: bool = True
) -> Callable[[str], Generator[Image.Image, None, None]]:
    """
    Return event listener trigger function for image component to animate image sequence.

    :param idx: index of part to animate from output
    :param fps: approximate rate at which images should be cycled through in frames per second. Note that the fps cannot
    be higher than the rate at which images can be returned and rendered. Defaults to 40
    :param oscillate: if True, animates part in reverse after running from start to end. Defaults to True
    """

    def iter_images(rgb_image: str) -> Generator[Image.Image, None, None]:
        """Iterator to yield sequence of images for rendering, based on temp RGB image path"""
        start_time = time.time()

        def wait_until_next_frame(frame_count: int) -> None:
            """wait until appropriate time per the specified fps, relative to start time of iteration"""
            time_to_sleep = max(frame_count / fps - (time.time() - start_time), 0)
            if time_to_sleep <= 0:
                print("[WARNING] frames cannot be rendered at the specified FPS due to processing/rendering time.")
            time.sleep(time_to_sleep)

        if not rgb_image or rgb_image not in outputs:
            gr.Warning("You must upload an image and run the model before you can view the output.")

        elif idx < len(outputs[rgb_image]):
            frame_count = 0

            # iterate forward
            for im in outputs[rgb_image][idx]:
                wait_until_next_frame(frame_count)
                yield im
                frame_count += 1

            # iterate in reverse
            if oscillate:
                for im in reversed(outputs[rgb_image][idx]):
                    wait_until_next_frame(frame_count)
                    yield im
                    frame_count += 1

        else:
            gr.Error("Could not find any images to load into this module.")

    return iter_images


def clear_outputs():
    """
    Remove images from image components.
    """
    return [gr.update(value=None, visible=(idx == 0)) for idx in range(MAX_PARTS)]


def run():
    with gr.Blocks() as demo:
        gr.Markdown(
            """
        # OPDMulti Demo
        We tackle the openable-part-detection (OPD) problem where we identify in a single-view image parts that are openable and their motion parameters. Our OPDFORMER architecture outputs segmentations for openable parts on potentially multiple objects, along with each part’s motion parameters: motion type (translation or rotation, indicated by blue or purple mask), motion axis and origin (see green arrows and points). For each openable part, we predict the motion parameters (axis and origin) in object coordinates along with an object pose prediction to convert to camera coordinates.
    
        More information about the project, including code, can be found [here](https://3dlg-hcvc.github.io/OPDMulti/).
    
        Upload an image to see a visualization of its range of motion below. Only the RGB image is needed for the model itself, but the depth image is required as of now for the visualization of motion.
    
        If you know the intrinsic matrix of your camera, you can specify that here or otherwise use the default matrix which will work with any of the provided examples.
    
        You can also change the number of samples to define the number of states in the visualization generated.
        """
        )
    
        # inputs
        with gr.Row():
            rgb_image = gr.Image(
                image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True
            )
            depth_image = gr.Image(
                image_mode="I;16", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True
            )
    
        intrinsic = gr.Dataframe(
            value=[
                [
                    214.85935872395834,
                    0.0,
                    125.90160319010417,
                ],
                [
                    0.0,
                    214.85935872395834,
                    95.13726399739583,
                ],
                [
                    0.0,
                    0.0,
                    1.0,
                ],
            ],
            row_count=(3, "fixed"),
            col_count=(3, "fixed"),
            datatype="number",
            type="numpy",
            label="Intrinsic matrix",
            show_label=True,
            interactive=True,
        )
        num_samples = gr.Number(
            value=NUM_SAMPLES,
            label="Number of samples",
            show_label=True,
            interactive=True,
            precision=0,
            minimum=3,
            maximum=20,
        )
    
        # specify examples which can be used to start
        examples = gr.Examples(
            examples=[
                ["examples/59-4860.png", "examples/59-4860_d.png"],
                ["examples/174-8460.png", "examples/174-8460_d.png"],
                ["examples/187-0.png", "examples/187-0_d.png"],
                ["examples/187-23040.png", "examples/187-23040_d.png"],
            ],
            inputs=[rgb_image, depth_image],
            api_name=False,
            examples_per_page=2,
        )
    
        submit_btn = gr.Button("Run model")
    
        # output
        explanation = gr.Markdown(
            value=f"# Output\nClick on an image to see an animation of the part motion. As of now, only up to {MAX_PARTS} parts can be visualized due to limitations of the visualizer."
        )
    
        images = [
            gr.Image(type="pil", label=f"Part {idx + 1}", show_download_button=False, visible=(idx == 0))
            for idx in range(MAX_PARTS)
        ]
        for idx, image_comp in enumerate(images):
            image_comp.select(get_trigger(idx), inputs=rgb_image, outputs=image_comp, api_name=False)
    
        # if user changes input, clear output images
        rgb_image.change(clear_outputs, inputs=[], outputs=images, api_name=False)
        depth_image.change(clear_outputs, inputs=[], outputs=images, api_name=False)
    
        submit_btn.click(
            fn=predict, inputs=[rgb_image, depth_image, intrinsic, num_samples], outputs=images, api_name=False
        )
    
    demo.queue(api_open=False)
    demo.launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    print("Starting up app...")
    run()