Spaces:
Runtime error
Runtime error
import torch | |
def conv_forward(self): | |
def forward(input_tensor, temb, scale=1.0): | |
hidden_states = input_tensor | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
# import pdb; pdb.set_trace() | |
if self.upsample is not None: | |
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
if hidden_states.shape[0] >= 64: | |
input_tensor = input_tensor.contiguous() | |
hidden_states = hidden_states.contiguous() | |
input_tensor = self.upsample(input_tensor) | |
hidden_states = self.upsample(hidden_states) | |
elif self.downsample is not None: | |
input_tensor = self.downsample(input_tensor) | |
hidden_states = self.downsample(hidden_states) | |
hidden_states = self.conv1(hidden_states) | |
if temb is not None: | |
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None].repeat(1, 1, hidden_states.shape[2], 1, 1) | |
if temb is not None and self.time_embedding_norm == "default": | |
hidden_states = hidden_states + temb | |
hidden_states = self.norm2(hidden_states) | |
if temb is not None and self.time_embedding_norm == "scale_shift": | |
scale, shift = torch.chunk(temb, 2, dim=1) | |
hidden_states = hidden_states * (1 + scale) + shift | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.conv2(hidden_states) | |
# record hidden state | |
self.record_hidden_state = hidden_states | |
if self.conv_shortcut is not None: | |
input_tensor = self.conv_shortcut(input_tensor) | |
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | |
return output_tensor | |
return forward | |
def get_conv_feat(unet): | |
hidden_state_dict = dict() | |
for i in range(len(unet.up_blocks)): | |
for j in range(len(unet.up_blocks[i].resnets)): | |
module = unet.up_blocks[i].resnets[j] | |
module_name = f"up_blocks.{i}.resnets.{j}" | |
# print(module_name) | |
hidden_state_dict[module_name] = module.record_hidden_state | |
return hidden_state_dict | |
def prep_unet_conv(unet): | |
for i in range(len(unet.up_blocks)): | |
for j in range(len(unet.up_blocks[i].resnets)): | |
module = unet.up_blocks[i].resnets[j] | |
module.forward = conv_forward(module) | |
return unet | |