Spaces:
Sleeping
Sleeping
File size: 2,172 Bytes
02f3f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
import torch
import mlflow
import numpy as np
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
from torchvision import transforms
from model import Generator
EXPERIMENT_NAME = "Colorizer_Experiment"
RUN_ID = "your_run_id_here" # Replace with your actual run ID
def setup_mlflow():
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
else:
experiment_id = experiment.experiment_id
return experiment_id
def load_model(run_id, device):
print(f"Loading model from run: {run_id}")
model_uri = f"runs:/{run_id}/generator_model"
model = mlflow.pytorch.load_model(model_uri, map_location=device)
return model
def preprocess_image(image):
img = Image.fromarray(image).convert("RGB")
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
img_tensor = transform(img)
lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy())
L = lab_img[:,:,0]
L = (L - 50) / 50
L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float()
return L
def postprocess_output(L, ab):
L = L.squeeze().cpu().numpy()
ab = ab.squeeze().cpu().numpy()
L = (L + 1.) * 50.
ab = ab * 128.
Lab = np.concatenate([L[..., np.newaxis], ab], axis=2)
rgb_img = lab2rgb(Lab)
return (rgb_img * 255).astype(np.uint8)
def colorize_image(image, model, device):
L = preprocess_image(image).to(device)
with torch.no_grad():
ab = model(L)
colorized = postprocess_output(L, ab)
return colorized
def setup_gradio_app(run_id, device):
model = load_model(run_id, device)
def gradio_colorize(input_image):
colorized = colorize_image(input_image, model, device)
return Image.fromarray(colorized)
iface = gr.Interface(
fn=gradio_colorize,
inputs=gr.Image(label="Upload a grayscale image"),
outputs=gr.Image(label="Colorized Image"),
title="Image Colorizer",
description="Upload a grayscale image and get a colorized version!",
)
return iface |