Spaces:
Sleeping
Sleeping
File size: 4,458 Bytes
02f3f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
import os
from torch import nn
# Define the model architecture (same as in the training script)
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
super(UNetBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.bn = nn.BatchNorm2d(out_channels) if bn else None
self.dropout = nn.Dropout(0.5) if dropout else None
self.down = down
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.dropout:
x = self.dropout(x)
return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = UNetBlock(1, 64, bn=False)
self.down2 = UNetBlock(64, 128)
self.down3 = UNetBlock(128, 256)
self.down4 = UNetBlock(256, 512)
self.down5 = UNetBlock(512, 512)
self.down6 = UNetBlock(512, 512)
self.down7 = UNetBlock(512, 512)
self.down8 = UNetBlock(512, 512, bn=False)
self.up1 = UNetBlock(512, 512, down=False, dropout=True)
self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
self.up4 = UNetBlock(1024, 512, down=False)
self.up5 = UNetBlock(1024, 256, down=False)
self.up6 = UNetBlock(512, 128, down=False)
self.up7 = UNetBlock(256, 64, down=False)
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8)
u2 = self.up2(torch.cat([u1, d7], 1))
u3 = self.up3(torch.cat([u2, d6], 1))
u4 = self.up4(torch.cat([u3, d5], 1))
u5 = self.up5(torch.cat([u4, d4], 1))
u6 = self.up6(torch.cat([u5, d3], 1))
u7 = self.up7(torch.cat([u6, d2], 1))
return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
# Load the checkpoint
def load_checkpoint(filename, generator, map_location):
if os.path.isfile(filename):
print(f"Loading checkpoint '{filename}'")
checkpoint = torch.load(filename, map_location=map_location)
generator.load_state_dict(checkpoint['generator_state_dict'])
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
else:
print(f"No checkpoint found at '{filename}'")
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
checkpoint_path = "checkpoints/latest_checkpoint.pth.tar"
load_checkpoint(checkpoint_path, generator, map_location=device)
generator.eval()
# Define the transformation
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(num_output_channels=1), # Convert to grayscale
transforms.ToTensor()
])
# Define the inference function
def colorize_image(input_image):
try:
original_size = input_image.size
input_image = transform(input_image).unsqueeze(0).to(device)
with torch.no_grad():
output = generator(input_image)
output = output.squeeze(0).cpu().numpy()
L = input_image.squeeze(0).cpu().numpy()
L = (L + 1.) * 50.
ab = output * 128.
Lab = np.concatenate([L, ab], axis=0).transpose(1, 2, 0)
rgb_image = lab2rgb(Lab)
rgb_image = Image.fromarray((rgb_image * 255).astype(np.uint8))
rgb_image = rgb_image.resize(original_size, Image.LANCZOS)
return rgb_image
except Exception as e:
print(f"Error in colorize_image: {str(e)}")
return None
# Create the Gradio interface
iface = gr.Interface(
fn=colorize_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Colorizer",
description="Upload a grayscale image to colorize it."
)
# Launch the app
if __name__ == "__main__":
iface.launch() |