File size: 9,424 Bytes
fd7e9e3
c4a40e9
 
 
 
3352e23
fd7e9e3
c4a40e9
 
 
 
 
 
 
 
 
5319ec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cd7d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0e767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5319ec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7e9e3
c4a40e9
 
 
 
 
 
fd7e9e3
 
 
226e074
fd7e9e3
 
 
c4a40e9
fd7e9e3
c4a40e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7e9e3
c4a40e9
 
3352e23
c4a40e9
 
fd7e9e3
c4a40e9
 
fd7e9e3
c4a40e9
1e5c329
 
 
c4a40e9
4cd7d82
 
1e5c329
 
 
4cd7d82
1e5c329
 
 
c4a40e9
3352e23
1e5c329
c4a40e9
4cd7d82
 
 
1e0e767
 
 
 
 
 
 
 
5319ec3
 
 
fd7e9e3
c4a40e9
3352e23
c4a40e9
1e5c329
fd7e9e3
 
1e5c329
 
fd7e9e3
c4a40e9
 
 
 
 
 
 
5319ec3
c4a40e9
fd7e9e3
 
1e5c329
c4a40e9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import os
import cv2
import torch
import numpy as np
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
from depth_anything_v2.dpt import DepthAnythingV2

# Model initialization
model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
}

class NormalMapSimple:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "normal_map"

    CATEGORY = "image/filters"

    def normal_map(self, images, scale_XY):
        t = images.detach().clone().cpu().numpy().astype(np.float32)
        L = np.mean(t[:,:,:,:3], axis=3)
        for i in range(t.shape[0]):
            t[i,:,:,0] = cv2.Scharr(L[i], -1, 1, 0, cv2.BORDER_REFLECT) * -1
            t[i,:,:,1] = cv2.Scharr(L[i], -1, 0, 1, cv2.BORDER_REFLECT)
        t[:,:,:,2] = 1
        t = torch.from_numpy(t)
        t[:,:,:,:2] *= scale_XY
        t[:,:,:,:3] = torch.nn.functional.normalize(t[:,:,:,:3], dim=3) / 2 + 0.5
        return (t,)

class ConvertNormals:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "normals": ("IMAGE",),
                "input_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
                "output_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
                "scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
                "normalize": ("BOOLEAN", {"default": True}),
                "fix_black": ("BOOLEAN", {"default": True}),
            },
            "optional": {
                "optional_fill": ("IMAGE",),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "convert_normals"

    CATEGORY = "image/filters"

    def convert_normals(self, normals, input_mode, output_mode, scale_XY, normalize, fix_black, optional_fill=None):
        try:
            t = normals.detach().clone()
            
            if input_mode == "BAE":
                t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
            elif input_mode == "MiDaS":
                t[:,:,:,:3] = torch.stack([1 - t[:,:,:,2], t[:,:,:,1], t[:,:,:,0]], dim=3) # BGR -> RGB and invert R
            elif input_mode == "DirectX":
                t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
            
            if fix_black:
                key = torch.clamp(1 - t[:,:,:,2] * 2, min=0, max=1)
                if optional_fill is None:
                    t[:,:,:,0] += key * 0.5
                    t[:,:,:,1] += key * 0.5
                    t[:,:,:,2] += key
                else:
                    fill = optional_fill.detach().clone()
                    if fill.shape[1:3] != t.shape[1:3]:
                        fill = torch.nn.functional.interpolate(fill.movedim(-1,1), size=(t.shape[1], t.shape[2]), mode='bilinear').movedim(1,-1)
                    if fill.shape[0] != t.shape[0]:
                        fill = fill[0].unsqueeze(0).expand(t.shape[0], -1, -1, -1)
                    t[:,:,:,:3] += fill[:,:,:,:3] * key.unsqueeze(3).expand(-1, -1, -1, 3)
            
            t[:,:,:,:2] = (t[:,:,:,:2] - 0.5) * scale_XY + 0.5
            
            if normalize:
                # Transform to [-1, 1] range
                t_norm = t[:,:,:,:3] * 2 - 1
                
                # Calculate the length of each vector
                lengths = torch.sqrt(torch.sum(t_norm**2, dim=3, keepdim=True))
                
                # Avoid division by zero
                lengths = torch.clamp(lengths, min=1e-6)
                
                # Normalize each vector to unit length
                t_norm = t_norm / lengths
                
                # Transform back to [0, 1] range
                t[:,:,:,:3] = (t_norm + 1) / 2
            
            if output_mode == "BAE":
                t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
            elif output_mode == "MiDaS":
                t[:,:,:,:3] = torch.stack([t[:,:,:,2], t[:,:,:,1], 1 - t[:,:,:,0]], dim=3) # invert R and BGR -> RGB
            elif output_mode == "DirectX":
                t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
            
            return (t,)
        except Exception as e:
            print(f"Error in convert_normals: {str(e)}")
            return (normals,)

def get_image_intensity(img, gamma_correction=1.0):
    """
    Extract intensity map from an image using HSV color space
    """
    # Convert to HSV color space
    result = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    # Extract Value channel (intensity)
    result = result[:, :, 2].astype(np.float32) / 255.0
    # Apply gamma correction
    result = result ** gamma_correction
    # Convert back to 0-255 range
    result = (result * 255.0).clip(0, 255).astype(np.uint8)
    # Convert to RGB (still grayscale but in RGB format)
    result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
    return result

def blend_numpy_images(image1, image2, blend_factor=0.4, mode="normal"):
    """
    Blend two numpy images using normal mode
    """
    # Convert to float32 and normalize to 0-1
    img1 = image1.astype(np.float32) / 255.0
    img2 = image2.astype(np.float32) / 255.0
    
    # Normal blend mode
    blended = img1 * (1 - blend_factor) + img2 * blend_factor
    
    # Convert back to uint8
    blended = (blended * 255.0).clip(0, 255).astype(np.uint8)
    return blended

def process_normal_map(image):
    """
    Process image through NormalMapSimple and ConvertNormals
    """
    # Convert numpy image to torch tensor with batch dimension
    image_tensor = torch.from_numpy(image).unsqueeze(0).float() / 255.0
    
    # Create instances of the classes
    normal_map_generator = NormalMapSimple()
    normal_converter = ConvertNormals()
    
    # Generate initial normal map
    normal_map = normal_map_generator.normal_map(image_tensor, scale_XY=1.0)[0]
    
    # Convert normal map from Standard to DirectX
    converted_normal = normal_converter.convert_normals(
        normal_map,
        input_mode="Standard",
        output_mode="DirectX",
        scale_XY=1.0,
        normalize=True,
        fix_black=True
    )[0]
    
    # Convert back to numpy array
    result = (converted_normal.squeeze(0).numpy() * 255).astype(np.uint8)
    return result

# Download and initialize model
def initialize_model():
    encoder = 'vitl'
    max_depth = 1
    
    model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
    
    # Download model from private repo
    model_path = hf_hub_download(
        "NightRaven109/DepthAnythingv2custom",
        "model95.pth",
        use_auth_token=os.environ['Read']
    )
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Get state dict
    state_dict = {}
    for key in checkpoint.keys():
        if key not in ['optimizer', 'epoch', 'previous_best']:
            state_dict = checkpoint[key]
    
    # Handle module prefix
    my_state_dict = {}
    for key in state_dict.keys():
        new_key = key.replace('module.', '')
        my_state_dict[new_key] = state_dict[key]
    
    model.load_state_dict(my_state_dict)
    return model

# Initialize model at startup
MODEL = initialize_model()

@spaces.GPU
def process_image(input_image):
    """
    Process the input image and return depth map and normal map
    """
    if input_image is None:
        return None, None
    
    # Move model to GPU for processing
    MODEL.to('cuda')
    MODEL.eval()
    
    # Convert from RGB to BGR for depth processing
    input_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
    
    with torch.no_grad():
        # Get depth map
        depth = MODEL.infer_image(input_bgr)
        
        # Normalize depth for visualization (0-255)
        depth_normalized = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
    
    # Move model back to CPU
    MODEL.to('cpu')
    
    # Get intensity map
    intensity_map = get_image_intensity(np.array(input_image), gamma_correction=1.0)
    
    # Blend depth raw with intensity map
    blended_result = blend_numpy_images(
        cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB),  # Convert depth to RGB
        intensity_map,
        blend_factor=0.4,
        mode="normal"
    )
    
    # Generate normal map from blended result
    normal_map = process_normal_map(blended_result)
    
    return depth_normalized, normal_map

@spaces.GPU
def gradio_interface(input_img):
    try:
        depth_raw, normal = process_image(input_img)
        return [depth_raw, normal]
    except Exception as e:
        print(f"Error processing image: {str(e)}")
        return [None, None]

# Define interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(label="Input Image"),
    outputs=[
        gr.Image(label="Raw Depth Map"),
        gr.Image(label="Normal Map")
    ],
    title="Depth and Normal Map Generation",
    description="Upload an image to generate its depth map and normal map.",
    examples=["image.jpg"]
)

# Launch the app
if __name__ == "__main__":
    iface.launch()