File size: 4,458 Bytes
02f3f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
import os
from torch import nn

# Define the model architecture (same as in the training script)
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
        super(UNetBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels) if bn else None
        self.dropout = nn.Dropout(0.5) if dropout else None
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        if self.dropout:
            x = self.dropout(x)
        return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = UNetBlock(1, 64, bn=False)
        self.down2 = UNetBlock(64, 128)
        self.down3 = UNetBlock(128, 256)
        self.down4 = UNetBlock(256, 512)
        self.down5 = UNetBlock(512, 512)
        self.down6 = UNetBlock(512, 512)
        self.down7 = UNetBlock(512, 512)
        self.down8 = UNetBlock(512, 512, bn=False)
        
        self.up1 = UNetBlock(512, 512, down=False, dropout=True)
        self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
        self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
        self.up4 = UNetBlock(1024, 512, down=False)
        self.up5 = UNetBlock(1024, 256, down=False)
        self.up6 = UNetBlock(512, 128, down=False)
        self.up7 = UNetBlock(256, 64, down=False)
        self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        
        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        return torch.tanh(self.up8(torch.cat([u7, d1], 1)))

# Load the checkpoint
def load_checkpoint(filename, generator, map_location):
    if os.path.isfile(filename):
        print(f"Loading checkpoint '{filename}'")
        checkpoint = torch.load(filename, map_location=map_location)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
    else:
        print(f"No checkpoint found at '{filename}'")

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
checkpoint_path = "checkpoints/latest_checkpoint.pth.tar"
load_checkpoint(checkpoint_path, generator, map_location=device)
generator.eval()

# Define the transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor()
])

# Define the inference function
def colorize_image(input_image):
    try:
        original_size = input_image.size
        input_image = transform(input_image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = generator(input_image)
        output = output.squeeze(0).cpu().numpy()
        L = input_image.squeeze(0).cpu().numpy()
        L = (L + 1.) * 50.
        ab = output * 128.
        Lab = np.concatenate([L, ab], axis=0).transpose(1, 2, 0)
        rgb_image = lab2rgb(Lab)
        rgb_image = Image.fromarray((rgb_image * 255).astype(np.uint8))
        rgb_image = rgb_image.resize(original_size, Image.LANCZOS)
        return rgb_image
    except Exception as e:
        print(f"Error in colorize_image: {str(e)}")
        return None

# Create the Gradio interface
iface = gr.Interface(
    fn=colorize_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Image Colorizer",
    description="Upload a grayscale image to colorize it."
)

# Launch the app
if __name__ == "__main__":
    iface.launch()