Oilkkkkbb / scripts /modify_checkpoints.py
Shatei's picture
Update space
5e373a9
raw
history blame contribute delete
478 Bytes
# Copyright 2024 Adobe. All rights reserved.
import torch
pretrained_model_path='pretrained_models/sd-v1-4.ckpt'
ckpt_file=torch.load(pretrained_model_path,map_location='cpu')
zero_data=torch.zeros(320,5,3,3)
new_weight=torch.cat((ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight'],zero_data),dim=1)
ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight']=new_weight
torch.save(ckpt_file,"pretrained_models/sd-v1-4-modified-9channel.ckpt")