|
import copy |
|
import itertools |
|
import os |
|
from pathlib import Path |
|
import html |
|
import gc |
|
from collections import OrderedDict |
|
|
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from torch import optim |
|
|
|
from modules import shared, scripts |
|
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer |
|
from tqdm.auto import tqdm, trange |
|
from modules.shared import opts, device |
|
|
|
|
|
aesthetic_embeddings_dir = os.path.join(scripts.basedir(), "aesthetic_embeddings") |
|
os.makedirs(aesthetic_embeddings_dir, exist_ok=True) |
|
|
|
aesthetic_embeddings = {} |
|
|
|
|
|
def update_aesthetic_embeddings(): |
|
global aesthetic_embeddings |
|
aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(aesthetic_embeddings_dir, f) for f in os.listdir(aesthetic_embeddings_dir) if f.endswith(".pt")} |
|
aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) |
|
|
|
|
|
update_aesthetic_embeddings() |
|
|
|
|
|
def get_all_images_in_folder(folder): |
|
return [os.path.join(folder, f) for f in os.listdir(folder) if |
|
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] |
|
|
|
|
|
def check_is_valid_image_file(filename): |
|
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) |
|
|
|
|
|
def batched(dataset, total, n=1): |
|
for ndx in range(0, total, n): |
|
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] |
|
|
|
|
|
def iter_to_batched(iterable, n=1): |
|
it = iter(iterable) |
|
while True: |
|
chunk = tuple(itertools.islice(it, n)) |
|
if not chunk: |
|
return |
|
yield chunk |
|
|
|
|
|
def create_ui(): |
|
import modules.ui |
|
|
|
with gr.Group(): |
|
with gr.Accordion("Open for Clip Aesthetic!", open=False): |
|
with gr.Row(): |
|
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", |
|
value=0.9) |
|
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) |
|
|
|
with gr.Row(): |
|
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', |
|
placeholder="Aesthetic learning rate", value="0.0001") |
|
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) |
|
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), |
|
label="Aesthetic imgs embedding", |
|
value="None") |
|
|
|
modules.ui.create_refresh_button(aesthetic_imgs, update_aesthetic_embeddings, lambda: {"choices": sorted(aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") |
|
|
|
with gr.Row(): |
|
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', |
|
placeholder="This text is used to rotate the feature space of the imgs embs", |
|
value="") |
|
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, |
|
value=0.1) |
|
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) |
|
|
|
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative |
|
|
|
|
|
aesthetic_clip_model = None |
|
|
|
|
|
def aesthetic_clip(): |
|
global aesthetic_clip_model |
|
|
|
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: |
|
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) |
|
aesthetic_clip_model.cpu() |
|
|
|
return aesthetic_clip_model |
|
|
|
|
|
def generate_imgs_embd(name, folder, batch_size): |
|
model = aesthetic_clip().to(device) |
|
processor = CLIPProcessor.from_pretrained(model.name_or_path) |
|
|
|
with torch.no_grad(): |
|
embs = [] |
|
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), |
|
desc=f"Generating embeddings for {name}"): |
|
if shared.state.interrupted: |
|
break |
|
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) |
|
outputs = model.get_image_features(**inputs).cpu() |
|
embs.append(torch.clone(outputs)) |
|
inputs.to("cpu") |
|
del inputs, outputs |
|
|
|
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) |
|
|
|
|
|
path = str(Path(aesthetic_embeddings_dir) / f"{name}.pt") |
|
torch.save(embs, path) |
|
|
|
model.cpu() |
|
del processor |
|
del embs |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
res = f""" |
|
Done generating embedding for {name}! |
|
Aesthetic embedding saved to {html.escape(path)} |
|
""" |
|
update_aesthetic_embeddings() |
|
return res |
|
|
|
|
|
def slerp(low, high, val): |
|
low_norm = low / torch.norm(low, dim=1, keepdim=True) |
|
high_norm = high / torch.norm(high, dim=1, keepdim=True) |
|
omega = torch.acos((low_norm * high_norm).sum(1)) |
|
so = torch.sin(omega) |
|
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high |
|
return res |
|
|
|
|
|
class AestheticCLIP: |
|
def __init__(self): |
|
self.skip = False |
|
self.aesthetic_steps = 0 |
|
self.aesthetic_weight = 0 |
|
self.aesthetic_lr = 0 |
|
self.slerp = False |
|
self.aesthetic_text_negative = "" |
|
self.aesthetic_slerp_angle = 0 |
|
self.aesthetic_imgs_text = "" |
|
|
|
self.image_embs_name = None |
|
self.image_embs = None |
|
self.load_image_embs(None) |
|
self.process_tokens = None |
|
|
|
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, |
|
aesthetic_slerp=True, aesthetic_imgs_text="", |
|
aesthetic_slerp_angle=0.15, |
|
aesthetic_text_negative=False): |
|
self.aesthetic_imgs_text = aesthetic_imgs_text |
|
self.aesthetic_slerp_angle = aesthetic_slerp_angle |
|
self.aesthetic_text_negative = aesthetic_text_negative |
|
self.slerp = aesthetic_slerp |
|
self.aesthetic_lr = aesthetic_lr |
|
self.aesthetic_weight = aesthetic_weight |
|
self.aesthetic_steps = aesthetic_steps |
|
self.load_image_embs(image_embs_name) |
|
|
|
if self.image_embs_name is not None: |
|
p.extra_generation_params.update({ |
|
"Aesthetic LR": aesthetic_lr, |
|
"Aesthetic weight": aesthetic_weight, |
|
"Aesthetic steps": aesthetic_steps, |
|
"Aesthetic embedding": self.image_embs_name, |
|
"Aesthetic slerp": aesthetic_slerp, |
|
"Aesthetic text": aesthetic_imgs_text, |
|
"Aesthetic text negative": aesthetic_text_negative, |
|
"Aesthetic slerp angle": aesthetic_slerp_angle, |
|
}) |
|
|
|
def set_skip(self, skip): |
|
self.skip = skip |
|
|
|
def load_image_embs(self, image_embs_name): |
|
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": |
|
image_embs_name = None |
|
self.image_embs_name = None |
|
if image_embs_name is not None and self.image_embs_name != image_embs_name: |
|
self.image_embs_name = image_embs_name |
|
self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) |
|
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) |
|
self.image_embs.requires_grad_(False) |
|
|
|
def __call__(self, remade_batch_tokens, multipliers): |
|
z = self.process_tokens(remade_batch_tokens, multipliers) |
|
|
|
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: |
|
tokenizer = shared.sd_model.cond_stage_model.tokenizer |
|
if not opts.use_old_emphasis_implementation: |
|
remade_batch_tokens = [[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in remade_batch_tokens] |
|
|
|
tokens = torch.asarray(remade_batch_tokens).to(device) |
|
|
|
model = copy.deepcopy(aesthetic_clip()).to(device) |
|
model.requires_grad_(True) |
|
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: |
|
text_embs_2 = model.get_text_features( |
|
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) |
|
if self.aesthetic_text_negative: |
|
text_embs_2 = self.image_embs - text_embs_2 |
|
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) |
|
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) |
|
else: |
|
img_embs = self.image_embs |
|
|
|
with torch.enable_grad(): |
|
|
|
|
|
optimizer = optim.Adam( |
|
model.text_model.parameters(), lr=self.aesthetic_lr |
|
) |
|
|
|
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): |
|
text_embs = model.get_text_features(input_ids=tokens) |
|
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) |
|
sim = text_embs @ img_embs.T |
|
loss = -sim |
|
optimizer.zero_grad() |
|
loss.mean().backward() |
|
optimizer.step() |
|
|
|
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) |
|
if opts.CLIP_stop_at_last_layers > 1: |
|
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] |
|
zn = model.text_model.final_layer_norm(zn) |
|
else: |
|
zn = zn.last_hidden_state |
|
model.cpu() |
|
del model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) |
|
if self.slerp: |
|
z = slerp(z, zn, self.aesthetic_weight) |
|
else: |
|
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight |
|
|
|
return z |
|
|