# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torchvision import transforms import cv2 from einops import rearrange import mediapipe as mp import torch import numpy as np from typing import Union from .affine_transform import AlignRestore, laplacianSmooth import face_alignment """ If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation. https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image """ def load_fixed_mask(resolution: int) -> torch.Tensor: mask_image = cv2.imread("latentsync/utils/mask.png") mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0 mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w") return mask_image class ImageProcessor: def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None): self.resolution = resolution self.resize = transforms.Resize( (resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True ) self.normalize = transforms.Normalize([0.5], [0.5], inplace=True) self.mask = mask if mask in ["mouth", "face", "eye"]: self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image if mask == "fix_mask": self.face_mesh = None self.smoother = laplacianSmooth() self.restorer = AlignRestore() if mask_image is None: self.mask_image = load_fixed_mask(resolution) else: self.mask_image = mask_image if device != "cpu": self.fa = face_alignment.FaceAlignment( face_alignment.LandmarksType.TWO_D, flip_input=False, device=device ) self.face_mesh = None else: # self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image self.face_mesh = None self.fa = None def detect_facial_landmarks(self, image: np.ndarray): height, width, _ = image.shape results = self.face_mesh.process(image) if not results.multi_face_landmarks: # Face not detected raise RuntimeError("Face not detected") face_landmarks = results.multi_face_landmarks[0] # Only use the first face in the image landmark_coordinates = [ (int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark ] # x means width, y means height return landmark_coordinates def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray: image = self.resize(image) if self.mask == "mouth" or self.mask == "face": landmark_coordinates = self.detect_facial_landmarks(image) if self.mask == "mouth": surround_landmarks = mouth_surround_landmarks else: surround_landmarks = face_surround_landmarks points = [landmark_coordinates[landmark] for landmark in surround_landmarks] points = np.array(points) mask = np.ones((self.resolution, self.resolution)) mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0)) mask = torch.from_numpy(mask) mask = mask.unsqueeze(0) elif self.mask == "half": mask = torch.ones((self.resolution, self.resolution)) height = mask.shape[0] mask[height // 2 :, :] = 0 mask = mask.unsqueeze(0) elif self.mask == "eye": mask = torch.ones((self.resolution, self.resolution)) landmark_coordinates = self.detect_facial_landmarks(image) y = landmark_coordinates[195][1] mask[y:, :] = 0 mask = mask.unsqueeze(0) else: raise ValueError("Invalid mask type") image = image.to(dtype=torch.float32) pixel_values = self.normalize(image / 255.0) masked_pixel_values = pixel_values * mask mask = 1 - mask return pixel_values, masked_pixel_values, mask def affine_transform(self, image: torch.Tensor) -> np.ndarray: # image = rearrange(image, "c h w-> h w c").numpy() if self.fa is None: landmark_coordinates = np.array(self.detect_facial_landmarks(image)) lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates) else: detected_faces = self.fa.get_landmarks(image) if detected_faces is None: raise RuntimeError("Face not detected") lm68 = detected_faces[0] points = self.smoother.smooth(lm68) lmk3_ = np.zeros((3, 2)) lmk3_[0] = points[17:22].mean(0) lmk3_[1] = points[22:27].mean(0) lmk3_[2] = points[27:36].mean(0) # print(lmk3_) face, affine_matrix = self.restorer.align_warp_face( image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant" ) box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2 face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC) face = rearrange(torch.from_numpy(face), "h w c -> c h w") return face, box, affine_matrix def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False): if affine_transform: image, _, _ = self.affine_transform(image) else: image = self.resize(image) pixel_values = self.normalize(image / 255.0) masked_pixel_values = pixel_values * self.mask_image return pixel_values, masked_pixel_values, self.mask_image[0:1] def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False): if isinstance(images, np.ndarray): images = torch.from_numpy(images) if images.shape[3] == 3: images = rearrange(images, "b h w c -> b c h w") if self.mask == "fix_mask": results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images] else: results = [self.preprocess_one_masked_image(image) for image in images] pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results)) return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list) def process_images(self, images: Union[torch.Tensor, np.ndarray]): if isinstance(images, np.ndarray): images = torch.from_numpy(images) if images.shape[3] == 3: images = rearrange(images, "b h w c -> b c h w") images = self.resize(images) pixel_values = self.normalize(images / 255.0) return pixel_values def close(self): if self.face_mesh is not None: self.face_mesh.close() def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True): """ lm478: [B, 478, 3] or [478,3] """ # lm478[..., 0] *= W # lm478[..., 1] *= H landmarks_extracted = [] for index in landmark_points_68: x = lm478[index][0] y = lm478[index][1] landmarks_extracted.append((x, y)) return np.array(landmarks_extracted) landmark_points_68 = [ 162, 234, 93, 58, 172, 136, 149, 148, 152, 377, 378, 365, 397, 288, 323, 454, 389, 71, 63, 105, 66, 107, 336, 296, 334, 293, 301, 168, 197, 5, 4, 75, 97, 2, 326, 305, 33, 160, 158, 133, 153, 144, 362, 385, 387, 263, 373, 380, 61, 39, 37, 0, 267, 269, 291, 405, 314, 17, 84, 181, 78, 82, 13, 312, 308, 317, 14, 87, ] # Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png mouth_surround_landmarks = [ 164, 165, 167, 92, 186, 57, 43, 106, 182, 83, 18, 313, 406, 335, 273, 287, 410, 322, 391, 393, ] face_surround_landmarks = [ 152, 377, 400, 378, 379, 365, 397, 288, 435, 433, 411, 425, 423, 327, 326, 94, 97, 98, 203, 205, 187, 213, 215, 58, 172, 136, 150, 149, 176, 148, ] if __name__ == "__main__": image_processor = ImageProcessor(512, mask="fix_mask") video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4") while True: ret, frame = video.read() # if not ret: # break # cv2.imwrite("image.jpg", frame) frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w") # face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True) face, _, _ = image_processor.affine_transform(frame) break face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8) cv2.imwrite("face.jpg", face) # masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8) # cv2.imwrite("masked_face.jpg", masked_face)