File size: 23,406 Bytes
a06fad0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import tensorflow as tf
import pickle as pkl
import sys

import torch
import numpy as np

def load_tf_weights(ckpt_path):
    # https://stackoverflow.com/questions/40118062/how-to-read-weights-saved-in-tensorflow-checkpoint-file
    from tensorflow.python.training import py_checkpoint_reader
    reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
    state_dict = {}
    for k in reader.get_variable_to_shape_map():
        if '.OPTIMIZER_SLOT' in k or 'optimizer' in k or '_CHECKPOINTABLE_OBJECT_GRAPH' in k or 'save_counter' in k or 'global_step' in k:
            continue
        v = reader.get_tensor(k)
        state_dict[k.replace('/.ATTRIBUTES/VARIABLE_VALUE', '')] = v
    for k in sorted(state_dict.keys()):
        print(k, state_dict[k].shape)
    return state_dict

def map_bn(name1, name2):
    res = {}
    res[name1 + '/gamma'] = name2 + ".weight"
    res[name1 + '/beta'] = name2 + ".bias"
    res[name1 + '/moving_mean'] = name2 + ".running_mean"
    res[name1 + '/moving_variance'] = name2 + ".running_var"
    return res


def map_conv(name1, name2, dw=False, bias=False):
    res = {}
    if dw:
        res[name1 + '/depthwise_kernel'] = name2 + ".weight"
    else:
        res[name1 + '/kernel'] = name2 + ".weight"
    if bias:
        res[name1 + '/bias'] = name2 + ".bias"
    return res


def tf_2_torch_mapping_r50():
    res = {}
    res.update(map_conv('encoder/_stem/_conv', 'backbone.stem.conv1'))
    res.update(map_bn('encoder/_stem/_batch_norm', 'backbone.stem.conv1.norm'))
    block_num = {2: 3, 3: 4, 4: 6, 5: 3}
    for stage_idx in range(2, 6):
        for block_idx in range(1, block_num[stage_idx] + 1):
            res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
             f'backbone.res{stage_idx}.{block_idx-1}.conv1'))
            res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
             f'backbone.res{stage_idx}.{block_idx-1}.conv1.norm'))
            res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
             f'backbone.res{stage_idx}.{block_idx-1}.conv2'))
            res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
             f'backbone.res{stage_idx}.{block_idx-1}.conv2.norm'))
            res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
             f'backbone.res{stage_idx}.{block_idx-1}.conv3'))
            res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
             f'backbone.res{stage_idx}.{block_idx-1}.conv3.norm'))
            res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_conv',
             f'backbone.res{stage_idx}.{block_idx-1}.shortcut'))
            res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
             f'backbone.res{stage_idx}.{block_idx-1}.shortcut.norm'))
    return res

def tf_2_torch_mapping_convnext():
    res = {}
    for i in range(4):
        if i == 0:
            res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-0',
                f'backbone.downsample_layers.{i}.0', bias=True))
            res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-1',
                f'backbone.downsample_layers.{i}.1'))
        else:
            res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-1',
                f'backbone.downsample_layers.{i}.1', bias=True))
            res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-0',
                f'backbone.downsample_layers.{i}.0'))
    
    block_num = {0: 3, 1: 3, 2: 27, 3: 3}
    for stage_idx in range(4):
        for block_idx in range(block_num[stage_idx]):
            res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/depthwise_conv',
                f'backbone.stages.{stage_idx}.{block_idx}.dwconv', bias=True))
            res.update(map_bn(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/norm',
                f'backbone.stages.{stage_idx}.{block_idx}.norm'))
            res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv1',
                f'backbone.stages.{stage_idx}.{block_idx}.pwconv1', bias=True))
            res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv2',
                f'backbone.stages.{stage_idx}.{block_idx}.pwconv2', bias=True))
            res[f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/layer_scale'] = f'backbone.stages.{stage_idx}.{block_idx}.gamma'
    
    return res

def tf_2_torch_mapping_pixel_dec():
    res = {}
    for i in range(4):
        res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
        res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
        res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
        res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
    
    for i in range(3):
        res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_conv',
         f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.conv'))
        res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_batch_norm',
         f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.norm'))
        res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_conv',
         f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.conv'))
        res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_batch_norm',
         f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.norm'))

    num_blocks = {0: 1, 1:5, 2:1, 3:1}
    for stage_idx in range(4):
        for block_idx in range(1, 1+num_blocks[stage_idx]):
            res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_conv',
                f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.conv'))
            res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
                f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.norm'))
            res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
                f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.conv'))
            res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
                f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.norm'))
            res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
                f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.conv'))
            res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
                f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.norm'))
            if stage_idx <= 1:
                for attn in ['height', 'width']:
                    res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_qkv',
                    f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_qkv'))
                    res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_retrieved_output',
                    f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_retrieved_output'))
                    res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_similarity',
                    f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_similarity'))
                    res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_key_rpe/embeddings'] = (
                        f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._key_rpe._embeddings.weight')
                    res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_query_rpe/embeddings'] = (
                        f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._query_rpe._embeddings.weight')
                    res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_value_rpe/embeddings'] = (
                        f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._value_rpe._embeddings.weight')
                    res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/qkv_kernel'] = (
                        f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis.qkv_transform.conv.weight')
            else:
                res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
                    f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.conv'))
                res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
                    f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.norm'))
    return res     


def tf_2_torch_mapping_predcitor(prefix_tf, prefix_torch):
    res = {}
    res.update(map_bn(prefix_tf + 'pixel_space_feature_batch_norm',
        prefix_torch + '_pixel_space_head_last_convbn.norm'))
    res[prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel'] = (
        prefix_torch + '_pixel_space_head_conv0bnact.conv.weight'
    )
    res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_batch_norm',
        prefix_torch + '_pixel_space_head_conv0bnact.norm'))
    res.update(map_conv(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_conv',
        prefix_torch + '_pixel_space_head_conv1bnact.conv'))
    res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_batch_norm',
        prefix_torch + '_pixel_space_head_conv1bnact.norm'))
    res.update(map_conv(prefix_tf + 'pixel_space_head/final_conv',
        prefix_torch + '_pixel_space_head_last_convbn.conv', bias=True))
    res.update(map_bn(prefix_tf + 'pixel_space_mask_batch_norm',
        prefix_torch + '_pixel_space_mask_batch_norm'))
    res.update(map_conv(prefix_tf + 'transformer_class_head/_conv',
        prefix_torch + '_transformer_class_head.conv', bias=True))
    res.update(map_conv(prefix_tf + 'transformer_mask_head/_conv',
        prefix_torch + '_transformer_mask_head.conv'))
    res.update(map_bn(prefix_tf + 'transformer_mask_head/_batch_norm',
        prefix_torch + '_transformer_mask_head.norm'))
    
    return res


def tf_2_torch_mapping_trans_dec():
    res = {}

    res.update(map_bn('transformer_decoder/_class_embedding_projection/_batch_norm',
        'sem_seg_head.predictor._class_embedding_projection.norm'))
    res.update(map_conv('transformer_decoder/_class_embedding_projection/_conv',
        'sem_seg_head.predictor._class_embedding_projection.conv'))
    res.update(map_bn('transformer_decoder/_mask_embedding_projection/_batch_norm',
        'sem_seg_head.predictor._mask_embedding_projection.norm'))
    res.update(map_conv('transformer_decoder/_mask_embedding_projection/_conv',
        'sem_seg_head.predictor._mask_embedding_projection.conv'))

    res['transformer_decoder/cluster_centers'] = 'sem_seg_head.predictor._cluster_centers.weight'

    res.update(tf_2_torch_mapping_predcitor(
            prefix_tf = '',
            prefix_torch = 'sem_seg_head.predictor._predcitor.'
        ))
    for kmax_idx in range(6):
        res.update(tf_2_torch_mapping_predcitor(
            prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/_auxiliary_clustering_predictor/_',
            prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}._predcitor.'
        ))
        common_prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/'
        common_prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}.'
        res.update(map_bn(common_prefix_tf + '_kmeans_memory_batch_norm_retrieved_value',
            common_prefix_torch + '_kmeans_query_batch_norm_retrieved_value'))
        res.update(map_bn(common_prefix_tf + '_kmeans_memory_conv3_bn/_batch_norm',
            common_prefix_torch + '_kmeans_query_conv3_bn.norm'))
        res.update(map_conv(common_prefix_tf + '_kmeans_memory_conv3_bn/_conv',
            common_prefix_torch + '_kmeans_query_conv3_bn.conv'))
        res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_retrieved_value',
            common_prefix_torch + '_query_self_attention._batch_norm_retrieved_value'))
        res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_similarity',
            common_prefix_torch + '_query_self_attention._batch_norm_similarity'))

        res.update(map_bn(common_prefix_tf + '_memory_conv1_bn_act/_batch_norm',
            common_prefix_torch + '_query_conv1_bn_act.norm'))
        res.update(map_conv(common_prefix_tf + '_memory_conv1_bn_act/_conv',
            common_prefix_torch + '_query_conv1_bn_act.conv'))

        res.update(map_bn(common_prefix_tf + '_memory_conv3_bn/_batch_norm',
            common_prefix_torch + '_query_conv3_bn.norm'))
        res.update(map_conv(common_prefix_tf + '_memory_conv3_bn/_conv',
            common_prefix_torch + '_query_conv3_bn.conv'))

        res.update(map_bn(common_prefix_tf + '_memory_ffn_conv1_bn_act/_batch_norm',
            common_prefix_torch + '_query_ffn_conv1_bn_act.norm'))
        res.update(map_conv(common_prefix_tf + '_memory_ffn_conv1_bn_act/_conv',
            common_prefix_torch + '_query_ffn_conv1_bn_act.conv'))

        res.update(map_bn(common_prefix_tf + '_memory_ffn_conv2_bn/_batch_norm',
            common_prefix_torch + '_query_ffn_conv2_bn.norm'))
        res.update(map_conv(common_prefix_tf + '_memory_ffn_conv2_bn/_conv',
            common_prefix_torch + '_query_ffn_conv2_bn.conv'))

        res.update(map_bn(common_prefix_tf + '_memory_qkv_conv_bn/_batch_norm',
            common_prefix_torch + '_query_qkv_conv_bn.norm'))
        res.update(map_conv(common_prefix_tf + '_memory_qkv_conv_bn/_conv',
            common_prefix_torch + '_query_qkv_conv_bn.conv'))

        res.update(map_bn(common_prefix_tf + '_pixel_conv1_bn_act/_batch_norm',
            common_prefix_torch + '_pixel_conv1_bn_act.norm'))
        res.update(map_conv(common_prefix_tf + '_pixel_conv1_bn_act/_conv',
            common_prefix_torch + '_pixel_conv1_bn_act.conv'))

        res.update(map_bn(common_prefix_tf + '_pixel_v_conv_bn/_batch_norm',
            common_prefix_torch + '_pixel_v_conv_bn.norm'))
        res.update(map_conv(common_prefix_tf + '_pixel_v_conv_bn/_conv',
            common_prefix_torch + '_pixel_v_conv_bn.conv'))

    return res


def tf_2_torch_mapping_aux_semanic_dec():
    res = {}
    res.update(map_conv('semantic_decoder/_aspp/_conv_bn_act/_conv',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.conv'))
    res.update(map_bn('semantic_decoder/_aspp/_conv_bn_act/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.norm'))
    
    res.update(map_conv('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_conv',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.conv'))
    res.update(map_bn('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.norm'))

    res.update(map_conv('semantic_decoder/_aspp/_proj_conv_bn_act/_conv',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.conv'))
    res.update(map_bn('semantic_decoder/_aspp/_proj_conv_bn_act/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.norm'))
    for i in range(1, 4):
        res.update(map_conv(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_conv',
             f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.conv'))
        res.update(map_bn(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_batch_norm',
             f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.norm'))
    
    res.update({
        'semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.conv.weight'})
    res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.norm'))
    res.update({
        'semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_conv/kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.conv.weight'})
    res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.norm'))
    
    res.update({
        'semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.conv.weight'})
    res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.norm'))
    res.update({
        'semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_conv/kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.conv.weight'})
    res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.norm'))

    res.update({
        'semantic_decoder/_low_level_conv1/_conv/kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.conv.weight'})
    res.update(map_bn('semantic_decoder/_low_level_conv1/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.norm'))
    res.update({
        'semantic_decoder/_low_level_conv2/_conv/kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.conv.weight'})
    res.update(map_bn('semantic_decoder/_low_level_conv2/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.norm'))
    

    res.update({
        'semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight'})
    res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.norm'))
    res.update({
        'semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_conv/kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.conv.weight'})
    res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_batch_norm',
             'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.norm'))

    res.update({
        'semantic_last_layer/kernel':
         'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.weight'})
    res.update({
        'semantic_last_layer/bias':
         'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.bias'})
    return res


# python3 convert-tf-weights-to-d2.py kmax_resnet50_coco_train/ckpt-150000 tf_kmax_r50.pkl

if __name__ == "__main__":
    input = sys.argv[1]

    state_dict = load_tf_weights(input)
    #exit()

    state_dict_torch = {}

    mapping_key = {}
    if 'resnet50' in input:
        mapping_key.update(tf_2_torch_mapping_r50())
    elif 'convnext' in input:
        mapping_key.update(tf_2_torch_mapping_convnext())
    mapping_key.update(tf_2_torch_mapping_pixel_dec())
    mapping_key.update(tf_2_torch_mapping_trans_dec())

    mapping_key.update(tf_2_torch_mapping_aux_semanic_dec())

    for k in state_dict.keys():
        value = state_dict[k]
        k2 = mapping_key[k]
        rank = len(value.shape)

        if '_batch_norm_retrieved_output' in k2 or '_batch_norm_similarity' in k2 or '_batch_norm_retrieved_value' in k2:
            value = np.reshape(value, [-1])
        elif 'qkv_transform.conv.weight' in k2:
            # (512, 1024) -> (1024, 512, 1)
            value = np.transpose(value, (1, 0))[:, :, None]
        elif '_cluster_centers.weight' in k2:
            # (1, 128, 256) -> (256, 128)
            value = np.transpose(value[0], (1, 0))
        elif '_pixel_conv1_bn_act.conv.weight' in k2:
            # (1, 512, 256) -> (256, 512, 1, 1)
            value = np.transpose(value, (2, 1, 0))[:, :, :, None]
        elif '_pixel_v_conv_bn.conv.weight' in k2:
            # (1, 256, 256) -> (256, 256, 1, 1) 
            value = np.transpose(value, (2, 1, 0))[:, :, :, None]
        elif '_pixel_space_head_conv0bnact.conv.weight' in k2:
            # (5, 5, 256, 1) -> (256, 1, 5, 5)
            value = np.transpose(value, (2, 3, 0, 1))
        elif '/layer_scale' in k:
            value = np.reshape(value, [-1])
        elif 'pwconv1.weight' in k2 or 'pwconv2.weight' in k2:
            # (128, 512) -> (512, 128)
            value = np.transpose(value, (1, 0))
        elif ('_low_level_fusion_os4_conv0_bn_act.conv.weight' in k2 
        or '_low_level_fusion_os8_conv0_bn_act.conv.weight' in k2 
        or 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight' in k2):
            value = np.transpose(value, (2, 3, 0, 1))
        else:
            if rank == 1: # bias, norm etc
                pass
            elif rank == 2: # _query_rpe
                pass
            elif rank == 3: # conv 1d kernel, etc
                value = np.transpose(value, (2, 1, 0))
            elif rank == 4: # conv 2d kernel, etc
                value = np.transpose(value, (3, 2, 0, 1))

        state_dict_torch[k2] = value

    res = {"model": state_dict_torch, "__author__": "third_party", "matching_heuristics": True}

    with open(sys.argv[2], "wb") as f:
        pkl.dump(res, f)


# r50: 52.85 -> 52.71 w/ eps 1e-3
# convnext-base: 56.85 -> 56.97 w/ eps 1e-3