|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from PIL import Image |
|
|
|
BLOCKS = { |
|
'content': ['down_blocks'], |
|
'style': ["up_blocks"], |
|
|
|
} |
|
|
|
controlnet_BLOCKS = { |
|
'content': [], |
|
'style': ["down_blocks"], |
|
} |
|
|
|
|
|
def resize_width_height(width, height, min_short_side=512, max_long_side=1024): |
|
|
|
if width < height: |
|
|
|
if width < min_short_side: |
|
scale_factor = min_short_side / width |
|
new_width = min_short_side |
|
new_height = int(height * scale_factor) |
|
else: |
|
new_width, new_height = width, height |
|
else: |
|
|
|
if height < min_short_side: |
|
scale_factor = min_short_side / height |
|
new_width = int(width * scale_factor) |
|
new_height = min_short_side |
|
else: |
|
new_width, new_height = width, height |
|
|
|
if max(new_width, new_height) > max_long_side: |
|
scale_factor = max_long_side / max(new_width, new_height) |
|
new_width = int(new_width * scale_factor) |
|
new_height = int(new_height * scale_factor) |
|
return new_width, new_height |
|
|
|
def resize_content(content_image): |
|
max_long_side = 1024 |
|
min_short_side = 1024 |
|
|
|
new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1], |
|
min_short_side=min_short_side, max_long_side=max_long_side) |
|
height = new_height // 16 * 16 |
|
width = new_width // 16 * 16 |
|
content_image = content_image.resize((width, height)) |
|
|
|
return width,height,content_image |
|
|
|
attn_maps = {} |
|
def hook_fn(name): |
|
def forward_hook(module, input, output): |
|
if hasattr(module.processor, "attn_map"): |
|
attn_maps[name] = module.processor.attn_map |
|
del module.processor.attn_map |
|
|
|
return forward_hook |
|
|
|
def register_cross_attention_hook(unet): |
|
for name, module in unet.named_modules(): |
|
if name.split('.')[-1].startswith('attn2'): |
|
module.register_forward_hook(hook_fn(name)) |
|
|
|
return unet |
|
|
|
def upscale(attn_map, target_size): |
|
attn_map = torch.mean(attn_map, dim=0) |
|
attn_map = attn_map.permute(1,0) |
|
temp_size = None |
|
|
|
for i in range(0,5): |
|
scale = 2 ** i |
|
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: |
|
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) |
|
break |
|
|
|
assert temp_size is not None, "temp_size cannot is None" |
|
|
|
attn_map = attn_map.view(attn_map.shape[0], *temp_size) |
|
|
|
attn_map = F.interpolate( |
|
attn_map.unsqueeze(0).to(dtype=torch.float32), |
|
size=target_size, |
|
mode='bilinear', |
|
align_corners=False |
|
)[0] |
|
|
|
attn_map = torch.softmax(attn_map, dim=0) |
|
return attn_map |
|
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): |
|
|
|
idx = 0 if instance_or_negative else 1 |
|
net_attn_maps = [] |
|
|
|
for name, attn_map in attn_maps.items(): |
|
attn_map = attn_map.cpu() if detach else attn_map |
|
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() |
|
attn_map = upscale(attn_map, image_size) |
|
net_attn_maps.append(attn_map) |
|
|
|
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) |
|
|
|
return net_attn_maps |
|
|
|
def attnmaps2images(net_attn_maps): |
|
|
|
|
|
images = [] |
|
|
|
for attn_map in net_attn_maps: |
|
attn_map = attn_map.cpu().numpy() |
|
|
|
|
|
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 |
|
normalized_attn_map = normalized_attn_map.astype(np.uint8) |
|
|
|
image = Image.fromarray(normalized_attn_map) |
|
|
|
|
|
images.append(image) |
|
|
|
|
|
return images |
|
def is_torch2_available(): |
|
return hasattr(F, "scaled_dot_product_attention") |
|
|
|
def get_generator(seed, device): |
|
|
|
if seed is not None: |
|
if isinstance(seed, list): |
|
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] |
|
else: |
|
generator = torch.Generator(device).manual_seed(seed) |
|
else: |
|
generator = None |
|
|
|
return generator |