File size: 3,076 Bytes
cdc2be3
 
 
 
 
 
 
 
 
e512885
cdc2be3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e512885
 
 
 
 
 
 
 
 
cdc2be3
 
 
 
e512885
cdc2be3
 
 
 
 
 
 
 
e512885
cdc2be3
 
 
e512885
 
cdc2be3
 
 
 
 
 
 
 
 
 
 
 
 
 
e512885
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import depth_pro
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from typing import Union
from pathlib import Path
import os

def predict_depth(image: Image.Image, auto_rotate: bool, remove_alpha: bool, grayscale: bool, model, transform):
    # Convert the PIL image to a temporary file path if needed
    image_path = "temp_image.jpg"
    image.save(image_path)

    # Load and preprocess the image from the given path
    loaded_image, _, f_px = depth_pro.load_rgb(image_path, auto_rotate=auto_rotate, remove_alpha=remove_alpha)
    loaded_image = transform(loaded_image)

    # Run inference
    prediction = model.infer(loaded_image, f_px=f_px)
    depth = prediction["depth"].detach().cpu().numpy().squeeze()  # Depth in [m]

    inverse_depth = 1 / depth
    # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization.
    max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1)
    min_invdepth_vizu = max(1 / 250, inverse_depth.min())
    inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
            max_invdepth_vizu - min_invdepth_vizu
    )

    focallength = prediction["focallength_px"].cpu().numpy()

    if grayscale:
        # Normalize the inverse depth map to 0-255 and convert to grayscale
        grayscale_depth = (inverse_depth_normalized * 255).astype(np.uint8)
        depth_image = Image.fromarray(grayscale_depth, mode="L")
    else:
        # Normalize and colorize depth map
        cmap = plt.get_cmap("turbo_r")
        color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8)
        depth_image = Image.fromarray(color_depth)

    # Clean up temporary image
    os.remove(image_path)

    return depth_image, focallength  # Return depth map and f_px

def main():
    # Load model and preprocessing transform
    model, transform = depth_pro.create_model_and_transforms()
    model.eval()

    # Set up Gradio interface
    iface = gr.Interface(
        fn=lambda image, auto_rotate, remove_alpha, grayscale: predict_depth(image, auto_rotate, remove_alpha, grayscale, model, transform),
        inputs=[
            gr.Image(type="pil", label="Upload Image"),  # Use image browser for input
            gr.Checkbox(label="Auto Rotate", value=True),  # Checkbox for auto_rotate
            gr.Checkbox(label="Remove Alpha", value=True),  # Checkbox for remove_alpha
            gr.Checkbox(label="Grayscale Depth", value=False)  # Checkbox for grayscale
        ],
        outputs=[
            gr.Image(label="Depth Map"),  # Use PIL image output
            gr.Textbox(label="Focal Length in Pixels", placeholder="Focal length")  # Output for f_px
        ],
        title="Depth Pro: Sharp Monocular Metric Depth Estimation",  # Set the title to "Depth Pro"
        description="Upload an image and adjust options to estimate its depth map using a depth estimation model.",
        allow_flagging=False  # Disable the flag button
    )

    # Launch the interface
    iface.launch()

if __name__ == "__main__":
    main()