|
|
|
|
|
|
|
|
|
"""
|
|
utils.py
|
|
|
|
This module provides utility functions for various tasks such as setting random seeds,
|
|
importing modules from files, managing checkpoint files, and saving video files from
|
|
sequences of PIL images.
|
|
|
|
Functions:
|
|
seed_everything(seed)
|
|
import_filename(filename)
|
|
delete_additional_ckpt(base_path, num_keep)
|
|
save_videos_from_pil(pil_images, path, fps=8)
|
|
|
|
Dependencies:
|
|
importlib
|
|
os
|
|
os.path as osp
|
|
random
|
|
shutil
|
|
sys
|
|
pathlib.Path
|
|
av
|
|
cv2
|
|
mediapipe as mp
|
|
numpy as np
|
|
torch
|
|
torchvision
|
|
einops.rearrange
|
|
moviepy.editor.AudioFileClip, VideoClip
|
|
PIL.Image
|
|
|
|
Examples:
|
|
seed_everything(42)
|
|
imported_module = import_filename('path/to/your/module.py')
|
|
delete_additional_ckpt('path/to/checkpoints', 1)
|
|
save_videos_from_pil(pil_images, 'output/video.mp4', fps=12)
|
|
|
|
The functions in this module ensure reproducibility of experiments by seeding random number
|
|
generators, allow dynamic importing of modules, manage checkpoint files by deleting extra ones,
|
|
and provide a way to save sequences of images as video files.
|
|
|
|
Function Details:
|
|
seed_everything(seed)
|
|
Seeds all random number generators to ensure reproducibility.
|
|
|
|
import_filename(filename)
|
|
Imports a module from a given file location.
|
|
|
|
delete_additional_ckpt(base_path, num_keep)
|
|
Deletes additional checkpoint files in the given directory.
|
|
|
|
save_videos_from_pil(pil_images, path, fps=8)
|
|
Saves a sequence of images as a video using the Pillow library.
|
|
|
|
Attributes:
|
|
_ (str): Placeholder for static type checking
|
|
"""
|
|
|
|
import importlib
|
|
import os
|
|
import os.path as osp
|
|
import random
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import av
|
|
import cv2
|
|
import mediapipe as mp
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
from einops import rearrange
|
|
from moviepy.editor import AudioFileClip, VideoClip
|
|
from PIL import Image
|
|
|
|
|
|
def seed_everything(seed):
|
|
"""
|
|
Seeds all random number generators to ensure reproducibility.
|
|
|
|
Args:
|
|
seed (int): The seed value to set for all random number generators.
|
|
"""
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed % (2**32))
|
|
random.seed(seed)
|
|
|
|
|
|
def import_filename(filename):
|
|
"""
|
|
Import a module from a given file location.
|
|
|
|
Args:
|
|
filename (str): The path to the file containing the module to be imported.
|
|
|
|
Returns:
|
|
module: The imported module.
|
|
|
|
Raises:
|
|
ImportError: If the module cannot be imported.
|
|
|
|
Example:
|
|
>>> imported_module = import_filename('path/to/your/module.py')
|
|
"""
|
|
spec = importlib.util.spec_from_file_location("mymodule", filename)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[spec.name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def delete_additional_ckpt(base_path, num_keep):
|
|
"""
|
|
Deletes additional checkpoint files in the given directory.
|
|
|
|
Args:
|
|
base_path (str): The path to the directory containing the checkpoint files.
|
|
num_keep (int): The number of most recent checkpoint files to keep.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
FileNotFoundError: If the base_path does not exist.
|
|
|
|
Example:
|
|
>>> delete_additional_ckpt('path/to/checkpoints', 1)
|
|
# This will delete all but the most recent checkpoint file in 'path/to/checkpoints'.
|
|
"""
|
|
dirs = []
|
|
for d in os.listdir(base_path):
|
|
if d.startswith("checkpoint-"):
|
|
dirs.append(d)
|
|
num_tot = len(dirs)
|
|
if num_tot <= num_keep:
|
|
return
|
|
|
|
del_dirs = sorted(dirs, key=lambda x: int(
|
|
x.split("-")[-1]))[: num_tot - num_keep]
|
|
for d in del_dirs:
|
|
path_to_dir = osp.join(base_path, d)
|
|
if osp.exists(path_to_dir):
|
|
shutil.rmtree(path_to_dir)
|
|
|
|
|
|
def save_videos_from_pil(pil_images, path, fps=8):
|
|
"""
|
|
Save a sequence of images as a video using the Pillow library.
|
|
|
|
Args:
|
|
pil_images (List[PIL.Image]): A list of PIL.Image objects representing the frames of the video.
|
|
path (str): The output file path for the video.
|
|
fps (int, optional): The frames per second rate of the video. Defaults to 8.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
ValueError: If the save format is not supported.
|
|
|
|
This function takes a list of PIL.Image objects and saves them as a video file with a specified frame rate.
|
|
The output file format is determined by the file extension of the provided path. Supported formats include
|
|
.mp4, .avi, and .mkv. The function uses the Pillow library to handle the image processing and video
|
|
creation.
|
|
"""
|
|
save_fmt = Path(path).suffix
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
width, height = pil_images[0].size
|
|
|
|
if save_fmt == ".mp4":
|
|
codec = "libx264"
|
|
container = av.open(path, "w")
|
|
stream = container.add_stream(codec, rate=fps)
|
|
|
|
stream.width = width
|
|
stream.height = height
|
|
|
|
for pil_image in pil_images:
|
|
|
|
av_frame = av.VideoFrame.from_image(pil_image)
|
|
container.mux(stream.encode(av_frame))
|
|
container.mux(stream.encode())
|
|
container.close()
|
|
|
|
elif save_fmt == ".gif":
|
|
pil_images[0].save(
|
|
fp=path,
|
|
format="GIF",
|
|
append_images=pil_images[1:],
|
|
save_all=True,
|
|
duration=(1 / fps * 1000),
|
|
loop=0,
|
|
)
|
|
else:
|
|
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
|
|
|
|
|
|
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
|
"""
|
|
Save a grid of videos as an animation or video.
|
|
|
|
Args:
|
|
videos (torch.Tensor): A tensor of shape (batch_size, channels, time, height, width)
|
|
containing the videos to save.
|
|
path (str): The path to save the video grid. Supported formats are .mp4, .avi, and .gif.
|
|
rescale (bool, optional): If True, rescale the video to the original resolution.
|
|
Defaults to False.
|
|
n_rows (int, optional): The number of rows in the video grid. Defaults to 6.
|
|
fps (int, optional): The frame rate of the saved video. Defaults to 8.
|
|
|
|
Raises:
|
|
ValueError: If the video format is not supported.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
videos = rearrange(videos, "b c t h w -> t b c h w")
|
|
|
|
outputs = []
|
|
|
|
for x in videos:
|
|
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
|
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
|
if rescale:
|
|
x = (x + 1.0) / 2.0
|
|
x = (x * 255).numpy().astype(np.uint8)
|
|
x = Image.fromarray(x)
|
|
|
|
outputs.append(x)
|
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
save_videos_from_pil(outputs, path, fps)
|
|
|
|
|
|
def read_frames(video_path):
|
|
"""
|
|
Reads video frames from a given video file.
|
|
|
|
Args:
|
|
video_path (str): The path to the video file.
|
|
|
|
Returns:
|
|
container (av.container.InputContainer): The input container object
|
|
containing the video stream.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the video file is not found.
|
|
RuntimeError: If there is an error in reading the video stream.
|
|
|
|
The function reads the video frames from the specified video file using the
|
|
Python AV library (av). It returns an input container object that contains
|
|
the video stream. If the video file is not found, it raises a FileNotFoundError,
|
|
and if there is an error in reading the video stream, it raises a RuntimeError.
|
|
"""
|
|
container = av.open(video_path)
|
|
|
|
video_stream = next(s for s in container.streams if s.type == "video")
|
|
frames = []
|
|
for packet in container.demux(video_stream):
|
|
for frame in packet.decode():
|
|
image = Image.frombytes(
|
|
"RGB",
|
|
(frame.width, frame.height),
|
|
frame.to_rgb().to_ndarray(),
|
|
)
|
|
frames.append(image)
|
|
|
|
return frames
|
|
|
|
|
|
def get_fps(video_path):
|
|
"""
|
|
Get the frame rate (FPS) of a video file.
|
|
|
|
Args:
|
|
video_path (str): The path to the video file.
|
|
|
|
Returns:
|
|
int: The frame rate (FPS) of the video file.
|
|
"""
|
|
container = av.open(video_path)
|
|
video_stream = next(s for s in container.streams if s.type == "video")
|
|
fps = video_stream.average_rate
|
|
container.close()
|
|
return fps
|
|
|
|
|
|
def tensor_to_video(tensor, output_video_file, audio_source, fps=25):
|
|
"""
|
|
Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
|
|
|
|
Args:
|
|
tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w].
|
|
output_video_file (str): The file path where the output video will be saved.
|
|
audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added.
|
|
fps (int): The frame rate of the output video. Default is 25 fps.
|
|
"""
|
|
tensor = tensor.permute(1, 2, 3, 0).cpu(
|
|
).numpy()
|
|
tensor = np.clip(tensor * 255, 0, 255).astype(
|
|
np.uint8
|
|
)
|
|
|
|
def make_frame(t):
|
|
|
|
frame_index = min(int(t * fps), tensor.shape[0] - 1)
|
|
return tensor[frame_index]
|
|
new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps)
|
|
audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps)
|
|
new_video_clip = new_video_clip.set_audio(audio_clip)
|
|
new_video_clip.write_videofile(output_video_file, fps=fps, audio_codec='aac')
|
|
|
|
|
|
silhouette_ids = [
|
|
10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
|
|
397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
|
|
172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109
|
|
]
|
|
lip_ids = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291,
|
|
146, 91, 181, 84, 17, 314, 405, 321, 375]
|
|
|
|
|
|
def compute_face_landmarks(detection_result, h, w):
|
|
"""
|
|
Compute face landmarks from a detection result.
|
|
|
|
Args:
|
|
detection_result (mediapipe.solutions.face_mesh.FaceMesh): The detection result containing face landmarks.
|
|
h (int): The height of the video frame.
|
|
w (int): The width of the video frame.
|
|
|
|
Returns:
|
|
face_landmarks_list (list): A list of face landmarks.
|
|
"""
|
|
face_landmarks_list = detection_result.face_landmarks
|
|
if len(face_landmarks_list) != 1:
|
|
print("#face is invalid:", len(face_landmarks_list))
|
|
return []
|
|
return [[p.x * w, p.y * h] for p in face_landmarks_list[0]]
|
|
|
|
|
|
def get_landmark(file):
|
|
"""
|
|
This function takes a file as input and returns the facial landmarks detected in the file.
|
|
|
|
Args:
|
|
file (str): The path to the file containing the video or image to be processed.
|
|
|
|
Returns:
|
|
Tuple[List[float], List[float]]: A tuple containing two lists of floats representing the x and y coordinates of the facial landmarks.
|
|
"""
|
|
model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
|
|
BaseOptions = mp.tasks.BaseOptions
|
|
FaceLandmarker = mp.tasks.vision.FaceLandmarker
|
|
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
|
|
VisionRunningMode = mp.tasks.vision.RunningMode
|
|
|
|
options = FaceLandmarkerOptions(
|
|
base_options=BaseOptions(model_asset_path=model_path),
|
|
running_mode=VisionRunningMode.IMAGE,
|
|
)
|
|
|
|
with FaceLandmarker.create_from_options(options) as landmarker:
|
|
image = mp.Image.create_from_file(str(file))
|
|
height, width = image.height, image.width
|
|
face_landmarker_result = landmarker.detect(image)
|
|
face_landmark = compute_face_landmarks(
|
|
face_landmarker_result, height, width)
|
|
|
|
return np.array(face_landmark), height, width
|
|
|
|
|
|
def get_landmark_overframes(landmark_model, frames_path):
|
|
"""
|
|
This function iterate frames and returns the facial landmarks detected in each frame.
|
|
|
|
Args:
|
|
landmark_model: mediapipe landmark model instance
|
|
frames_path (str): The path to the video frames.
|
|
|
|
Returns:
|
|
List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks.
|
|
"""
|
|
|
|
face_landmarks = []
|
|
|
|
for file in sorted(os.listdir(frames_path)):
|
|
image = mp.Image.create_from_file(os.path.join(frames_path, file))
|
|
height, width = image.height, image.width
|
|
landmarker_result = landmark_model.detect(image)
|
|
frame_landmark = compute_face_landmarks(
|
|
landmarker_result, height, width)
|
|
face_landmarks.append(frame_landmark)
|
|
|
|
return face_landmarks, height, width
|
|
|
|
|
|
def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0):
|
|
"""
|
|
Extracts the lip region from the given landmarks and saves it as an image.
|
|
|
|
Parameters:
|
|
landmarks (numpy.ndarray): Array of facial landmarks.
|
|
height (int): Height of the output lip mask image.
|
|
width (int): Width of the output lip mask image.
|
|
out_path (pathlib.Path): Path to save the lip mask image.
|
|
expand_ratio (float): Expand ratio of mask.
|
|
"""
|
|
lip_landmarks = np.take(landmarks, lip_ids, 0)
|
|
min_xy_lip = np.round(np.min(lip_landmarks, 0))
|
|
max_xy_lip = np.round(np.max(lip_landmarks, 0))
|
|
min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region(
|
|
[min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio)
|
|
lip_mask = np.zeros((height, width), dtype=np.uint8)
|
|
lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]),
|
|
round(min_xy_lip[0]):round(max_xy_lip[0])] = 255
|
|
if out_path:
|
|
cv2.imwrite(str(out_path), lip_mask)
|
|
return None
|
|
|
|
return lip_mask
|
|
|
|
|
|
def get_union_lip_mask(landmarks, height, width, expand_ratio=1):
|
|
"""
|
|
Extracts the lip region from the given landmarks and saves it as an image.
|
|
|
|
Parameters:
|
|
landmarks (numpy.ndarray): Array of facial landmarks.
|
|
height (int): Height of the output lip mask image.
|
|
width (int): Width of the output lip mask image.
|
|
expand_ratio (float): Expand ratio of mask.
|
|
"""
|
|
lip_masks = []
|
|
for landmark in landmarks:
|
|
lip_masks.append(get_lip_mask(landmarks=landmark, height=height,
|
|
width=width, expand_ratio=expand_ratio))
|
|
union_mask = get_union_mask(lip_masks)
|
|
return union_mask
|
|
|
|
|
|
def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2):
|
|
"""
|
|
Generate a face mask based on the given landmarks.
|
|
|
|
Args:
|
|
landmarks (numpy.ndarray): The landmarks of the face.
|
|
height (int): The height of the output face mask image.
|
|
width (int): The width of the output face mask image.
|
|
out_path (pathlib.Path): The path to save the face mask image.
|
|
expand_ratio (float): Expand ratio of mask.
|
|
Returns:
|
|
None. The face mask image is saved at the specified path.
|
|
"""
|
|
face_landmarks = np.take(landmarks, silhouette_ids, 0)
|
|
min_xy_face = np.round(np.min(face_landmarks, 0))
|
|
max_xy_face = np.round(np.max(face_landmarks, 0))
|
|
min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1] = expand_region(
|
|
[min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1]], width, height, expand_ratio)
|
|
face_mask = np.zeros((height, width), dtype=np.uint8)
|
|
face_mask[round(min_xy_face[1]):round(max_xy_face[1]),
|
|
round(min_xy_face[0]):round(max_xy_face[0])] = 255
|
|
if out_path:
|
|
cv2.imwrite(str(out_path), face_mask)
|
|
return None
|
|
|
|
return face_mask
|
|
|
|
|
|
def get_union_face_mask(landmarks, height, width, expand_ratio=1):
|
|
"""
|
|
Generate a face mask based on the given landmarks.
|
|
|
|
Args:
|
|
landmarks (numpy.ndarray): The landmarks of the face.
|
|
height (int): The height of the output face mask image.
|
|
width (int): The width of the output face mask image.
|
|
expand_ratio (float): Expand ratio of mask.
|
|
Returns:
|
|
None. The face mask image is saved at the specified path.
|
|
"""
|
|
face_masks = []
|
|
for landmark in landmarks:
|
|
face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio))
|
|
union_mask = get_union_mask(face_masks)
|
|
return union_mask
|
|
|
|
def get_mask(file, cache_dir, face_expand_raio):
|
|
"""
|
|
Generate a face mask based on the given landmarks and save it to the specified cache directory.
|
|
|
|
Args:
|
|
file (str): The path to the file containing the landmarks.
|
|
cache_dir (str): The directory to save the generated face mask.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
landmarks, height, width = get_landmark(file)
|
|
file_name = os.path.basename(file).split(".")[0]
|
|
get_lip_mask(landmarks, height, width, os.path.join(
|
|
cache_dir, f"{file_name}_lip_mask.png"))
|
|
get_face_mask(landmarks, height, width, os.path.join(
|
|
cache_dir, f"{file_name}_face_mask.png"), face_expand_raio)
|
|
get_blur_mask(os.path.join(
|
|
cache_dir, f"{file_name}_face_mask.png"), os.path.join(
|
|
cache_dir, f"{file_name}_face_mask_blur.png"), kernel_size=(51, 51))
|
|
get_blur_mask(os.path.join(
|
|
cache_dir, f"{file_name}_lip_mask.png"), os.path.join(
|
|
cache_dir, f"{file_name}_sep_lip.png"), kernel_size=(31, 31))
|
|
get_background_mask(os.path.join(
|
|
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join(
|
|
cache_dir, f"{file_name}_sep_background.png"))
|
|
get_sep_face_mask(os.path.join(
|
|
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join(
|
|
cache_dir, f"{file_name}_sep_lip.png"), os.path.join(
|
|
cache_dir, f"{file_name}_sep_face.png"))
|
|
|
|
|
|
def expand_region(region, image_w, image_h, expand_ratio=1.0):
|
|
"""
|
|
Expand the given region by a specified ratio.
|
|
Args:
|
|
region (tuple): A tuple containing the coordinates (min_x, max_x, min_y, max_y) of the region.
|
|
image_w (int): The width of the image.
|
|
image_h (int): The height of the image.
|
|
expand_ratio (float, optional): The ratio by which the region should be expanded. Defaults to 1.0.
|
|
|
|
Returns:
|
|
tuple: A tuple containing the expanded coordinates (min_x, max_x, min_y, max_y) of the region.
|
|
"""
|
|
|
|
min_x, max_x, min_y, max_y = region
|
|
mid_x = (max_x + min_x) // 2
|
|
side_len_x = (max_x - min_x) * expand_ratio
|
|
mid_y = (max_y + min_y) // 2
|
|
side_len_y = (max_y - min_y) * expand_ratio
|
|
min_x = mid_x - side_len_x // 2
|
|
max_x = mid_x + side_len_x // 2
|
|
min_y = mid_y - side_len_y // 2
|
|
max_y = mid_y + side_len_y // 2
|
|
if min_x < 0:
|
|
max_x -= min_x
|
|
min_x = 0
|
|
if max_x > image_w:
|
|
min_x -= max_x - image_w
|
|
max_x = image_w
|
|
if min_y < 0:
|
|
max_y -= min_y
|
|
min_y = 0
|
|
if max_y > image_h:
|
|
min_y -= max_y - image_h
|
|
max_y = image_h
|
|
|
|
return round(min_x), round(max_x), round(min_y), round(max_y)
|
|
|
|
|
|
def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=(101, 101)):
|
|
"""
|
|
Read, resize, blur, normalize, and save an image.
|
|
|
|
Parameters:
|
|
file_path (str): Path to the input image file.
|
|
output_dir (str): Path to the output directory to save blurred images.
|
|
resize_dim (tuple): Dimensions to resize the images to.
|
|
kernel_size (tuple): Size of the kernel to use for Gaussian blur.
|
|
"""
|
|
|
|
mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
|
|
|
|
|
|
if mask is not None:
|
|
normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size)
|
|
|
|
cv2.imwrite(output_file_path, normalized_mask)
|
|
return f"Processed, normalized, and saved: {output_file_path}"
|
|
return f"Failed to load image: {file_path}"
|
|
|
|
|
|
def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)):
|
|
"""
|
|
Read, resize, blur, normalize, and save an image.
|
|
|
|
Parameters:
|
|
file_path (str): Path to the input image file.
|
|
resize_dim (tuple): Dimensions to resize the images to.
|
|
kernel_size (tuple): Size of the kernel to use for Gaussian blur.
|
|
"""
|
|
|
|
normalized_mask = None
|
|
if mask is not None:
|
|
|
|
resized_mask = cv2.resize(mask, resize_dim)
|
|
|
|
blurred_mask = cv2.GaussianBlur(resized_mask, kernel_size, 0)
|
|
|
|
normalized_mask = cv2.normalize(
|
|
blurred_mask, None, 0, 255, cv2.NORM_MINMAX)
|
|
|
|
return normalized_mask
|
|
|
|
def get_background_mask(file_path, output_file_path):
|
|
"""
|
|
Read an image, invert its values, and save the result.
|
|
|
|
Parameters:
|
|
file_path (str): Path to the input image file.
|
|
output_dir (str): Path to the output directory to save the inverted image.
|
|
"""
|
|
|
|
image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
|
|
|
|
if image is None:
|
|
print(f"Failed to load image: {file_path}")
|
|
return
|
|
|
|
|
|
inverted_image = 1.0 - (
|
|
image / 255.0
|
|
)
|
|
|
|
inverted_image = (inverted_image * 255).astype(np.uint8)
|
|
|
|
|
|
cv2.imwrite(output_file_path, inverted_image)
|
|
print(f"Processed and saved: {output_file_path}")
|
|
|
|
|
|
def get_sep_face_mask(file_path1, file_path2, output_file_path):
|
|
"""
|
|
Read two images, subtract the second one from the first, and save the result.
|
|
|
|
Parameters:
|
|
output_dir (str): Path to the output directory to save the subtracted image.
|
|
"""
|
|
|
|
|
|
mask1 = cv2.imread(file_path1, cv2.IMREAD_GRAYSCALE)
|
|
mask2 = cv2.imread(file_path2, cv2.IMREAD_GRAYSCALE)
|
|
|
|
if mask1 is None or mask2 is None:
|
|
print(f"Failed to load images: {file_path1}")
|
|
return
|
|
|
|
|
|
if mask1.shape != mask2.shape:
|
|
print(
|
|
f"Image shapes do not match for {file_path1}: {mask1.shape} vs {mask2.shape}"
|
|
)
|
|
return
|
|
|
|
|
|
result_mask = cv2.subtract(mask1, mask2)
|
|
|
|
|
|
cv2.imwrite(output_file_path, result_mask)
|
|
print(f"Processed and saved: {output_file_path}")
|
|
|
|
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
|
|
p = subprocess.Popen([
|
|
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
|
|
])
|
|
ret = p.wait()
|
|
assert ret == 0, "Resample audio failed!"
|
|
return output_audio_file
|
|
|
|
def get_face_region(image_path: str, detector):
|
|
try:
|
|
image = cv2.imread(image_path)
|
|
if image is None:
|
|
print(f"Failed to open image: {image_path}. Skipping...")
|
|
return None, None
|
|
|
|
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
|
detection_result = detector.detect(mp_image)
|
|
|
|
|
|
mask = np.zeros_like(image, dtype=np.uint8)
|
|
|
|
for detection in detection_result.detections:
|
|
bbox = detection.bounding_box
|
|
start_point = (int(bbox.origin_x), int(bbox.origin_y))
|
|
end_point = (int(bbox.origin_x + bbox.width),
|
|
int(bbox.origin_y + bbox.height))
|
|
cv2.rectangle(mask, start_point, end_point,
|
|
(255, 255, 255), thickness=-1)
|
|
|
|
save_path = image_path.replace("images", "face_masks")
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
cv2.imwrite(save_path, mask)
|
|
|
|
return image_path, mask
|
|
except Exception as e:
|
|
print(f"Error processing image {image_path}: {e}")
|
|
return None, None
|
|
|
|
|
|
def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
|
|
"""
|
|
Save the model's state_dict to a checkpoint file.
|
|
|
|
If `total_limit` is provided, this function will remove the oldest checkpoints
|
|
until the total number of checkpoints is less than the specified limit.
|
|
|
|
Args:
|
|
model (nn.Module): The model whose state_dict is to be saved.
|
|
save_dir (str): The directory where the checkpoint will be saved.
|
|
prefix (str): The prefix for the checkpoint file name.
|
|
ckpt_num (int): The checkpoint number to be saved.
|
|
total_limit (int, optional): The maximum number of checkpoints to keep.
|
|
Defaults to None, in which case no checkpoints will be removed.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the save directory does not exist.
|
|
ValueError: If the checkpoint number is negative.
|
|
OSError: If there is an error saving the checkpoint.
|
|
"""
|
|
|
|
if not osp.exists(save_dir):
|
|
raise FileNotFoundError(
|
|
f"The save directory {save_dir} does not exist.")
|
|
|
|
if ckpt_num < 0:
|
|
raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")
|
|
|
|
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
|
|
|
|
if total_limit > 0:
|
|
checkpoints = os.listdir(save_dir)
|
|
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
|
|
checkpoints = sorted(
|
|
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
|
)
|
|
|
|
if len(checkpoints) >= total_limit:
|
|
num_to_remove = len(checkpoints) - total_limit + 1
|
|
removing_checkpoints = checkpoints[0:num_to_remove]
|
|
print(
|
|
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
|
)
|
|
print(
|
|
f"Removing checkpoints: {', '.join(removing_checkpoints)}"
|
|
)
|
|
|
|
for removing_checkpoint in removing_checkpoints:
|
|
removing_checkpoint_path = osp.join(
|
|
save_dir, removing_checkpoint)
|
|
try:
|
|
os.remove(removing_checkpoint_path)
|
|
except OSError as e:
|
|
print(
|
|
f"Error removing checkpoint {removing_checkpoint_path}: {e}")
|
|
|
|
state_dict = model.state_dict()
|
|
try:
|
|
torch.save(state_dict, save_path)
|
|
print(f"Checkpoint saved at {save_path}")
|
|
except OSError as e:
|
|
raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e
|
|
|
|
|
|
def init_output_dir(dir_list: List[str]):
|
|
"""
|
|
Initialize the output directories.
|
|
|
|
This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.
|
|
|
|
Args:
|
|
dir_list (List[str]): List of directory paths to create.
|
|
"""
|
|
for path in dir_list:
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
def load_checkpoint(cfg, save_dir, accelerator):
|
|
"""
|
|
Load the most recent checkpoint from the specified directory.
|
|
|
|
This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
|
|
If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
|
|
it starts training from scratch.
|
|
|
|
Args:
|
|
cfg: The configuration object containing training parameters.
|
|
save_dir (str): The directory where checkpoints are saved.
|
|
accelerator: The accelerator object for distributed training.
|
|
|
|
Returns:
|
|
int: The global step at which to resume training.
|
|
"""
|
|
if cfg.resume_from_checkpoint != "latest":
|
|
resume_dir = cfg.resume_from_checkpoint
|
|
else:
|
|
resume_dir = save_dir
|
|
|
|
dirs = os.listdir(resume_dir)
|
|
|
|
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
|
if len(dirs) > 0:
|
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
|
path = dirs[-1]
|
|
accelerator.load_state(os.path.join(resume_dir, path))
|
|
accelerator.print(f"Resuming from checkpoint {path}")
|
|
global_step = int(path.split("-")[1])
|
|
else:
|
|
accelerator.print(
|
|
f"Could not find checkpoint under {resume_dir}, start training from scratch")
|
|
global_step = 0
|
|
|
|
return global_step
|
|
|
|
|
|
def compute_snr(noise_scheduler, timesteps):
|
|
"""
|
|
Computes SNR as per
|
|
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
|
|
521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
|
"""
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
|
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
|
|
|
|
|
|
|
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
|
timesteps
|
|
].float()
|
|
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
|
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
|
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
|
device=timesteps.device
|
|
)[timesteps].float()
|
|
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
|
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
|
|
|
|
|
snr = (alpha / sigma) ** 2
|
|
return snr
|
|
|
|
|
|
def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path:
|
|
"""
|
|
Extract audio from a video file and save it as a WAV file.
|
|
|
|
This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file
|
|
in the specified output directory.
|
|
|
|
Args:
|
|
video_path (Path): The path to the input video file.
|
|
output_dir (Path): The directory where the extracted audio file will be saved.
|
|
|
|
Returns:
|
|
Path: The path to the extracted audio file.
|
|
|
|
Raises:
|
|
subprocess.CalledProcessError: If the ffmpeg command fails to execute.
|
|
"""
|
|
ffmpeg_command = [
|
|
'ffmpeg', '-y',
|
|
'-i', str(video_path),
|
|
'-vn', '-acodec',
|
|
"pcm_s16le", '-ar', '16000', '-ac', '2',
|
|
str(audio_output_path)
|
|
]
|
|
|
|
try:
|
|
print(f"Running command: {' '.join(ffmpeg_command)}")
|
|
subprocess.run(ffmpeg_command, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error extracting audio from video: {e}")
|
|
raise
|
|
|
|
return audio_output_path
|
|
|
|
|
|
def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
|
|
"""
|
|
Convert a video file into a sequence of images.
|
|
|
|
This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved
|
|
in a directory named after the video file stem under the specified output directory.
|
|
|
|
Args:
|
|
video_path (Path): The path to the input video file.
|
|
output_dir (Path): The directory where the extracted images will be saved.
|
|
|
|
Returns:
|
|
Path: The path to the directory containing the extracted images.
|
|
|
|
Raises:
|
|
subprocess.CalledProcessError: If the ffmpeg command fails to execute.
|
|
"""
|
|
ffmpeg_command = [
|
|
'ffmpeg',
|
|
'-i', str(video_path),
|
|
'-vf', 'fps=25',
|
|
str(output_dir / '%04d.png')
|
|
]
|
|
|
|
try:
|
|
print(f"Running command: {' '.join(ffmpeg_command)}")
|
|
subprocess.run(ffmpeg_command, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error converting video to images: {e}")
|
|
raise
|
|
|
|
return output_dir
|
|
|
|
|
|
def get_union_mask(masks):
|
|
"""
|
|
Compute the union of a list of masks.
|
|
|
|
This function takes a list of masks and computes their union by taking the maximum value at each pixel location.
|
|
Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white.
|
|
|
|
Args:
|
|
masks (list of np.ndarray): List of masks to be combined.
|
|
|
|
Returns:
|
|
np.ndarray: The union of the input masks.
|
|
"""
|
|
union_mask = None
|
|
for mask in masks:
|
|
if union_mask is None:
|
|
union_mask = mask
|
|
else:
|
|
union_mask = np.maximum(union_mask, mask)
|
|
|
|
if union_mask is not None:
|
|
|
|
rows = np.any(union_mask, axis=1)
|
|
cols = np.any(union_mask, axis=0)
|
|
try:
|
|
ymin, ymax = np.where(rows)[0][[0, -1]]
|
|
xmin, xmax = np.where(cols)[0][[0, -1]]
|
|
except Exception as e:
|
|
print(str(e))
|
|
return 0.0
|
|
|
|
|
|
union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask)
|
|
|
|
return union_mask
|
|
|
|
|
|
def move_final_checkpoint(save_dir, module_dir, prefix):
|
|
"""
|
|
Move the final checkpoint file to the save directory.
|
|
|
|
This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory.
|
|
|
|
Args:
|
|
save_dir (str): The directory where the final checkpoint file should be saved.
|
|
module_dir (str): The directory containing the checkpoint files.
|
|
prefix (str): The prefix used to identify checkpoint files.
|
|
|
|
Raises:
|
|
ValueError: If no checkpoint files are found with the specified prefix.
|
|
"""
|
|
checkpoints = os.listdir(module_dir)
|
|
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
|
|
checkpoints = sorted(
|
|
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
|
)
|
|
shutil.copy2(os.path.join(
|
|
module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth'))
|
|
|