|
import comfy.clip_vision |
|
import comfy.clip_model |
|
import comfy.model_management |
|
import comfy.utils |
|
from comfy.sd import CLIP |
|
from itertools import zip_longest |
|
from transformers import CLIPImageProcessor |
|
from transformers.image_utils import PILImageResampling |
|
from collections import Counter |
|
import folder_paths |
|
import torch |
|
import os |
|
from .model import PhotoMakerIDEncoder |
|
from .utils import load_image, tokenize_with_weights, prepImage, crop_image_pil, LoadImageCustom |
|
from folder_paths import folder_names_and_paths, models_dir, supported_pt_extensions, add_model_folder_path |
|
from torch import Tensor |
|
import hashlib |
|
|
|
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) |
|
add_model_folder_path("loras", folder_names_and_paths["photomaker"][0][0]) |
|
|
|
class PhotoMakerLoader: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), ), |
|
}} |
|
RETURN_TYPES = ("PHOTOMAKER",) |
|
FUNCTION = "load_photomaker_model" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
def load_photomaker_model(self, photomaker_model_name): |
|
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) |
|
photomaker_model = PhotoMakerIDEncoder() |
|
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) |
|
if "id_encoder" in data: |
|
data = data["id_encoder"] |
|
photomaker_model.load_state_dict(data) |
|
return (photomaker_model,) |
|
|
|
|
|
class PhotoMakerEncodePlus: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"clip": ("CLIP",), |
|
"photomaker": ("PHOTOMAKER",), |
|
"image": ("IMAGE",), |
|
"trigger_word": ("STRING", {"default": "img"}), |
|
"text": ("STRING", {"multiline": True, "default": "photograph of a man img", "dynamicPrompts": True}), |
|
}, |
|
} |
|
RETURN_TYPES = ("CONDITIONING",) |
|
FUNCTION = "apply_photomaker" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
@torch.no_grad() |
|
def apply_photomaker(self, clip: CLIP, photomaker: PhotoMakerIDEncoder, image: Tensor, trigger_word: str, text: str): |
|
if (num_id_images:=len(image)) == 0: |
|
raise ValueError("No image provided or found.") |
|
trigger_word=trigger_word.strip() |
|
tokens = clip.tokenize(text) |
|
class_tokens_mask = {} |
|
for key in tokens: |
|
clip_tokenizer = getattr(clip.tokenizer, f'clip_{key}', clip.tokenizer) |
|
tkwp = tokenize_with_weights(clip_tokenizer, text, return_tokens=True) |
|
|
|
class_token = clip_tokenizer.tokenizer(trigger_word)["input_ids"][clip_tokenizer.tokens_start:-1][0] |
|
|
|
tmp=[] |
|
mask=[] |
|
num = num_id_images |
|
num_trigger_tokens_processed = 0 |
|
for ls in tkwp: |
|
|
|
p = [] |
|
pmask = [] |
|
|
|
newls = [ls[0]] + [curr for prev, curr in zip_longest(ls, ls[1:]) |
|
if not (curr and prev and curr[0] == class_token and prev[0] == class_token)] |
|
if newls and newls[-1] is None: newls.pop() |
|
for pair in newls: |
|
|
|
if pair[0] != class_token: |
|
p.append(pair) |
|
pmask.append(pair) |
|
else: |
|
|
|
num_trigger_tokens_processed += 1 |
|
if p: |
|
|
|
pmask[-1] = (-1, pmask[-1][1]) |
|
if num-1 > 0: |
|
p.extend([p[-1]] * (num-1)) |
|
pmask.extend([( -1, pmask[-1][1] )] * (num-1)) |
|
else: |
|
|
|
|
|
if tmp and tmp[-1]: |
|
last_ls = tmp[-1] |
|
last_pair = last_ls[-1] |
|
mask[-1][-1] = (-1, mask[-1][-1][1]) |
|
if num-1 > 0: |
|
last_ls.extend([last_pair] * (num-1)) |
|
mask[-1].extend([ (-1, mask[-1][-1][1]) ] * (num-1)) |
|
if p: tmp.append(p) |
|
if pmask: mask.append(pmask) |
|
token_weight_pairs = tmp |
|
token_weight_pairs_mask = mask |
|
|
|
|
|
token_weight_pairs = tokenize_with_weights(clip_tokenizer, text, tokens=token_weight_pairs) |
|
token_weight_pairs_mask = tokenize_with_weights(clip_tokenizer, text, tokens=token_weight_pairs_mask) |
|
tokens[key] = token_weight_pairs |
|
|
|
|
|
class_tokens_mask[key] = list(map(lambda a: list(map(lambda b: b[0] < 0, a)), token_weight_pairs_mask)) |
|
|
|
prompt_embeds, pooled = clip.encode_from_tokens(tokens, return_pooled=True) |
|
cond = prompt_embeds |
|
device_orig = prompt_embeds.device |
|
first_key = next(iter(class_tokens_mask.keys())) |
|
class_tokens_mask = class_tokens_mask[first_key] |
|
if num_trigger_tokens_processed > 1: |
|
image = image.repeat([num_trigger_tokens_processed] + [1] * (len(image.shape) - 1)) |
|
|
|
photomaker = photomaker.to(device=photomaker.load_device) |
|
|
|
_, h, w, _ = image.shape |
|
do_resize = (h, w) != (224, 224) |
|
image_bak = image |
|
try: |
|
if do_resize: |
|
clip_preprocess = CLIPImageProcessor(resample=PILImageResampling.LANCZOS, do_normalize=False, do_rescale=False, do_convert_rgb=False) |
|
image = clip_preprocess(image, return_tensors="pt").pixel_values.movedim(1,-1) |
|
except RuntimeError as e: |
|
image = image_bak |
|
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() |
|
cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), |
|
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) |
|
cond = cond.to(device=device_orig) |
|
|
|
return ([[cond, {"pooled_output": pooled}]],) |
|
|
|
|
|
from .style_template import styles |
|
|
|
class PhotoMakerStyles: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"style_name": (list(styles.keys()), {"default": "Photographic (Default)"}), |
|
}, |
|
"optional": { |
|
"positive": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}), |
|
"negative": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}), |
|
}, |
|
} |
|
RETURN_TYPES = ("STRING","STRING",) |
|
RETURN_NAMES = ("POSITIVE","NEGATIVE",) |
|
FUNCTION = "apply_photomaker_style" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
def apply_photomaker_style(self, style_name, positive: str = '', negative: str = ''): |
|
p, n = styles.get(style_name, "Photographic (Default)") |
|
return p.replace("{prompt}", positive), n + ' ' + negative |
|
|
|
class PrepImagesForClipVisionFromPath: |
|
def __init__(self) -> None: |
|
self.image_loader = LoadImageCustom() |
|
self.load_device = comfy.model_management.text_encoder_device() |
|
self.offload_device = comfy.model_management.text_encoder_offload_device() |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"path": ("STRING", {"multiline": False}), |
|
"interpolation": (["nearest", "bilinear", "box", "bicubic", "lanczos", "hamming"], {"default": "lanczos"}), |
|
"crop_position": (["top", "bottom", "left", "right", "center", "pad"], {"default": "center"}), |
|
}, |
|
} |
|
|
|
@classmethod |
|
def IS_CHANGED(s, path:str, interpolation, crop_position): |
|
image_path_list = s.get_images_paths(path) |
|
hashes = [] |
|
for image_path in image_path_list: |
|
if not (path.startswith("http://") or path.startswith("https://")): |
|
m = hashlib.sha256() |
|
with open(image_path, 'rb') as f: |
|
m.update(f.read()) |
|
hashes.append(m.digest().hex()) |
|
return Counter(hashes) |
|
|
|
@classmethod |
|
def VALIDATE_INPUTS(s, path:str, interpolation, crop_position): |
|
image_path_list = s.get_images_paths(path) |
|
if len(image_path_list) == 0: |
|
return "No image provided or found." |
|
return True |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "prep_images_for_clip_vision_from_path" |
|
|
|
CATEGORY = "ipadapter" |
|
|
|
@classmethod |
|
def get_images_paths(self, path:str): |
|
image_path_list = [] |
|
path = path.strip() |
|
if path: |
|
image_path_list = [path] |
|
if not (path.startswith("http://") or path.startswith("https://")) and os.path.isdir(path): |
|
image_basename_list = os.listdir(path) |
|
image_path_list = [ |
|
os.path.join(path, basename) |
|
for basename in image_basename_list |
|
if not basename.startswith('.') and basename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp', '.gif')) |
|
] |
|
return image_path_list |
|
|
|
def prep_images_for_clip_vision_from_path(self, path:str, interpolation:str, crop_position,): |
|
image_path_list = self.get_images_paths(path) |
|
if len(image_path_list) == 0: |
|
raise ValueError("No image provided or found.") |
|
|
|
interpolation=interpolation.upper() |
|
size = (224, 224) |
|
try: |
|
input_id_images = [img if (img:=load_image(image_path)).size == size else crop_image_pil(img, crop_position) for image_path in image_path_list] |
|
do_resize = not all(img.size == size for img in input_id_images) |
|
resample = getattr(PILImageResampling, interpolation) |
|
clip_preprocess = CLIPImageProcessor(resample=resample, do_normalize=False, do_resize=do_resize) |
|
id_pixel_values = clip_preprocess(input_id_images, return_tensors="pt").pixel_values.movedim(1,-1) |
|
except TypeError as err: |
|
print('[PhotoMaker]:', err) |
|
print('[PhotoMaker]: You may need to update transformers.') |
|
input_id_images = [self.image_loader.load_image(image_path)[0] for image_path in image_path_list] |
|
do_resize = not all(img.shape[-3:-3+2] == size for img in input_id_images) |
|
if do_resize: |
|
id_pixel_values = torch.cat([prepImage(img, interpolation=interpolation, crop_position=crop_position) for img in input_id_images]) |
|
else: |
|
id_pixel_values = torch.cat(input_id_images) |
|
return (id_pixel_values,) |
|
|
|
supported = False |
|
try: |
|
from comfy_extras.nodes_photomaker import PhotoMakerLoader as _PhotoMakerLoader |
|
supported = True |
|
except: ... |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
**({} if supported else {"PhotoMakerLoader": PhotoMakerLoader}), |
|
"PhotoMakerEncodePlus": PhotoMakerEncodePlus, |
|
"PhotoMakerStyles": PhotoMakerStyles, |
|
"PrepImagesForClipVisionFromPath": PrepImagesForClipVisionFromPath, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
**({} if supported else {"PhotoMakerLoader": "Load PhotoMaker"}), |
|
"PhotoMakerEncodePlus": "PhotoMaker Encode Plus", |
|
"PhotoMakerStyles": "Apply PhotoMaker Style", |
|
"PrepImagesForClipVisionFromPath": "Prepare Images For CLIP Vision From Path", |
|
} |
|
|