Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from PIL import Image | |
import requests | |
from datetime import datetime,timedelta | |
import re | |
attn_maps = {} | |
def hook_fn(name): | |
def forward_hook(module, input, output): | |
if hasattr(module.processor, "attn_map"): | |
attn_maps[name] = module.processor.attn_map | |
del module.processor.attn_map | |
return forward_hook | |
def register_cross_attention_hook(unet): | |
for name, module in unet.named_modules(): | |
if name.split('.')[-1].startswith('attn2'): | |
module.register_forward_hook(hook_fn(name)) | |
return unet | |
def upscale(attn_map, target_size): | |
attn_map = torch.mean(attn_map, dim=0) | |
attn_map = attn_map.permute(1,0) | |
temp_size = None | |
for i in range(0,5): | |
scale = 2 ** i | |
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: | |
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) | |
break | |
assert temp_size is not None, "temp_size cannot is None" | |
attn_map = attn_map.view(attn_map.shape[0], *temp_size) | |
attn_map = F.interpolate( | |
attn_map.unsqueeze(0).to(dtype=torch.float32), | |
size=target_size, | |
mode='bilinear', | |
align_corners=False | |
)[0] | |
attn_map = torch.softmax(attn_map, dim=0) | |
return attn_map | |
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): | |
idx = 0 if instance_or_negative else 1 | |
net_attn_maps = [] | |
for name, attn_map in attn_maps.items(): | |
attn_map = attn_map.cpu() if detach else attn_map | |
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() | |
attn_map = upscale(attn_map, image_size) | |
net_attn_maps.append(attn_map) | |
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) | |
return net_attn_maps | |
def attnmaps2images(net_attn_maps): | |
#total_attn_scores = 0 | |
images = [] | |
for attn_map in net_attn_maps: | |
attn_map = attn_map.cpu().numpy() | |
#total_attn_scores += attn_map.mean().item() | |
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 | |
normalized_attn_map = normalized_attn_map.astype(np.uint8) | |
#print("norm: ", normalized_attn_map.shape) | |
image = Image.fromarray(normalized_attn_map) | |
#image = fix_save_attn_map(attn_map) | |
images.append(image) | |
#print(total_attn_scores) | |
return images | |
def is_torch2_available(): | |
return hasattr(F, "scaled_dot_product_attention") | |
class RemoteJson: | |
def __init__(self, url, refresh_gap_seconds=3600, processor=None): | |
""" | |
Initialize the RemoteJsonManager. | |
:param url: The URL of the remote JSON file. | |
:param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed. | |
:param processor: Optional callback function to process the JSON after it's loaded successfully. | |
""" | |
self.url = url | |
self.refresh_gap_seconds = refresh_gap_seconds | |
self.processor = processor | |
self.json_data = None | |
self.last_updated = None | |
def _load_json(self): | |
""" | |
Load JSON from the remote URL. If loading fails, return None. | |
""" | |
try: | |
response = requests.get(self.url) | |
response.raise_for_status() | |
return response.json() | |
except requests.RequestException as e: | |
print(f"Failed to fetch JSON: {e}") | |
return None | |
def _should_refresh(self): | |
""" | |
Check whether the JSON should be refreshed based on the time gap. | |
""" | |
if not self.last_updated: | |
return True # If no last update, always refresh | |
return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds) | |
def _update_json(self): | |
""" | |
Fetch and load the JSON from the remote URL. If it fails, keep the previous data. | |
""" | |
new_json = self._load_json() | |
if new_json: | |
self.json_data = new_json | |
self.last_updated = datetime.now() | |
print("JSON updated successfully.") | |
if self.processor: | |
self.json_data = self.processor(self.json_data) | |
else: | |
print("Failed to update JSON. Keeping the previous version.") | |
def get(self): | |
""" | |
Get the JSON, checking whether it needs to be refreshed. | |
If refresh is required, it fetches the new data and applies the processor. | |
""" | |
if self._should_refresh(): | |
print("Refreshing JSON...") | |
self._update_json() | |
else: | |
print("Using cached JSON.") | |
return self.json_data | |
def extract_key_value_pairs(input_string): | |
# Define the regular expression to match [xxx:yyy] where yyy can have special characters | |
pattern = r"\[([^\]]+):([^\]]+)\]" | |
# Find all matches in the input string with the original matching string | |
matches = re.finditer(pattern, input_string) | |
# Convert matches to a list of dictionaries including the raw matching string | |
result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches] | |
return result | |
def extract_characters(prefix, input_string): | |
# Define the regular expression to match placeholders starting with "@" and ending with space or comma | |
pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)" | |
# Find all matches in the input string | |
matches = re.findall(pattern, input_string) | |
# Return a list of dictionaries with the extracted placeholders | |
result = [{"raw": f"{prefix}{match}", "key": match} for match in matches] | |
return result |