pix2pixcolorizer / app-mlflow.py
Rohil Bansal
huggingface spaces commit.
02f3f24
import gradio as gr
import torch
import mlflow
import numpy as np
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
from torchvision import transforms
from model import Generator
EXPERIMENT_NAME = "Colorizer_Experiment"
RUN_ID = "your_run_id_here" # Replace with your actual run ID
def setup_mlflow():
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
else:
experiment_id = experiment.experiment_id
return experiment_id
def load_model(run_id, device):
print(f"Loading model from run: {run_id}")
model_uri = f"runs:/{run_id}/generator_model"
model = mlflow.pytorch.load_model(model_uri, map_location=device)
return model
def preprocess_image(image):
img = Image.fromarray(image).convert("RGB")
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
img_tensor = transform(img)
lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
L = lab_img[:,:,0]
L = (L - 50) / 50
L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
return L
def postprocess_output(L, ab):
L = L.squeeze().cpu().numpy()
ab = ab.squeeze().cpu().numpy()
L = (L + 1.) * 50.
ab = ab * 128.
Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
rgb_img = lab2rgb(Lab)
return (rgb_img * 255).astype(np.uint8)
def colorize_image(image, model, device):
L = preprocess_image(image).to(device)
with torch.no_grad():
ab = model(L)
colorized = postprocess_output(L, ab)
return colorized
def setup_gradio_app(run_id, device):
model = load_model(run_id, device)
def gradio_colorize(input_image):
colorized = colorize_image(input_image, model, device)
return Image.fromarray(colorized)
iface = gr.Interface(
fn=gradio_colorize,
inputs=gr.Image(label="Upload a grayscale image"),
outputs=gr.Image(label="Colorized Image"),
title="Image Colorizer",
description="Upload a grayscale image and get a colorized version!",
)
return iface