In [1]:
import os
os.chdir('../')

In [2]:
%pwd

'c:\\mlops project\\image-colorization-mlops'

In [3]:
# Assuming all necessary imports are here
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    test_data_path: Path
    train_data_path: Path
    LEARNING_RATE: float
    LAMBDA_RECON: int
    DISPLAY_STEP: int
    IMAGE_SIZE: list
    INPUT_CHANNELS: int
    OUTPUT_CHANNELS: int
    EPOCH: int
    BATCH_SIZE : int

In [4]:
from src.imagecolorization.constants import *
from src.imagecolorization.utils.common import read_yaml, create_directories

class ConfigurationManager:
    def __init__(self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self) -> ModelTrainerConfig:
        config = self.config.model_trainer
        params = self.params
        
        create_directories([config.root_dir])
        
        # Convert LEARNING_RATE to float explicitly
        learning_rate = float(params.LEARNING_RATE)
        
        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            test_data_path=config.test_data_path,
            train_data_path=config.train_data_path,
            LEARNING_RATE=learning_rate,  # Use the converted float value
            LAMBDA_RECON=params.LAMBDA_RECON,
            DISPLAY_STEP=params.DISPLAY_STEP,
            IMAGE_SIZE=params.IMAGE_SIZE,
            INPUT_CHANNELS=params.INPUT_CHANNELS,
            OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,
            EPOCH=params.EPOCH,
            BATCH_SIZE= params.BATCH_SIZE
        )
        return model_trainer_config


In [5]:
import torch
import numpy as np
from skimage.color import rgb2lab, lab2rgb
def lab_to_rgb(L, ab):
    L = L * 100
    ab = (ab - 0.5) * 128 * 2
    Lab = torch.cat([L, ab], dim = 2).numpy()
    rgb_img = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_img.append(img_rgb)
        
    return np.stack(rgb_img, axis = 0)

In [6]:
import matplotlib.pyplot as plt
def display_progress(cond, real, fake, current_epoch = 0, figsize=(20,15)):
    """
    Save cond, real (original) and generated (fake)
    images in one panel 
    """
    cond = cond.detach().cpu().permute(1, 2, 0)   
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)
    
    images = [cond, real, fake]
    titles = ['input','real','generated']
    print(f'Epoch: {current_epoch}')
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    for idx,img in enumerate(images):
        if idx == 0:
            ab = torch.zeros((224,224,2))
            img = torch.cat([images[0]* 100, ab], dim=2).numpy()
            imgan = lab2rgb(img)
        else:
            imgan = lab_to_rgb(images[0],img)
        ax[idx].imshow(imgan)
        ax[idx].axis("off")
    for idx, title in enumerate(titles):    
        ax[idx].set_title('{}'.format(title))
    plt.show()


In [7]:
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import models
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
from scipy.stats import entropy
import pytorch_lightning as pl
from torchsummary import summary
from src.imagecolorization.conponents.model_building import Generator, Critic
from src.imagecolorization.conponents.data_tranformation import ImageColorizationDataset
from src.imagecolorization.logging import logger
import gc





class CWGAN(pl.LightningModule):
    def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=100, display_step=10, lambda_gp=10, lambda_r1=10):
        super().__init__()
        self.save_hyperparameters()
        self.display_step = display_step
        self.generator = Generator(in_channels, out_channels)
        self.critic = Critic(in_channels + out_channels)
        self.lambda_recon = lambda_recon
        self.lambda_gp = lambda_gp
        self.lambda_r1 = lambda_r1
        self.recon_criterion = nn.L1Loss()
        self.generator_losses, self.critic_losses = [], []
        self.automatic_optimization = False  # Disable automatic optimization

    def configure_optimizers(self):
        optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
        optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
        return [optimizer_C, optimizer_G]

    def generator_step(self, real_images, conditioned_images, optimizer_G):
        optimizer_G.zero_grad()
        fake_images = self.generator(conditioned_images)
        recon_loss = self.recon_criterion(fake_images, real_images)
        recon_loss.backward()
        optimizer_G.step()
        self.generator_losses.append(recon_loss.item())

    def critic_step(self, real_images, conditioned_images, optimizer_C):
        optimizer_C.zero_grad()
        fake_images = self.generator(conditioned_images)

        # Separate L and ab channels
        l_real = conditioned_images
        ab_real = real_images
        ab_fake = fake_images

        # Compute logits
        fake_logits = self.critic(ab_fake, l_real)  # Pass two arguments
        real_logits = self.critic(ab_real, l_real)  # Pass two arguments

        # Compute the loss for the critic
        loss_C = real_logits.mean() - fake_logits.mean()

        # Compute the gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True).to(real_images.device)
        interpolated = (alpha * ab_real + (1 - alpha) * ab_fake.detach()).requires_grad_(True)
        interpolated_logits = self.critic(interpolated, l_real)

        gradients = torch.autograd.grad(outputs=interpolated_logits, inputs=interpolated,
                                        grad_outputs=torch.ones_like(interpolated_logits), create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(len(gradients), -1)
        gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        loss_C += self.lambda_gp * gradients_penalty

        # Compute the R1 regularization loss
        r1_reg = gradients.pow(2).sum(1).mean()
        loss_C += self.lambda_r1 * r1_reg

        # Backpropagation
        loss_C.backward()
        optimizer_C.step()
        self.critic_losses.append(loss_C.item())

    def training_step(self, batch, batch_idx):
        optimizer_C, optimizer_G = self.optimizers()
        real_images, conditioned_images = batch

        self.critic_step(real_images, conditioned_images, optimizer_C)
        self.generator_step(real_images, conditioned_images, optimizer_G)

        if self.current_epoch % self.display_step == 0 and batch_idx == 0:
            with torch.no_grad():
                fake_images = self.generator(conditioned_images)
                display_progress(conditioned_images[0], real_images[0], fake_images[0], self.current_epoch)

    def on_epoch_end(self):
        gc.collect()
        torch.cuda.empty_cache()

    def forward(self, x):
        return self.generator(x)
    

In [10]:
import os
class ModelTrainer:
    def __init__(self, config):
        self.config = config
        
    def load_datasets(self):
        self.train_dataset = torch.load(self.config.train_data_path)
        self.test_dataset = torch.load(self.config.test_data_path)
    
        # Ensure these are actually datasets, not dataloaders
        if isinstance(self.train_dataset, DataLoader):
            self.train_dataset = self.train_dataset.dataset
        if isinstance(self.test_dataset, DataLoader):
            self.test_dataset = self.test_dataset.dataset
        
    def create_dataloaders(self):
        self.train_dataloader = DataLoader(
            self.train_dataset, 
            batch_size=self.config.BATCH_SIZE, 
            shuffle=True,
        )
        self.test_dataloader = DataLoader(
            self.test_dataset, 
            batch_size=self.config.BATCH_SIZE, 
            shuffle=False,
        )
        
        for batch in self.test_dataloader:
            ab, L = batch
            print(f"Train loader batch - ab shape: {ab.shape}, L shape: {L.shape}")
            break  # Remove break to check shapes for all batches if needed

        for batch in self.train_dataloader:
            ab, L = batch
            print(f"Test loader batch - ab shape: {ab.shape}, L shape: {L.shape}")
            break
        
        
        
    def initialize_model(self):
        self.model = CWGAN(
            in_channels=self.config.INPUT_CHANNELS,
            out_channels=self.config.OUTPUT_CHANNELS,
            learning_rate=self.config.LEARNING_RATE,
            lambda_recon=self.config.LAMBDA_RECON,
            display_step=self.config.DISPLAY_STEP,
            lambda_gp=10,  # Default value, you can make it configurable
            lambda_r1=10   # Default value, you can make it configurable
        )
        
    def train_model(self):
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=self.config.root_dir,
            filename='cwgan-{epoch:02d}-{generator_loss:.2f}',
            save_top_k=-1,  # Save all checkpoints
            verbose=True
        )
        
        trainer = pl.Trainer(
            max_epochs=self.config.EPOCH,
            callbacks=[checkpoint_callback],
            
        )
        
        trainer.fit(self.model, self.train_dataloader)
        
    def save_model(self):
        trained_model_dir = self.config.root_dir
        os.makedirs(trained_model_dir, exist_ok=True)
    
        generator_path = os.path.join(trained_model_dir, "cwgan_generator_final.pt")
        critic_path = os.path.join(trained_model_dir, "cwgan_critic_final.pt")
    
        torch.save(self.model.generator.state_dict(), generator_path)
        torch.save(self.model.critic.state_dict(), critic_path)
        logger.info(f"Final models saved at {trained_model_dir}")

In [11]:
try:
    config_manager = ConfigurationManager()
    model_trainer_config = config_manager.get_model_trainer_config()
    model_trainer = ModelTrainer(config=model_trainer_config)
    model_trainer.load_datasets()
    model_trainer.create_dataloaders()
    
    

    model_trainer.initialize_model()
    model_trainer.train_model()
    model_trainer.save_model()
except Exception as e:
    logger.error(f"Error during model training: {e}")
    raise e

[2024-08-24 19:39:49,395: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-08-24 19:39:49,400: INFO: common: yaml file: params.yaml loaded successfully]
[2024-08-24 19:39:49,404: INFO: common: created directory at: artifacts]
[2024-08-24 19:39:49,405: INFO: common: created directory at: artifacts/trained_model]


  self.train_dataset = torch.load(self.config.train_data_path)
  self.test_dataset = torch.load(self.config.test_data_path)


[2024-08-24 19:39:58,617: ERROR: 129140692: Error during model training: pic should be 2/3 dimensional. Got 4 dimensions.]


ValueError: pic should be 2/3 dimensional. Got 4 dimensions.