Unico3D / gradio_app /utils.py
cavargas10's picture
Upload 56 files
1f30907 verified
import torch
import numpy as np
from PIL import Image
import gc
import numpy as np
import numpy as np
from PIL import Image
from scripts.refine_lr_to_sr import run_sr_fast
GRADIO_CACHE = "/tmp/gradio/"
def clean_up():
torch.cuda.empty_cache()
gc.collect()
def remove_color(arr):
if arr.shape[-1] == 4:
arr = arr[..., :3]
# calc diffs
base = arr[0, 0]
diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
alpha = (diffs <= 80)
arr[alpha] = 255
alpha = ~alpha
arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
return arr
def simple_remove(imgs, run_sr=True):
"""Only works for normal"""
if not isinstance(imgs, list):
imgs = [imgs]
single_input = True
else:
single_input = False
if run_sr:
imgs = run_sr_fast(imgs)
rets = []
for img in imgs:
arr = np.array(img)
arr = remove_color(arr)
rets.append(Image.fromarray(arr.astype(np.uint8)))
if single_input:
return rets[0]
return rets
def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
new_image = Image.new("RGBA", rgba.size, bkgd)
new_image.paste(rgba, (0, 0), rgba)
new_image = new_image.convert('RGB')
return new_image
def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
rgb_white = rgba_to_rgb(rgba, bkgd)
new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
return new_rgba
def split_image(image, rows=None, cols=None):
"""
inverse function of make_image_grid
"""
# image is in square
if rows is None and cols is None:
# image.size [W, H]
rows = 1
cols = image.size[0] // image.size[1]
assert cols * image.size[1] == image.size[0]
subimg_size = image.size[1]
elif rows is None:
subimg_size = image.size[0] // cols
rows = image.size[1] // subimg_size
assert rows * subimg_size == image.size[1]
elif cols is None:
subimg_size = image.size[1] // rows
cols = image.size[0] // subimg_size
assert cols * subimg_size == image.size[0]
else:
subimg_size = image.size[1] // rows
assert cols * subimg_size == image.size[0]
subimgs = []
for i in range(rows):
for j in range(cols):
subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
subimgs.append(subimg)
return subimgs
def make_image_grid(images, rows=None, cols=None, resize=None):
if rows is None and cols is None:
rows = 1
cols = len(images)
if rows is None:
rows = len(images) // cols
if len(images) % cols != 0:
rows += 1
if cols is None:
cols = len(images) // rows
if len(images) % rows != 0:
cols += 1
total_imgs = rows * cols
if total_imgs > len(images):
images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
w, h = images[0].size
grid = Image.new(images[0].mode, size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid