Spaces:
Runtime error
Runtime error
Update models/unet.py
Browse files- models/unet.py +45 -45
models/unet.py
CHANGED
@@ -640,53 +640,53 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|
640 |
# 'CrossAttnUpBlock3D']}
|
641 |
|
642 |
model = cls.from_config(config)
|
643 |
-
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
644 |
-
if not os.path.isfile(model_file):
|
645 |
-
|
646 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
647 |
-
|
648 |
-
if use_concat:
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
else:
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
|
691 |
return model
|
692 |
|
|
|
640 |
# 'CrossAttnUpBlock3D']}
|
641 |
|
642 |
model = cls.from_config(config)
|
643 |
+
# model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
644 |
+
# if not os.path.isfile(model_file):
|
645 |
+
# raise RuntimeError(f"{model_file} does not exist")
|
646 |
+
# state_dict = torch.load(model_file, map_location="cpu")
|
647 |
+
|
648 |
+
# if use_concat:
|
649 |
+
# new_state_dict = {}
|
650 |
+
# conv_in_weight = state_dict["conv_in.weight"]
|
651 |
+
# new_conv_weight = torch.zeros((conv_in_weight.shape[0], 9, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
|
652 |
|
653 |
+
# for i, j in zip([0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7, 8]):
|
654 |
+
# new_conv_weight[:, j] = conv_in_weight[:, i]
|
655 |
+
# new_state_dict["conv_in.weight"] = new_conv_weight
|
656 |
+
# new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
|
657 |
+
# for k, v in model.state_dict().items():
|
658 |
+
# # print(k)
|
659 |
+
# if '_temp.' in k:
|
660 |
+
# new_state_dict.update({k: v})
|
661 |
+
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
662 |
+
# k = k.replace('attn_fcross', 'attn1')
|
663 |
+
# state_dict.update({k: state_dict[k]})
|
664 |
+
# if 'norm_fcross' in k:
|
665 |
+
# k = k.replace('norm_fcross', 'norm1')
|
666 |
+
# state_dict.update({k: state_dict[k]})
|
667 |
|
668 |
+
# if 'conv_in' in k:
|
669 |
+
# continue
|
670 |
+
# else:
|
671 |
+
# new_state_dict[k] = v
|
672 |
+
# # # tmp
|
673 |
+
# # if 'class_embedding' in k:
|
674 |
+
# # state_dict.update({k: v})
|
675 |
+
# # breakpoint()
|
676 |
+
# model.load_state_dict(new_state_dict)
|
677 |
+
# else:
|
678 |
+
# for k, v in model.state_dict().items():
|
679 |
+
# # print(k)
|
680 |
+
# if '_temp' in k:
|
681 |
+
# state_dict.update({k: v})
|
682 |
+
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
683 |
+
# k = k.replace('attn_fcross', 'attn1')
|
684 |
+
# state_dict.update({k: state_dict[k]})
|
685 |
+
# if 'norm_fcross' in k:
|
686 |
+
# k = k.replace('norm_fcross', 'norm1')
|
687 |
+
# state_dict.update({k: state_dict[k]})
|
688 |
+
|
689 |
+
# model.load_state_dict(state_dict)
|
690 |
|
691 |
return model
|
692 |
|