svjack's picture
Upload folder using huggingface_hub
ce68674 verified
import torch
def conv_forward(self):
def forward(input_tensor, temb, scale=1.0):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
# import pdb; pdb.set_trace()
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None].repeat(1, 1, hidden_states.shape[2], 1, 1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
# record hidden state
self.record_hidden_state = hidden_states
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
return forward
def get_conv_feat(unet):
hidden_state_dict = dict()
for i in range(len(unet.up_blocks)):
for j in range(len(unet.up_blocks[i].resnets)):
module = unet.up_blocks[i].resnets[j]
module_name = f"up_blocks.{i}.resnets.{j}"
# print(module_name)
hidden_state_dict[module_name] = module.record_hidden_state
return hidden_state_dict
def prep_unet_conv(unet):
for i in range(len(unet.up_blocks)):
for j in range(len(unet.up_blocks[i].resnets)):
module = unet.up_blocks[i].resnets[j]
module.forward = conv_forward(module)
return unet