import tensorflow as tf import pickle as pkl import sys import torch import numpy as np def load_tf_weights(ckpt_path): # https://stackoverflow.com/questions/40118062/how-to-read-weights-saved-in-tensorflow-checkpoint-file from tensorflow.python.training import py_checkpoint_reader reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path) state_dict = {} for k in reader.get_variable_to_shape_map(): if '.OPTIMIZER_SLOT' in k or 'optimizer' in k or '_CHECKPOINTABLE_OBJECT_GRAPH' in k or 'save_counter' in k or 'global_step' in k: continue v = reader.get_tensor(k) state_dict[k.replace('/.ATTRIBUTES/VARIABLE_VALUE', '')] = v for k in sorted(state_dict.keys()): print(k, state_dict[k].shape) return state_dict def map_bn(name1, name2): res = {} res[name1 + '/gamma'] = name2 + ".weight" res[name1 + '/beta'] = name2 + ".bias" res[name1 + '/moving_mean'] = name2 + ".running_mean" res[name1 + '/moving_variance'] = name2 + ".running_var" return res def map_conv(name1, name2, dw=False, bias=False): res = {} if dw: res[name1 + '/depthwise_kernel'] = name2 + ".weight" else: res[name1 + '/kernel'] = name2 + ".weight" if bias: res[name1 + '/bias'] = name2 + ".bias" return res def tf_2_torch_mapping_r50(): res = {} res.update(map_conv('encoder/_stem/_conv', 'backbone.stem.conv1')) res.update(map_bn('encoder/_stem/_batch_norm', 'backbone.stem.conv1.norm')) block_num = {2: 3, 3: 4, 4: 6, 5: 3} for stage_idx in range(2, 6): for block_idx in range(1, block_num[stage_idx] + 1): res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv', f'backbone.res{stage_idx}.{block_idx-1}.conv1')) res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm', f'backbone.res{stage_idx}.{block_idx-1}.conv1.norm')) res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv', f'backbone.res{stage_idx}.{block_idx-1}.conv2')) res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm', f'backbone.res{stage_idx}.{block_idx-1}.conv2.norm')) res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_conv', f'backbone.res{stage_idx}.{block_idx-1}.conv3')) res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm', f'backbone.res{stage_idx}.{block_idx-1}.conv3.norm')) res.update(map_conv(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_conv', f'backbone.res{stage_idx}.{block_idx-1}.shortcut')) res.update(map_bn(f'encoder/_stage{stage_idx}/_block{block_idx}/_shortcut/_batch_norm', f'backbone.res{stage_idx}.{block_idx-1}.shortcut.norm')) return res def tf_2_torch_mapping_convnext(): res = {} for i in range(4): if i == 0: res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-0', f'backbone.downsample_layers.{i}.0', bias=True)) res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-1', f'backbone.downsample_layers.{i}.1')) else: res.update(map_conv(f'encoder/downsample_layers/{i}/layer_with_weights-1', f'backbone.downsample_layers.{i}.1', bias=True)) res.update(map_bn(f'encoder/downsample_layers/{i}/layer_with_weights-0', f'backbone.downsample_layers.{i}.0')) block_num = {0: 3, 1: 3, 2: 27, 3: 3} for stage_idx in range(4): for block_idx in range(block_num[stage_idx]): res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/depthwise_conv', f'backbone.stages.{stage_idx}.{block_idx}.dwconv', bias=True)) res.update(map_bn(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/norm', f'backbone.stages.{stage_idx}.{block_idx}.norm')) res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv1', f'backbone.stages.{stage_idx}.{block_idx}.pwconv1', bias=True)) res.update(map_conv(f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/pointwise_conv2', f'backbone.stages.{stage_idx}.{block_idx}.pwconv2', bias=True)) res[f'encoder/stages/{stage_idx}/layer_with_weights-{block_idx}/layer_scale'] = f'backbone.stages.{stage_idx}.{block_idx}.gamma' return res def tf_2_torch_mapping_pixel_dec(): res = {} for i in range(4): res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}')) res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}')) res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}')) res.update(map_bn(f'pixel_decoder/_backbone_norms/{i}', f'sem_seg_head.pixel_decoder._in_norms.{i}')) for i in range(3): res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_conv', f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.conv')) res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn1/_batch_norm', f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_low.norm')) res.update(map_conv(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_conv', f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.conv')) res.update(map_bn(f'pixel_decoder/_skip_connections/{i}/_resized_conv_bn2/_batch_norm', f'sem_seg_head.pixel_decoder._resized_fuses.{i}._conv_bn_high.norm')) num_blocks = {0: 1, 1:5, 2:1, 3:1} for stage_idx in range(4): for block_idx in range(1, 1+num_blocks[stage_idx]): res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_conv', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.conv')) res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_shortcut/_batch_norm', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._shortcut.norm')) res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_conv', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.conv')) res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv1_bn_act/_batch_norm', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv1_bn_act.norm')) res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_conv', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.conv')) res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv3_bn/_batch_norm', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv3_bn.norm')) if stage_idx <= 1: for attn in ['height', 'width']: res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_qkv', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_qkv')) res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_retrieved_output', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_retrieved_output')) res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_batch_norm_similarity', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._batch_norm_similarity')) res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_key_rpe/embeddings'] = ( f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._key_rpe._embeddings.weight') res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_query_rpe/embeddings'] = ( f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._query_rpe._embeddings.weight') res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/_value_rpe/embeddings'] = ( f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis._value_rpe._embeddings.weight') res[f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_attention/_{attn}_axis/qkv_kernel'] = ( f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._attention._{attn}_axis.qkv_transform.conv.weight') else: res.update(map_conv(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_conv', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.conv')) res.update(map_bn(f'pixel_decoder/_stages/{stage_idx}/_block{block_idx}/_conv2_bn_act/_batch_norm', f'sem_seg_head.pixel_decoder._stages.{stage_idx}._blocks.{block_idx-1}._conv2_bn_act.norm')) return res def tf_2_torch_mapping_predcitor(prefix_tf, prefix_torch): res = {} res.update(map_bn(prefix_tf + 'pixel_space_feature_batch_norm', prefix_torch + '_pixel_space_head_last_convbn.norm')) res[prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel'] = ( prefix_torch + '_pixel_space_head_conv0bnact.conv.weight' ) res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_depthwise/_batch_norm', prefix_torch + '_pixel_space_head_conv0bnact.norm')) res.update(map_conv(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_conv', prefix_torch + '_pixel_space_head_conv1bnact.conv')) res.update(map_bn(prefix_tf + 'pixel_space_head/conv_block/_conv1_bn_act/_pointwise/_batch_norm', prefix_torch + '_pixel_space_head_conv1bnact.norm')) res.update(map_conv(prefix_tf + 'pixel_space_head/final_conv', prefix_torch + '_pixel_space_head_last_convbn.conv', bias=True)) res.update(map_bn(prefix_tf + 'pixel_space_mask_batch_norm', prefix_torch + '_pixel_space_mask_batch_norm')) res.update(map_conv(prefix_tf + 'transformer_class_head/_conv', prefix_torch + '_transformer_class_head.conv', bias=True)) res.update(map_conv(prefix_tf + 'transformer_mask_head/_conv', prefix_torch + '_transformer_mask_head.conv')) res.update(map_bn(prefix_tf + 'transformer_mask_head/_batch_norm', prefix_torch + '_transformer_mask_head.norm')) return res def tf_2_torch_mapping_trans_dec(): res = {} res.update(map_bn('transformer_decoder/_class_embedding_projection/_batch_norm', 'sem_seg_head.predictor._class_embedding_projection.norm')) res.update(map_conv('transformer_decoder/_class_embedding_projection/_conv', 'sem_seg_head.predictor._class_embedding_projection.conv')) res.update(map_bn('transformer_decoder/_mask_embedding_projection/_batch_norm', 'sem_seg_head.predictor._mask_embedding_projection.norm')) res.update(map_conv('transformer_decoder/_mask_embedding_projection/_conv', 'sem_seg_head.predictor._mask_embedding_projection.conv')) res['transformer_decoder/cluster_centers'] = 'sem_seg_head.predictor._cluster_centers.weight' res.update(tf_2_torch_mapping_predcitor( prefix_tf = '', prefix_torch = 'sem_seg_head.predictor._predcitor.' )) for kmax_idx in range(6): res.update(tf_2_torch_mapping_predcitor( prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/_auxiliary_clustering_predictor/_', prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}._predcitor.' )) common_prefix_tf = f'transformer_decoder/_kmax_decoder/{kmax_idx}/_block1_transformer/' common_prefix_torch = f'sem_seg_head.predictor._kmax_transformer_layers.{kmax_idx}.' res.update(map_bn(common_prefix_tf + '_kmeans_memory_batch_norm_retrieved_value', common_prefix_torch + '_kmeans_query_batch_norm_retrieved_value')) res.update(map_bn(common_prefix_tf + '_kmeans_memory_conv3_bn/_batch_norm', common_prefix_torch + '_kmeans_query_conv3_bn.norm')) res.update(map_conv(common_prefix_tf + '_kmeans_memory_conv3_bn/_conv', common_prefix_torch + '_kmeans_query_conv3_bn.conv')) res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_retrieved_value', common_prefix_torch + '_query_self_attention._batch_norm_retrieved_value')) res.update(map_bn(common_prefix_tf + '_memory_attention/_batch_norm_similarity', common_prefix_torch + '_query_self_attention._batch_norm_similarity')) res.update(map_bn(common_prefix_tf + '_memory_conv1_bn_act/_batch_norm', common_prefix_torch + '_query_conv1_bn_act.norm')) res.update(map_conv(common_prefix_tf + '_memory_conv1_bn_act/_conv', common_prefix_torch + '_query_conv1_bn_act.conv')) res.update(map_bn(common_prefix_tf + '_memory_conv3_bn/_batch_norm', common_prefix_torch + '_query_conv3_bn.norm')) res.update(map_conv(common_prefix_tf + '_memory_conv3_bn/_conv', common_prefix_torch + '_query_conv3_bn.conv')) res.update(map_bn(common_prefix_tf + '_memory_ffn_conv1_bn_act/_batch_norm', common_prefix_torch + '_query_ffn_conv1_bn_act.norm')) res.update(map_conv(common_prefix_tf + '_memory_ffn_conv1_bn_act/_conv', common_prefix_torch + '_query_ffn_conv1_bn_act.conv')) res.update(map_bn(common_prefix_tf + '_memory_ffn_conv2_bn/_batch_norm', common_prefix_torch + '_query_ffn_conv2_bn.norm')) res.update(map_conv(common_prefix_tf + '_memory_ffn_conv2_bn/_conv', common_prefix_torch + '_query_ffn_conv2_bn.conv')) res.update(map_bn(common_prefix_tf + '_memory_qkv_conv_bn/_batch_norm', common_prefix_torch + '_query_qkv_conv_bn.norm')) res.update(map_conv(common_prefix_tf + '_memory_qkv_conv_bn/_conv', common_prefix_torch + '_query_qkv_conv_bn.conv')) res.update(map_bn(common_prefix_tf + '_pixel_conv1_bn_act/_batch_norm', common_prefix_torch + '_pixel_conv1_bn_act.norm')) res.update(map_conv(common_prefix_tf + '_pixel_conv1_bn_act/_conv', common_prefix_torch + '_pixel_conv1_bn_act.conv')) res.update(map_bn(common_prefix_tf + '_pixel_v_conv_bn/_batch_norm', common_prefix_torch + '_pixel_v_conv_bn.norm')) res.update(map_conv(common_prefix_tf + '_pixel_v_conv_bn/_conv', common_prefix_torch + '_pixel_v_conv_bn.conv')) return res def tf_2_torch_mapping_aux_semanic_dec(): res = {} res.update(map_conv('semantic_decoder/_aspp/_conv_bn_act/_conv', 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.conv')) res.update(map_bn('semantic_decoder/_aspp/_conv_bn_act/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv0.norm')) res.update(map_conv('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_conv', 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.conv')) res.update(map_bn('semantic_decoder/_aspp/_aspp_pool/_conv_bn_act/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_pool.norm')) res.update(map_conv('semantic_decoder/_aspp/_proj_conv_bn_act/_conv', 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.conv')) res.update(map_bn('semantic_decoder/_aspp/_proj_conv_bn_act/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._proj_conv_bn_act.norm')) for i in range(1, 4): res.update(map_conv(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_conv', f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.conv')) res.update(map_bn(f'semantic_decoder/_aspp/_aspp_conv{i}/_conv_bn_act/_batch_norm', f'sem_seg_head.predictor._auxiliary_semantic_predictor._aspp._aspp_conv{i}.norm')) res.update({ 'semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.conv.weight'}) res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_depthwise/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv0_bn_act.norm')) res.update({ 'semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_conv/kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.conv.weight'}) res.update(map_bn('semantic_decoder/_fusion_conv1/_conv1_bn_act/_pointwise/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os8_conv1_bn_act.norm')) res.update({ 'semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.conv.weight'}) res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_depthwise/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv0_bn_act.norm')) res.update({ 'semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_conv/kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.conv.weight'}) res.update(map_bn('semantic_decoder/_fusion_conv2/_conv1_bn_act/_pointwise/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_fusion_os4_conv1_bn_act.norm')) res.update({ 'semantic_decoder/_low_level_conv1/_conv/kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.conv.weight'}) res.update(map_bn('semantic_decoder/_low_level_conv1/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os8.norm')) res.update({ 'semantic_decoder/_low_level_conv2/_conv/kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.conv.weight'}) res.update(map_bn('semantic_decoder/_low_level_conv2/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor._low_level_projection_os4.norm')) res.update({ 'semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_depthwise_conv/depthwise_kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight'}) res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_depthwise/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.norm')) res.update({ 'semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_conv/kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.conv.weight'}) res.update(map_bn('semantic_head_without_last_layer/_conv1_bn_act/_pointwise/_batch_norm', 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_1.norm')) res.update({ 'semantic_last_layer/kernel': 'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.weight'}) res.update({ 'semantic_last_layer/bias': 'sem_seg_head.predictor._auxiliary_semantic_predictor.final_conv.conv.bias'}) return res # python3 convert-tf-weights-to-d2.py kmax_resnet50_coco_train/ckpt-150000 tf_kmax_r50.pkl if __name__ == "__main__": input = sys.argv[1] state_dict = load_tf_weights(input) #exit() state_dict_torch = {} mapping_key = {} if 'resnet50' in input: mapping_key.update(tf_2_torch_mapping_r50()) elif 'convnext' in input: mapping_key.update(tf_2_torch_mapping_convnext()) mapping_key.update(tf_2_torch_mapping_pixel_dec()) mapping_key.update(tf_2_torch_mapping_trans_dec()) mapping_key.update(tf_2_torch_mapping_aux_semanic_dec()) for k in state_dict.keys(): value = state_dict[k] k2 = mapping_key[k] rank = len(value.shape) if '_batch_norm_retrieved_output' in k2 or '_batch_norm_similarity' in k2 or '_batch_norm_retrieved_value' in k2: value = np.reshape(value, [-1]) elif 'qkv_transform.conv.weight' in k2: # (512, 1024) -> (1024, 512, 1) value = np.transpose(value, (1, 0))[:, :, None] elif '_cluster_centers.weight' in k2: # (1, 128, 256) -> (256, 128) value = np.transpose(value[0], (1, 0)) elif '_pixel_conv1_bn_act.conv.weight' in k2: # (1, 512, 256) -> (256, 512, 1, 1) value = np.transpose(value, (2, 1, 0))[:, :, :, None] elif '_pixel_v_conv_bn.conv.weight' in k2: # (1, 256, 256) -> (256, 256, 1, 1) value = np.transpose(value, (2, 1, 0))[:, :, :, None] elif '_pixel_space_head_conv0bnact.conv.weight' in k2: # (5, 5, 256, 1) -> (256, 1, 5, 5) value = np.transpose(value, (2, 3, 0, 1)) elif '/layer_scale' in k: value = np.reshape(value, [-1]) elif 'pwconv1.weight' in k2 or 'pwconv2.weight' in k2: # (128, 512) -> (512, 128) value = np.transpose(value, (1, 0)) elif ('_low_level_fusion_os4_conv0_bn_act.conv.weight' in k2 or '_low_level_fusion_os8_conv0_bn_act.conv.weight' in k2 or 'sem_seg_head.predictor._auxiliary_semantic_predictor.conv_block_0.conv.weight' in k2): value = np.transpose(value, (2, 3, 0, 1)) else: if rank == 1: # bias, norm etc pass elif rank == 2: # _query_rpe pass elif rank == 3: # conv 1d kernel, etc value = np.transpose(value, (2, 1, 0)) elif rank == 4: # conv 2d kernel, etc value = np.transpose(value, (3, 2, 0, 1)) state_dict_torch[k2] = value res = {"model": state_dict_torch, "__author__": "third_party", "matching_heuristics": True} with open(sys.argv[2], "wb") as f: pkl.dump(res, f) # r50: 52.85 -> 52.71 w/ eps 1e-3 # convnext-base: 56.85 -> 56.97 w/ eps 1e-3