3v324v23's picture
lfs
1e3b872
raw
history blame
12 kB
import comfy.clip_vision
import comfy.clip_model
import comfy.model_management
import comfy.utils
from comfy.sd import CLIP
from itertools import zip_longest
from transformers import CLIPImageProcessor
from transformers.image_utils import PILImageResampling
from collections import Counter
import folder_paths
import torch
import os
from .model import PhotoMakerIDEncoder
from .utils import load_image, tokenize_with_weights, prepImage, crop_image_pil, LoadImageCustom
from folder_paths import folder_names_and_paths, models_dir, supported_pt_extensions, add_model_folder_path
from torch import Tensor
import hashlib
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
add_model_folder_path("loras", folder_names_and_paths["photomaker"][0][0])
class PhotoMakerLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), ),
}}
RETURN_TYPES = ("PHOTOMAKER",)
FUNCTION = "load_photomaker_model"
CATEGORY = "PhotoMaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:
data = data["id_encoder"]
photomaker_model.load_state_dict(data)
return (photomaker_model,)
class PhotoMakerEncodePlus:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP",),
"photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",),
"trigger_word": ("STRING", {"default": "img"}),
"text": ("STRING", {"multiline": True, "default": "photograph of a man img", "dynamicPrompts": True}),
},
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_photomaker"
CATEGORY = "PhotoMaker"
@torch.no_grad()
def apply_photomaker(self, clip: CLIP, photomaker: PhotoMakerIDEncoder, image: Tensor, trigger_word: str, text: str):
if (num_id_images:=len(image)) == 0:
raise ValueError("No image provided or found.")
trigger_word=trigger_word.strip()
tokens = clip.tokenize(text)
class_tokens_mask = {}
for key in tokens:
clip_tokenizer = getattr(clip.tokenizer, f'clip_{key}', clip.tokenizer)
tkwp = tokenize_with_weights(clip_tokenizer, text, return_tokens=True)
# e.g.: 24157
class_token = clip_tokenizer.tokenizer(trigger_word)["input_ids"][clip_tokenizer.tokens_start:-1][0]
tmp=[]
mask=[]
num = num_id_images
num_trigger_tokens_processed = 0
for ls in tkwp:
# recreate the list of pairs
p = []
pmask = []
# remove consecutive duplicates
newls = [ls[0]] + [curr for prev, curr in zip_longest(ls, ls[1:])
if not (curr and prev and curr[0] == class_token and prev[0] == class_token)]
if newls and newls[-1] is None: newls.pop()
for pair in newls:
# Non-matches simply get appended to the list.
if pair[0] != class_token:
p.append(pair)
pmask.append(pair)
else:
# Found a match; append it to the previous list or main list's last list
num_trigger_tokens_processed += 1
if p:
# take the last element of the list we're creating and repeat it
pmask[-1] = (-1, pmask[-1][1])
if num-1 > 0:
p.extend([p[-1]] * (num-1))
pmask.extend([( -1, pmask[-1][1] )] * (num-1))
else:
# The list we're cerating is empty so
# take the last element of the main list and then take its last element and repeat it
if tmp and tmp[-1]:
last_ls = tmp[-1]
last_pair = last_ls[-1]
mask[-1][-1] = (-1, mask[-1][-1][1])
if num-1 > 0:
last_ls.extend([last_pair] * (num-1))
mask[-1].extend([ (-1, mask[-1][-1][1]) ] * (num-1))
if p: tmp.append(p)
if pmask: mask.append(pmask)
token_weight_pairs = tmp
token_weight_pairs_mask = mask
# send it back to be batched evenly
token_weight_pairs = tokenize_with_weights(clip_tokenizer, text, tokens=token_weight_pairs)
token_weight_pairs_mask = tokenize_with_weights(clip_tokenizer, text, tokens=token_weight_pairs_mask)
tokens[key] = token_weight_pairs
# Finalize the mask
class_tokens_mask[key] = list(map(lambda a: list(map(lambda b: b[0] < 0, a)), token_weight_pairs_mask))
prompt_embeds, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
cond = prompt_embeds
device_orig = prompt_embeds.device
first_key = next(iter(class_tokens_mask.keys()))
class_tokens_mask = class_tokens_mask[first_key]
if num_trigger_tokens_processed > 1:
image = image.repeat([num_trigger_tokens_processed] + [1] * (len(image.shape) - 1))
photomaker = photomaker.to(device=photomaker.load_device)
_, h, w, _ = image.shape
do_resize = (h, w) != (224, 224)
image_bak = image
try:
if do_resize:
clip_preprocess = CLIPImageProcessor(resample=PILImageResampling.LANCZOS, do_normalize=False, do_rescale=False, do_convert_rgb=False)
image = clip_preprocess(image, return_tensors="pt").pixel_values.movedim(1,-1)
except RuntimeError as e:
image = image_bak
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
cond = cond.to(device=device_orig)
return ([[cond, {"pooled_output": pooled}]],)
from .style_template import styles
class PhotoMakerStyles:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"style_name": (list(styles.keys()), {"default": "Photographic (Default)"}),
},
"optional": {
"positive": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}),
"negative": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}),
},
}
RETURN_TYPES = ("STRING","STRING",)
RETURN_NAMES = ("POSITIVE","NEGATIVE",)
FUNCTION = "apply_photomaker_style"
CATEGORY = "PhotoMaker"
def apply_photomaker_style(self, style_name, positive: str = '', negative: str = ''):
p, n = styles.get(style_name, "Photographic (Default)")
return p.replace("{prompt}", positive), n + ' ' + negative
class PrepImagesForClipVisionFromPath:
def __init__(self) -> None:
self.image_loader = LoadImageCustom()
self.load_device = comfy.model_management.text_encoder_device()
self.offload_device = comfy.model_management.text_encoder_offload_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {
"path": ("STRING", {"multiline": False}),
"interpolation": (["nearest", "bilinear", "box", "bicubic", "lanczos", "hamming"], {"default": "lanczos"}),
"crop_position": (["top", "bottom", "left", "right", "center", "pad"], {"default": "center"}),
},
}
@classmethod
def IS_CHANGED(s, path:str, interpolation, crop_position):
image_path_list = s.get_images_paths(path)
hashes = []
for image_path in image_path_list:
if not (path.startswith("http://") or path.startswith("https://")):
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
hashes.append(m.digest().hex())
return Counter(hashes)
@classmethod
def VALIDATE_INPUTS(s, path:str, interpolation, crop_position):
image_path_list = s.get_images_paths(path)
if len(image_path_list) == 0:
return "No image provided or found."
return True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "prep_images_for_clip_vision_from_path"
CATEGORY = "ipadapter"
@classmethod
def get_images_paths(self, path:str):
image_path_list = []
path = path.strip()
if path:
image_path_list = [path]
if not (path.startswith("http://") or path.startswith("https://")) and os.path.isdir(path):
image_basename_list = os.listdir(path)
image_path_list = [
os.path.join(path, basename)
for basename in image_basename_list
if not basename.startswith('.') and basename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp', '.gif'))
]
return image_path_list
def prep_images_for_clip_vision_from_path(self, path:str, interpolation:str, crop_position,):
image_path_list = self.get_images_paths(path)
if len(image_path_list) == 0:
raise ValueError("No image provided or found.")
interpolation=interpolation.upper()
size = (224, 224)
try:
input_id_images = [img if (img:=load_image(image_path)).size == size else crop_image_pil(img, crop_position) for image_path in image_path_list]
do_resize = not all(img.size == size for img in input_id_images)
resample = getattr(PILImageResampling, interpolation)
clip_preprocess = CLIPImageProcessor(resample=resample, do_normalize=False, do_resize=do_resize)
id_pixel_values = clip_preprocess(input_id_images, return_tensors="pt").pixel_values.movedim(1,-1)
except TypeError as err:
print('[PhotoMaker]:', err)
print('[PhotoMaker]: You may need to update transformers.')
input_id_images = [self.image_loader.load_image(image_path)[0] for image_path in image_path_list]
do_resize = not all(img.shape[-3:-3+2] == size for img in input_id_images)
if do_resize:
id_pixel_values = torch.cat([prepImage(img, interpolation=interpolation, crop_position=crop_position) for img in input_id_images])
else:
id_pixel_values = torch.cat(input_id_images)
return (id_pixel_values,)
supported = False
try:
from comfy_extras.nodes_photomaker import PhotoMakerLoader as _PhotoMakerLoader
supported = True
except: ...
NODE_CLASS_MAPPINGS = {
**({} if supported else {"PhotoMakerLoader": PhotoMakerLoader}),
"PhotoMakerEncodePlus": PhotoMakerEncodePlus,
"PhotoMakerStyles": PhotoMakerStyles,
"PrepImagesForClipVisionFromPath": PrepImagesForClipVisionFromPath,
}
NODE_DISPLAY_NAME_MAPPINGS = {
**({} if supported else {"PhotoMakerLoader": "Load PhotoMaker"}),
"PhotoMakerEncodePlus": "PhotoMaker Encode Plus",
"PhotoMakerStyles": "Apply PhotoMaker Style",
"PrepImagesForClipVisionFromPath": "Prepare Images For CLIP Vision From Path",
}