MuseV-test / mmcm /vision /process /image_process.py
kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame
9.56 kB
from einops import rearrange
import requests
from io import BytesIO
from typing import Literal, Union
import math
from PIL import Image
import numpy as np
from diffusers.utils import load_image
import cv2
import torch
from mmcm.vision.utils.data_type_util import convert_images
from transformers.models.clip.image_processing_clip import to_numpy_array
from ..utils.vision_util import round_up_to_even
def get_image_from_input(image: Union[str, Image.Image]) -> Image.Image:
if isinstance(image, str):
if "http" in image:
image = BytesIO(requests.get(image).content)
image = Image.open(image).convert("RGB")
else:
image = Image.open(image).convert("RGB")
else:
image = image.convert("RGB")
assert type(image) == Image.Image
return image
def dynamic_resize_image(
image: Image.Image,
target_height: int,
target_width: int,
image_max_length: int = 910,
) -> Image.Image:
"""对图像进行预处理,目前会将短边resize到目标长度,同时限制长边长度
Args:
image (Image.Image): _description_
target_height (int): _description_
target_width (int): _description_
image_max_length (int): _description_
Returns:
Image.Image: _description_
"""
w, h = image.size
if w > h:
target_width = min(math.ceil(w * target_height / h), image_max_length)
target_height = math.ceil(target_width / w * h)
else:
target_height = min(math.ceil(h * target_width / w), image_max_length)
target_width = math.ceil(target_height / h * w)
target_width = round_up_to_even(target_width)
target_height = round_up_to_even(target_height)
image = image.resize((target_width, target_height))
return image
def dynamic_crop_resize_image(
image: Image.Image,
target_height: int,
target_width: int,
resample=None,
) -> Image.Image:
"""获取图像有效部分,并resize到对应目标宽度和高度。
如果图像宽高比大于 target_width / target_height,则保留全部高,截取宽的中心部位;
如果图像宽高比小于 target_width / target_height,则保留全部宽,截取高的中心部位;
最后,将截取的图像resize到目标宽高
Args:
image (Image.Image): 输入图像
target_height (int): 目标高
target_width (int): 目标宽
Returns:
Image.Image: 动态截取、resize生成的图像
"""
w, h = image.size
image_width_heigt_ratio = w / h
target_width_height_ratio = target_width / target_height
if image_width_heigt_ratio >= target_width_height_ratio:
y1 = 0
y2 = h - 1
x1 = math.ceil((w - h * target_width / target_height) / 2)
x2 = math.ceil(w - (w - h * target_width / target_height) / 2)
else:
x1 = 0
x2 = w - 1
y1 = math.ceil((h - w * target_height / target_width) / 2)
y2 = math.ceil(h - (h - w * target_height / target_width) / 2)
x1 = max(0, x1)
x2 = min(x2, w - 1)
y1 = max(0, y1)
y2 = min(y2, h - 1)
image = image.crop((x1, y1, x2, y2))
image = image.resize((target_width, target_height), resample=resample)
return image
def get_canny(
image: np.ndarray, low_threshold: float, high_threshold: float
) -> np.ndarray:
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
return image
def pad_matrix(matrix, target_shape):
h, w, c = matrix.shape
h1, w1 = target_shape
if h1 < h or w1 < w:
raise ValueError("Target shape must be larger than original shape.")
pad_h = (h1 - h) // 2
pad_w = (w1 - w) // 2
padded_matrix = np.zeros((h1, w1, c))
padded_matrix[pad_h : pad_h + h, pad_w : pad_w + w, :] = matrix
return padded_matrix
def pad_tensor(tensor, shape):
"""
将输入的numpy array tensor进行0填充,直到其尺寸达到目标尺寸shape。
参数:
tensor: numpy array,输入的tensor
shape: tuple,目标尺寸
返回值:
numpy array,填充后的tensor
"""
# 获取tensor的尺寸
tensor_shape = tensor.shape
# 计算需要填充的尺寸
pad_shape = tuple(
np.maximum(np.zeros_like(shape), np.array(shape) - np.array(tensor_shape))
)
# pad_shape = (np.max(0, shape[i] - tensor_shape[i]) for i in range(len(shape)))
# 构造填充后的tensor
pad_shape_ = ((0, x) for x in pad_shape)
padded_tensor = np.pad(
tensor,
((0, pad_shape[0]), (0, pad_shape[1]), (0, pad_shape[2]), (0, pad_shape[3])),
# pad_shape_,
mode="constant",
)
return padded_tensor
def batch_dynamic_crop_resize_images_v2(
images: Union[torch.Tensor, np.ndarray],
target_height: int,
target_width: int,
mode=Image.Resampling.LANCZOS,
) -> np.ndarray:
"""获取图像中心有效部分,并resize到对应目标宽度和高度。
如果图像宽高比大于 target_width / target_height,则保留全部高,截取宽的中心部位;
如果图像宽高比小于 target_width / target_height,则保留全部宽,截取高的中心部位;
最后,将截取的图像resize到目标宽高
Args:
image (Image.Image): 输入图像
target_height (int): 目标高
target_width (int): 目标宽
Returns:
Image.Image: 动态截取、resize生成的图像
"""
ndim = images.ndim
if ndim == 4:
b, c, h, w = images.shape
elif ndim == 5:
b, c, t, h, w = images.shape
images = rearrange(images, "b c t h w->(b t) c h w")
else:
raise ValueError(f"ndim only support 4, 5 but given {ndim}")
images = convert_images(
images, data_channel_order="b c h w", return_type="pil", input_rgb_order="rgb"
)
images = [
dynamic_crop_resize_image(
image,
target_height=target_height,
target_width=target_width,
resample=mode,
)
for image in images
]
images = [to_numpy_array(x) for x in images]
images = np.stack(images, axis=0)
images = rearrange(images, "b h w c-> b c h w")
if ndim == 5:
images = rearrange(images, "(b t) c h w->b c t h w", b=b, t=t)
return images
def batch_dynamic_crop_resize_images(
images: Union[torch.Tensor, np.ndarray],
target_height: int,
target_width: int,
mode: Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
] = "bilinear",
# ] = "nearest",
align_corners=False,
) -> torch.TensorType:
"""获取图像中心有效部分,并resize到对应目标宽度和高度。
如果图像宽高比大于 target_width / target_height,则保留全部高,截取宽的中心部位;
如果图像宽高比小于 target_width / target_height,则保留全部宽,截取高的中心部位;
最后,将截取的图像resize到目标宽高
Warning: 该方法对于 b c t h w t=1时 会出现图像像素错位问题,所以新增了个使用Image.Resize的V2版本
Args:
image (Image.Image): 输入图像
target_height (int): 目标高
target_width (int): 目标宽
Returns:
Image.Image: 动态截取、resize生成的图像
"""
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
ndim = images.ndim
if ndim == 4:
b, c, h, w = images.shape
elif ndim == 5:
b, c, t, h, w = images.shape
images = rearrange(images, "b c t h w->(b t) c h w")
else:
raise ValueError(f"ndim only support 4, 5 but given {ndim}")
image_width_heigt_ratio = w / h
target_width_height_ratio = target_width / target_height
if image_width_heigt_ratio >= target_width_height_ratio:
y1 = 0
y2 = h - 1
x1 = math.ceil((w - h * target_width / target_height) / 2)
x2 = math.ceil(w - (w - h * target_width / target_height) / 2)
else:
x1 = 0
x2 = w - 1
y1 = math.ceil((h - w * target_height / target_width) / 2)
y2 = math.ceil(h - (h - w * target_height / target_width) / 2)
x1 = max(0, x1)
x2 = min(x2, w - 1)
y1 = max(0, y1)
y2 = min(y2, h - 1)
images = images[:, :, y1:y2, x1:x2]
images = torch.nn.functional.interpolate(
images,
(target_height, target_width),
mode=mode, # align_corners=align_corners
)
if ndim == 5:
images = rearrange(images, "(b t) c h w->b c t h w", b=b, t=t)
return images
def his_match(src: np.ndarray, dst: np.ndarray) -> np.ndarray:
src = src * 255.0
dst = dst * 255.0
src = src.astype(np.uint8)
dst = dst.astype(np.uint8)
res = np.zeros_like(dst)
cdf_src = np.zeros((3, 256))
cdf_dst = np.zeros((3, 256))
cdf_res = np.zeros((3, 256))
kw = dict(bins=256, range=(0, 256), density=True)
for ch in range(3):
his_src, _ = np.histogram(src[:, :, ch], **kw)
hist_dst, _ = np.histogram(dst[:, :, ch], **kw)
cdf_src[ch] = np.cumsum(his_src)
cdf_dst[ch] = np.cumsum(hist_dst)
index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left")
np.clip(index, 0, 255, out=index)
res[:, :, ch] = index[dst[:, :, ch]]
his_res, _ = np.histogram(res[:, :, ch], **kw)
cdf_res[ch] = np.cumsum(his_res)
return res / 255.0