File size: 4,788 Bytes
3175ce6
 
 
9235b7f
 
371bdca
50bfc5a
 
 
 
a7bee92
 
 
 
 
 
50bfc5a
c97fcf1
 
 
a7bee92
371bdca
50bfc5a
 
9235b7f
a7bee92
9235b7f
 
 
d033e91
ab0b470
6ef3309
50bfc5a
 
 
 
6ef3309
b3a0761
9235b7f
 
e6730cb
 
9235b7f
fc8037f
3175ce6
fc8037f
9235b7f
 
 
 
3175ce6
 
 
 
 
3d23955
 
 
 
 
3175ce6
3d23955
 
 
3175ce6
3d23955
 
3175ce6
3d23955
 
3175ce6
3d23955
 
 
 
 
 
 
 
 
3175ce6
3d23955
 
 
 
3175ce6
9235b7f
114a69f
 
9235b7f
114a69f
 
 
9235b7f
9cda2f8
 
9235b7f
 
 
9cda2f8
9235b7f
 
 
 
 
 
3175ce6
9235b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import cv2
import numpy as np
from PIL import Image
import os
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download

def resolve_hf_path(path):
    if isinstance(path, str) and path.startswith("hf://"):
        parts = path[len("hf://"):].split("@")
        if len(parts) == 1:
            repo_id = parts[0]
            filename = None
        elif len(parts) == 2:
            repo_id, filename = parts
        else:
            raise ValueError(f"Invalid HF URI format: {path}")
        token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
        if token is None:
            raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable not set!")
        # If filename is provided, download that file; otherwise, download the whole repo snapshot.
        local_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token) if filename else snapshot_download(repo_id=repo_id, token=token)
        return local_path
    return path

os.environ["FLUX_FILL_PATH"] = "hf://black-forest-labs/FLUX.1-Fill-dev"
os.environ["PORTRAIT_MODEL_PATH"] = "ms://iic/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
os.environ["SUBJECT_MODEL_PATH"] = "ms://iic/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
os.environ["LOCAL_MODEL_PATH"] = "ms://iic/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
os.environ["ACE_PLUS_FFT_MODEL"] = "hf://ali-vilab/ACE_Plus@ace_plus_fft.safetensors"

flux_full = resolve_hf_path(os.environ["FLUX_FILL_PATH"])
ace_plus_fft_model_path = resolve_hf_path(os.environ["ACE_PLUS_FFT_MODEL"])

# Update the environment variables with the resolved local file paths.
os.environ["ACE_PLUS_FFT_MODEL"] = ace_plus_fft_model_path
os.environ["FLUX_FILL_PATH"] = flux_full

from inference.ace_plus_inference import ACEInference
from scepter.modules.utils.config import Config
from modules.flux import FluxMRModiACEPlus
from inference.registry import INFERENCES


config_path = os.path.join("config", "ace_plus_fft.yaml")
cfg = Config(load=True, cfg_file=config_path)

# Instantiate the ACEInference object.
ace_infer = ACEInference(cfg)

def create_face_mask(pil_image):
    """
    Create a binary mask (PIL Image) from a PIL image by detecting the face region.
    The mask will be white (255) on the detected face area and black (0) elsewhere.
    """
    try:
        # Convert PIL image to a numpy array in RGB format
        image_np = np.array(pil_image.convert("RGB"))
        # Convert to grayscale for face detection
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)

        # Load the Haar cascade for face detection (make sure opencv data is installed)
        cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
        face_cascade = cv2.CascadeClassifier(cascade_path)

        # Detect faces in the image
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)

        # Create an empty mask with the same dimensions as the image
        mask = np.zeros_like(gray, dtype=np.uint8)

        # For each detected face, draw a white rectangle (or a more refined shape)
        for (x, y, w, h) in faces:
            # Optionally expand the bounding box slightly
            padding = 0.2
            x1 = max(0, int(x - w * padding))
            y1 = max(0, int(y - h * padding))
            x2 = min(gray.shape[1], int(x + w * (1 + padding)))
            y2 = min(gray.shape[0], int(y + h * (1 + padding)))
            mask[y1:y2, x1:x2] = 255

        return Image.fromarray(mask)
    except Exception as e:
        print(f"Error: {e}")
        raise ValueError('A very specific bad thing happened.')

def face_swap_app(target_img, face_img):
    if target_img is None or face_img is None:
        raise ValueError("Both a target image and a face image must be provided.")

    # (Optional) Ensure images are in RGB
    target_img = target_img.convert("RGB")
    face_img = face_img.convert("RGB")

    edit_mask = create_face_mask(face_img)

    output_img, edit_image, change_image, mask, seed = ace_infer(
        reference_image=target_img,
        edit_image=face_img,
        edit_mask=edit_mask,
        prompt="Face swap",
        output_height=1024,
        output_width=1024,
        sampler='flow_euler',
        sample_steps=28,
        guide_scale=50,
        seed=-1
    )
    return output_img

# Create the Gradio interface.
iface = gr.Interface(
    fn=face_swap_app,
    inputs=[
        gr.Image(type="pil", label="Target Image"),
        gr.Image(type="pil", label="Face Image")
    ],
    outputs=gr.Image(type="pil", label="Swapped Face Output"),
    title="ACE++ Face Swap Demo",
    description="Upload a target image and a face image to swap the face using the ACE++ model."
)

if __name__ == "__main__":
    iface.launch()