Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import mlflow | |
import mlflow.pytorch | |
from PIL import Image | |
import numpy as np | |
from skimage.color import rgb2lab, lab2rgb | |
from torchvision import transforms | |
import argparse | |
from model import Generator | |
EXPERIMENT_NAME = "Colorizer_Experiment" | |
def setup_mlflow(): | |
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) | |
if experiment is None: | |
experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) | |
else: | |
experiment_id = experiment.experiment_id | |
return experiment_id | |
def load_model(run_id, device): | |
print(f"Loading model from run: {run_id}") | |
model_uri = f"runs:/{run_id}/generator_model" | |
model = mlflow.pytorch.load_model(model_uri, map_location=device) | |
return model | |
# Configuration variables | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
RUN_ID = "your_run_id_here" # Replace with the actual run ID | |
IMAGE_PATH = "path/to/your/image.jpg" # Replace with the path to your input image | |
SAVE_MODEL = False | |
SERVE_MODEL = False | |
SERVE_PORT = 5000 | |
def preprocess_image(image_path): | |
img = Image.open(image_path).convert("RGB") | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor() | |
]) | |
img_tensor = transform(img) | |
lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy()) | |
L = lab_img[:,:,0] | |
L = (L - 50) / 50 | |
L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float() | |
return L | |
def postprocess_output(L, ab): | |
L = L.squeeze().cpu().numpy() | |
ab = ab.squeeze().cpu().numpy() | |
L = (L + 1.) * 50. | |
ab = ab * 128. | |
Lab = np.concatenate([L[..., np.newaxis], ab], axis=2) | |
rgb_img = lab2rgb(Lab) | |
return (rgb_img * 255).astype(np.uint8) | |
def colorize_image(model, image_path, device): | |
L = preprocess_image(image_path).to(device) | |
with torch.no_grad(): | |
ab = model(L) | |
colorized = postprocess_output(L, ab) | |
return colorized | |
def save_model(model, run_id): | |
with mlflow.start_run(run_id=run_id): | |
# Log the model | |
mlflow.pytorch.log_model(model, "model") | |
# Register the model | |
model_uri = f"runs:/{run_id}/model" | |
mlflow.register_model(model_uri, "colorizer_model") | |
print(f"Model saved and registered with run_id: {run_id}") | |
def serve_model(run_id, port=5000): | |
model_uri = f"runs:/{run_id}/model" | |
mlflow.pytorch.serve(model_uri, port=port) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Colorizer Inference") | |
parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model") | |
parser.add_argument("--image_path", type=str, required=True, help="Path to the input grayscale image") | |
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", | |
help="Device to use for inference (cuda/cpu)") | |
args = parser.parse_args() | |
device = torch.device(args.device) | |
print(f"Using device: {device}") | |
# If run_id is not provided, try to load it from the file | |
if not args.run_id: | |
try: | |
with open("latest_run_id.txt", "r") as f: | |
args.run_id = f.read().strip() | |
except FileNotFoundError: | |
print("No run ID provided and couldn't find latest_run_id.txt") | |
exit(1) | |
experiment_id = setup_mlflow() | |
with mlflow.start_run(experiment_id=experiment_id, run_name="inference_run"): | |
try: | |
model = load_model(args.run_id, device) | |
colorized = colorize_image(model, args.image_path, device) | |
output_path = f"colorized_{os.path.basename(args.image_path)}" | |
Image.fromarray(colorized).save(output_path) | |
print(f"Colorized image saved as: {output_path}") | |
mlflow.log_artifact(output_path) | |
mlflow.log_param("input_image", args.image_path) | |
mlflow.log_param("model_run_id", args.run_id) | |
except Exception as e: | |
print(f"Error during inference: {str(e)}") | |
mlflow.log_param("error", str(e)) | |
finally: | |
mlflow.end_run() |