Spaces:
Runtime error
Runtime error
File size: 23,406 Bytes
a06fad0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
import tensorflow as tf
import pickle as pkl
import sys
import torch
import numpy as np
def load_tf_weights(ckpt_path):
# https://stackoverflow.com/questions/40118062/how-to-read-weights-saved-in-tensorflow-checkpoint-file
from tensorflow.python.training import py_checkpoint_reader
reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
state_dict = {}
for k in reader.get_variable_to_shape_map():
if '.OPTIMIZER_SLOT' in k or 'optimizer' in k or '_CHECKPOINTABLE_OBJECT_GRAPH' in k or 'save_counter' in k or 'global_step' in k:
continue
v = reader.get_tensor(k)
state_dict[k.replace('/.ATTRIBUTES/VARIABLE_VALUE', '')] = v
for k in sorted(state_dict.keys()):
print(k, state_dict[k].shape)
return state_dict
def map_bn(name1, name2):
res = {}
res[name1 + '/gamma'] = name2 + ".weight"
res[name1 + '/beta'] = name2 + ".bias"
res[name1 + '/moving_mean'] = name2 + ".running_mean"
res[name1 + '/moving_variance'] = name2 + ".running_var"
return res
def map_conv(name1, name2, dw=False, bias=False):
res = {}
if dw:
res[name1 + '/depthwise_kernel'] = name2 + ".weight"
else:
res[name1 + '/kernel'] = name2 + ".weight"
if bias:
res[name1 + '/bias'] = name2 + ".bias"
return res
def tf_2_torch_mapping_r50():
res = {}
res.update(map_conv('encoder/_stem/_conv', 'backbone.stem.conv1'))
res.update(map_bn('encoder/_stem/_batch_norm', 'backbone.stem.conv1.norm'))
block_num = {2: 3, 3: 4, 4: 6, 5: 3}
for stage_idx in range(2, 6):
for block_idx in range(1, block_num[stage_idx] + 1):
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
f'backbone.res{stage_idx}.{block_idx-1}.conv1'))
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
f'backbone.res{stage_idx}.{block_idx-1}.conv1.norm'))
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
f'backbone.res{stage_idx}.{block_idx-1}.conv2'))
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
f'backbone.res{stage_idx}.{block_idx-1}.conv2.norm'))
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
f'backbone.res{stage_idx}.{block_idx-1}.conv3'))
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
f'backbone.res{stage_idx}.{block_idx-1}.conv3.norm'))
res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_conv',
f'backbone.res{stage_idx}.{block_idx-1}.shortcut'))
res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
f'backbone.res{stage_idx}.{block_idx-1}.shortcut.norm'))
return res
def tf_2_torch_mapping_convnext():
res = {}
for i in range(4):
if i == 0:
res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-0',
f'backbone.downsample_layers.{i}.0', bias=True))
res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-1',
f'backbone.downsample_layers.{i}.1'))
else:
res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-1',
f'backbone.downsample_layers.{i}.1', bias=True))
res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-0',
f'backbone.downsample_layers.{i}.0'))
block_num = {0: 3, 1: 3, 2: 27, 3: 3}
for stage_idx in range(4):
for block_idx in range(block_num[stage_idx]):
res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/depthwise_conv',
f'backbone.stages.{stage_idx}.{block_idx}.dwconv', bias=True))
res.update(map_bn(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/norm',
f'backbone.stages.{stage_idx}.{block_idx}.norm'))
res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv1',
f'backbone.stages.{stage_idx}.{block_idx}.pwconv1', bias=True))
res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv2',
f'backbone.stages.{stage_idx}.{block_idx}.pwconv2', bias=True))
res[f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/layer_scale'] = f'backbone.stages.{stage_idx}.{block_idx}.gamma'
return res
def tf_2_torch_mapping_pixel_dec():
res = {}
for i in range(4):
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}'))
for i in range(3):
res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_conv',
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.conv'))
res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_batch_norm',
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.norm'))
res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_conv',
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.conv'))
res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_batch_norm',
f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.norm'))
num_blocks = {0: 1, 1:5, 2:1, 3:1}
for stage_idx in range(4):
for block_idx in range(1, 1+num_blocks[stage_idx]):
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_conv',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.conv'))
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_batch_norm',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.norm'))
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.conv'))
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.norm'))
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_conv',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.conv'))
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.norm'))
if stage_idx <= 1:
for attn in ['height', 'width']:
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_qkv',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_qkv'))
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_retrieved_output',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_retrieved_output'))
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_similarity',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_similarity'))
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_key_rpe/embeddings'] = (
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._key_rpe._embeddings.weight')
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_query_rpe/embeddings'] = (
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._query_rpe._embeddings.weight')
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_value_rpe/embeddings'] = (
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._value_rpe._embeddings.weight')
res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/qkv_kernel'] = (
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis.qkv_transform.conv.weight')
else:
res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.conv'))
res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm',
f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.norm'))
return res
def tf_2_torch_mapping_predcitor(prefix_tf, prefix_torch):
res = {}
res.update(map_bn(prefix_tf + 'pixel_space_feature_batch_norm',
prefix_torch + '_pixel_space_head_last_convbn.norm'))
res[prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel'] = (
prefix_torch + '_pixel_space_head_conv0bnact.conv.weight'
)
res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_batch_norm',
prefix_torch + '_pixel_space_head_conv0bnact.norm'))
res.update(map_conv(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_conv',
prefix_torch + '_pixel_space_head_conv1bnact.conv'))
res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_batch_norm',
prefix_torch + '_pixel_space_head_conv1bnact.norm'))
res.update(map_conv(prefix_tf + 'pixel_space_head/final_conv',
prefix_torch + '_pixel_space_head_last_convbn.conv', bias=True))
res.update(map_bn(prefix_tf + 'pixel_space_mask_batch_norm',
prefix_torch + '_pixel_space_mask_batch_norm'))
res.update(map_conv(prefix_tf + 'transformer_class_head/_conv',
prefix_torch + '_transformer_class_head.conv', bias=True))
res.update(map_conv(prefix_tf + 'transformer_mask_head/_conv',
prefix_torch + '_transformer_mask_head.conv'))
res.update(map_bn(prefix_tf + 'transformer_mask_head/_batch_norm',
prefix_torch + '_transformer_mask_head.norm'))
return res
def tf_2_torch_mapping_trans_dec():
res = {}
res.update(map_bn('transformer_decoder/_class_embedding_projection/_batch_norm',
'sem_seg_head.predictor._class_embedding_projection.norm'))
res.update(map_conv('transformer_decoder/_class_embedding_projection/_conv',
'sem_seg_head.predictor._class_embedding_projection.conv'))
res.update(map_bn('transformer_decoder/_mask_embedding_projection/_batch_norm',
'sem_seg_head.predictor._mask_embedding_projection.norm'))
res.update(map_conv('transformer_decoder/_mask_embedding_projection/_conv',
'sem_seg_head.predictor._mask_embedding_projection.conv'))
res['transformer_decoder/cluster_centers'] = 'sem_seg_head.predictor._cluster_centers.weight'
res.update(tf_2_torch_mapping_predcitor(
prefix_tf = '',
prefix_torch = 'sem_seg_head.predictor._predcitor.'
))
for kmax_idx in range(6):
res.update(tf_2_torch_mapping_predcitor(
prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/_auxiliary_clustering_predictor/_',
prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}._predcitor.'
))
common_prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/'
common_prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}.'
res.update(map_bn(common_prefix_tf + '_kmeans_memory_batch_norm_retrieved_value',
common_prefix_torch + '_kmeans_query_batch_norm_retrieved_value'))
res.update(map_bn(common_prefix_tf + '_kmeans_memory_conv3_bn/_batch_norm',
common_prefix_torch + '_kmeans_query_conv3_bn.norm'))
res.update(map_conv(common_prefix_tf + '_kmeans_memory_conv3_bn/_conv',
common_prefix_torch + '_kmeans_query_conv3_bn.conv'))
res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_retrieved_value',
common_prefix_torch + '_query_self_attention._batch_norm_retrieved_value'))
res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_similarity',
common_prefix_torch + '_query_self_attention._batch_norm_similarity'))
res.update(map_bn(common_prefix_tf + '_memory_conv1_bn_act/_batch_norm',
common_prefix_torch + '_query_conv1_bn_act.norm'))
res.update(map_conv(common_prefix_tf + '_memory_conv1_bn_act/_conv',
common_prefix_torch + '_query_conv1_bn_act.conv'))
res.update(map_bn(common_prefix_tf + '_memory_conv3_bn/_batch_norm',
common_prefix_torch + '_query_conv3_bn.norm'))
res.update(map_conv(common_prefix_tf + '_memory_conv3_bn/_conv',
common_prefix_torch + '_query_conv3_bn.conv'))
res.update(map_bn(common_prefix_tf + '_memory_ffn_conv1_bn_act/_batch_norm',
common_prefix_torch + '_query_ffn_conv1_bn_act.norm'))
res.update(map_conv(common_prefix_tf + '_memory_ffn_conv1_bn_act/_conv',
common_prefix_torch + '_query_ffn_conv1_bn_act.conv'))
res.update(map_bn(common_prefix_tf + '_memory_ffn_conv2_bn/_batch_norm',
common_prefix_torch + '_query_ffn_conv2_bn.norm'))
res.update(map_conv(common_prefix_tf + '_memory_ffn_conv2_bn/_conv',
common_prefix_torch + '_query_ffn_conv2_bn.conv'))
res.update(map_bn(common_prefix_tf + '_memory_qkv_conv_bn/_batch_norm',
common_prefix_torch + '_query_qkv_conv_bn.norm'))
res.update(map_conv(common_prefix_tf + '_memory_qkv_conv_bn/_conv',
common_prefix_torch + '_query_qkv_conv_bn.conv'))
res.update(map_bn(common_prefix_tf + '_pixel_conv1_bn_act/_batch_norm',
common_prefix_torch + '_pixel_conv1_bn_act.norm'))
res.update(map_conv(common_prefix_tf + '_pixel_conv1_bn_act/_conv',
common_prefix_torch + '_pixel_conv1_bn_act.conv'))
res.update(map_bn(common_prefix_tf + '_pixel_v_conv_bn/_batch_norm',
common_prefix_torch + '_pixel_v_conv_bn.norm'))
res.update(map_conv(common_prefix_tf + '_pixel_v_conv_bn/_conv',
common_prefix_torch + '_pixel_v_conv_bn.conv'))
return res
def tf_2_torch_mapping_aux_semanic_dec():
res = {}
res.update(map_conv('semantic_decoder/_aspp/_conv_bn_act/_conv',
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.conv'))
res.update(map_bn('semantic_decoder/_aspp/_conv_bn_act/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.norm'))
res.update(map_conv('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_conv',
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.conv'))
res.update(map_bn('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.norm'))
res.update(map_conv('semantic_decoder/_aspp/_proj_conv_bn_act/_conv',
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.conv'))
res.update(map_bn('semantic_decoder/_aspp/_proj_conv_bn_act/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.norm'))
for i in range(1, 4):
res.update(map_conv(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_conv',
f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.conv'))
res.update(map_bn(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_batch_norm',
f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.norm'))
res.update({
'semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.conv.weight'})
res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.norm'))
res.update({
'semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_conv/kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.conv.weight'})
res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.norm'))
res.update({
'semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.conv.weight'})
res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.norm'))
res.update({
'semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_conv/kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.conv.weight'})
res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.norm'))
res.update({
'semantic_decoder/_low_level_conv1/_conv/kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.conv.weight'})
res.update(map_bn('semantic_decoder/_low_level_conv1/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.norm'))
res.update({
'semantic_decoder/_low_level_conv2/_conv/kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.conv.weight'})
res.update(map_bn('semantic_decoder/_low_level_conv2/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.norm'))
res.update({
'semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight'})
res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.norm'))
res.update({
'semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_conv/kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.conv.weight'})
res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_batch_norm',
'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.norm'))
res.update({
'semantic_last_layer/kernel':
'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.weight'})
res.update({
'semantic_last_layer/bias':
'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.bias'})
return res
# python3 convert-tf-weights-to-d2.py kmax_resnet50_coco_train/ckpt-150000 tf_kmax_r50.pkl
if __name__ == "__main__":
input = sys.argv[1]
state_dict = load_tf_weights(input)
#exit()
state_dict_torch = {}
mapping_key = {}
if 'resnet50' in input:
mapping_key.update(tf_2_torch_mapping_r50())
elif 'convnext' in input:
mapping_key.update(tf_2_torch_mapping_convnext())
mapping_key.update(tf_2_torch_mapping_pixel_dec())
mapping_key.update(tf_2_torch_mapping_trans_dec())
mapping_key.update(tf_2_torch_mapping_aux_semanic_dec())
for k in state_dict.keys():
value = state_dict[k]
k2 = mapping_key[k]
rank = len(value.shape)
if '_batch_norm_retrieved_output' in k2 or '_batch_norm_similarity' in k2 or '_batch_norm_retrieved_value' in k2:
value = np.reshape(value, [-1])
elif 'qkv_transform.conv.weight' in k2:
# (512, 1024) -> (1024, 512, 1)
value = np.transpose(value, (1, 0))[:, :, None]
elif '_cluster_centers.weight' in k2:
# (1, 128, 256) -> (256, 128)
value = np.transpose(value[0], (1, 0))
elif '_pixel_conv1_bn_act.conv.weight' in k2:
# (1, 512, 256) -> (256, 512, 1, 1)
value = np.transpose(value, (2, 1, 0))[:, :, :, None]
elif '_pixel_v_conv_bn.conv.weight' in k2:
# (1, 256, 256) -> (256, 256, 1, 1)
value = np.transpose(value, (2, 1, 0))[:, :, :, None]
elif '_pixel_space_head_conv0bnact.conv.weight' in k2:
# (5, 5, 256, 1) -> (256, 1, 5, 5)
value = np.transpose(value, (2, 3, 0, 1))
elif '/layer_scale' in k:
value = np.reshape(value, [-1])
elif 'pwconv1.weight' in k2 or 'pwconv2.weight' in k2:
# (128, 512) -> (512, 128)
value = np.transpose(value, (1, 0))
elif ('_low_level_fusion_os4_conv0_bn_act.conv.weight' in k2
or '_low_level_fusion_os8_conv0_bn_act.conv.weight' in k2
or 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight' in k2):
value = np.transpose(value, (2, 3, 0, 1))
else:
if rank == 1: # bias, norm etc
pass
elif rank == 2: # _query_rpe
pass
elif rank == 3: # conv 1d kernel, etc
value = np.transpose(value, (2, 1, 0))
elif rank == 4: # conv 2d kernel, etc
value = np.transpose(value, (3, 2, 0, 1))
state_dict_torch[k2] = value
res = {"model": state_dict_torch, "__author__": "third_party", "matching_heuristics": True}
with open(sys.argv[2], "wb") as f:
pkl.dump(res, f)
# r50: 52.85 -> 52.71 w/ eps 1e-3
# convnext-base: 56.85 -> 56.97 w/ eps 1e-3 |