Spaces:
Runtime error
Runtime error
HAMIM-ML
commited on
Commit
·
e623e7e
1
Parent(s):
cf8627d
model trainer fixed
Browse files- config/config.yaml +5 -0
- lightning_logs/version_0/events.out.tfevents.1724493340.Hakim.9156.0 +0 -0
- lightning_logs/version_0/hparams.yaml +7 -0
- lightning_logs/version_1/events.out.tfevents.1724493420.Hakim.2840.0 +0 -0
- lightning_logs/version_1/hparams.yaml +7 -0
- lightning_logs/version_10/events.out.tfevents.1724505838.Hakim.4444.2 +0 -0
- lightning_logs/version_10/hparams.yaml +7 -0
- lightning_logs/version_11/events.out.tfevents.1724505997.Hakim.4444.3 +0 -0
- lightning_logs/version_11/hparams.yaml +7 -0
- lightning_logs/version_12/events.out.tfevents.1724506091.Hakim.4444.4 +0 -0
- lightning_logs/version_12/hparams.yaml +7 -0
- lightning_logs/version_13/events.out.tfevents.1724506200.Hakim.4444.5 +0 -0
- lightning_logs/version_13/hparams.yaml +7 -0
- lightning_logs/version_14/events.out.tfevents.1724506514.Hakim.19820.0 +0 -0
- lightning_logs/version_14/hparams.yaml +7 -0
- lightning_logs/version_15/events.out.tfevents.1724521354.Hakim.8444.0 +0 -0
- lightning_logs/version_15/hparams.yaml +7 -0
- lightning_logs/version_16/events.out.tfevents.1724521651.Hakim.8444.1 +0 -0
- lightning_logs/version_16/hparams.yaml +7 -0
- lightning_logs/version_17/events.out.tfevents.1724521762.Hakim.8444.2 +0 -0
- lightning_logs/version_17/hparams.yaml +7 -0
- lightning_logs/version_18/events.out.tfevents.1724521935.Hakim.8444.3 +0 -0
- lightning_logs/version_18/hparams.yaml +7 -0
- lightning_logs/version_19/events.out.tfevents.1724522010.Hakim.28412.0 +0 -0
- lightning_logs/version_19/hparams.yaml +7 -0
- lightning_logs/version_2/events.out.tfevents.1724493545.Hakim.2840.1 +0 -0
- lightning_logs/version_2/hparams.yaml +7 -0
- lightning_logs/version_3/events.out.tfevents.1724494109.Hakim.14856.0 +0 -0
- lightning_logs/version_3/hparams.yaml +7 -0
- lightning_logs/version_4/events.out.tfevents.1724494965.Hakim.16704.0 +0 -0
- lightning_logs/version_4/hparams.yaml +7 -0
- lightning_logs/version_5/events.out.tfevents.1724495112.Hakim.17000.0 +0 -0
- lightning_logs/version_5/hparams.yaml +7 -0
- lightning_logs/version_6/events.out.tfevents.1724495583.Hakim.2896.0 +0 -0
- lightning_logs/version_6/hparams.yaml +7 -0
- lightning_logs/version_7/events.out.tfevents.1724505035.Hakim.3288.0 +0 -0
- lightning_logs/version_7/hparams.yaml +7 -0
- lightning_logs/version_8/events.out.tfevents.1724505263.Hakim.4444.0 +0 -0
- lightning_logs/version_8/hparams.yaml +7 -0
- lightning_logs/version_9/events.out.tfevents.1724505733.Hakim.4444.1 +0 -0
- lightning_logs/version_9/hparams.yaml +7 -0
- params.yaml +8 -0
- research/data_transformation.ipynb +33 -51
- research/model_trainer.ipynb +0 -0
- src/imagecolorization/conponents/data_tranformation.py +27 -18
- src/imagecolorization/pipeline/stage02_data_transformation.py +8 -2
config/config.yaml
CHANGED
@@ -14,5 +14,10 @@ data_transformation:
|
|
14 |
|
15 |
model_building:
|
16 |
root_dir : artifacts/model
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
|
|
14 |
|
15 |
model_building:
|
16 |
root_dir : artifacts/model
|
17 |
+
|
18 |
+
model_trainer:
|
19 |
+
root_dir : artifacts/trained_model
|
20 |
+
test_data_path : artifacts/data_transformation/test_dataset.pt
|
21 |
+
train_data_path : artifacts/data_transformation/train_dataset.pt
|
22 |
|
23 |
|
lightning_logs/version_0/events.out.tfevents.1724493340.Hakim.9156.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_0/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_1/events.out.tfevents.1724493420.Hakim.2840.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_1/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_10/events.out.tfevents.1724505838.Hakim.4444.2
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_10/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_11/events.out.tfevents.1724505997.Hakim.4444.3
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_11/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_12/events.out.tfevents.1724506091.Hakim.4444.4
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_12/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_13/events.out.tfevents.1724506200.Hakim.4444.5
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_13/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_14/events.out.tfevents.1724506514.Hakim.19820.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_14/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_15/events.out.tfevents.1724521354.Hakim.8444.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_15/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_16/events.out.tfevents.1724521651.Hakim.8444.1
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_16/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_17/events.out.tfevents.1724521762.Hakim.8444.2
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_17/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_18/events.out.tfevents.1724521935.Hakim.8444.3
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_18/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_19/events.out.tfevents.1724522010.Hakim.28412.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_19/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_2/events.out.tfevents.1724493545.Hakim.2840.1
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_2/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_3/events.out.tfevents.1724494109.Hakim.14856.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_3/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_4/events.out.tfevents.1724494965.Hakim.16704.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_4/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_5/events.out.tfevents.1724495112.Hakim.17000.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_5/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_6/events.out.tfevents.1724495583.Hakim.2896.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_6/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_7/events.out.tfevents.1724505035.Hakim.3288.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_7/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_8/events.out.tfevents.1724505263.Hakim.4444.0
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_8/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
lightning_logs/version_9/events.out.tfevents.1724505733.Hakim.4444.1
ADDED
Binary file (683 Bytes). View file
|
|
lightning_logs/version_9/hparams.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
display_step: 10
|
2 |
+
in_channels: 1
|
3 |
+
lambda_gp: 10
|
4 |
+
lambda_r1: 10
|
5 |
+
lambda_recon: 100
|
6 |
+
learning_rate: 0.0002
|
7 |
+
out_channels: 2
|
params.yaml
CHANGED
@@ -23,3 +23,11 @@ OUTPUT_CHANNELS: 2
|
|
23 |
|
24 |
# Critic Parameters
|
25 |
IN_CHANNELS: 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# Critic Parameters
|
25 |
IN_CHANNELS: 3
|
26 |
+
|
27 |
+
# model train
|
28 |
+
LEARNING_RATE : 2e-4
|
29 |
+
LAMBDA_RECON : 100
|
30 |
+
DISPLAY_STEP : 10
|
31 |
+
INPUT_CHANNELS : 1
|
32 |
+
OUTPUT_CHANNELS : 2
|
33 |
+
EPOCH : 1
|
research/data_transformation.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
@@ -12,7 +12,7 @@
|
|
12 |
},
|
13 |
{
|
14 |
"cell_type": "code",
|
15 |
-
"execution_count":
|
16 |
"metadata": {},
|
17 |
"outputs": [
|
18 |
{
|
@@ -21,7 +21,7 @@
|
|
21 |
"'c:\\\\mlops project\\\\image-colorization-mlops'"
|
22 |
]
|
23 |
},
|
24 |
-
"execution_count":
|
25 |
"metadata": {},
|
26 |
"output_type": "execute_result"
|
27 |
}
|
@@ -32,7 +32,7 @@
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
-
"execution_count":
|
36 |
"metadata": {},
|
37 |
"outputs": [],
|
38 |
"source": [
|
@@ -51,7 +51,7 @@
|
|
51 |
},
|
52 |
{
|
53 |
"cell_type": "code",
|
54 |
-
"execution_count":
|
55 |
"metadata": {},
|
56 |
"outputs": [],
|
57 |
"source": [
|
@@ -88,13 +88,13 @@
|
|
88 |
},
|
89 |
{
|
90 |
"cell_type": "code",
|
91 |
-
"execution_count":
|
92 |
"metadata": {},
|
93 |
"outputs": [],
|
94 |
"source": [
|
95 |
"import numpy as np\n",
|
96 |
"import torch\n",
|
97 |
-
"from torch.utils.data import
|
98 |
"from torchvision import transforms\n",
|
99 |
"\n",
|
100 |
"class ImageColorizationDataset(Dataset):\n",
|
@@ -106,24 +106,26 @@
|
|
106 |
" return len(self.dataset[0])\n",
|
107 |
" \n",
|
108 |
" def __getitem__(self, idx):\n",
|
109 |
-
" L = np.array(self.dataset[0][idx]).reshape(self.image_size
|
110 |
" L = transforms.ToTensor()(L)\n",
|
111 |
" \n",
|
112 |
" ab = np.array(self.dataset[1][idx])\n",
|
113 |
" ab = transforms.ToTensor()(ab)\n",
|
114 |
" \n",
|
115 |
-
" return ab, L
|
116 |
]
|
117 |
},
|
118 |
{
|
119 |
"cell_type": "code",
|
120 |
-
"execution_count":
|
121 |
"metadata": {},
|
122 |
"outputs": [],
|
123 |
"source": [
|
124 |
"from torch.utils.data import DataLoader\n",
|
125 |
"import gc\n",
|
126 |
"import os\n",
|
|
|
|
|
127 |
"from src.imagecolorization.logging import logger\n",
|
128 |
"\n",
|
129 |
"class DataTransformation:\n",
|
@@ -137,71 +139,51 @@
|
|
137 |
" gc.collect()\n",
|
138 |
" return dataset\n",
|
139 |
" \n",
|
140 |
-
"
|
141 |
-
" def get_dataloader(self, dataset):\n",
|
142 |
" train_dataset = ImageColorizationDataset(\n",
|
143 |
" dataset=dataset,\n",
|
144 |
-
" image_size=self.config.IMAGE_SIZE\n",
|
145 |
" )\n",
|
146 |
" test_dataset = ImageColorizationDataset(\n",
|
147 |
" dataset=dataset,\n",
|
148 |
-
" image_size=self.config.IMAGE_SIZE\n",
|
149 |
-
" )\n",
|
150 |
-
" \n",
|
151 |
-
" train_loader = DataLoader(\n",
|
152 |
-
" train_dataset,\n",
|
153 |
-
" batch_size=self.config.BATCH_SIZE,\n",
|
154 |
-
" shuffle=True,\n",
|
155 |
-
" pin_memory=True\n",
|
156 |
" )\n",
|
157 |
-
" test_loader = DataLoader(\n",
|
158 |
-
" test_dataset,\n",
|
159 |
-
" batch_size=self.config.BATCH_SIZE,\n",
|
160 |
-
" shuffle=True,\n",
|
161 |
-
" pin_memory=True\n",
|
162 |
-
" )\n",
|
163 |
-
" \n",
|
164 |
-
" return train_loader, test_loader\n",
|
165 |
-
" \n",
|
166 |
-
" \n",
|
167 |
" \n",
|
|
|
168 |
" \n",
|
169 |
" \n",
|
170 |
" \n",
|
171 |
-
"
|
172 |
-
" def save_dataloaders(self, train_loader, test_loader):\n",
|
173 |
" # Ensure the directory exists\n",
|
174 |
" os.makedirs(self.config.root_dir, exist_ok=True)\n",
|
175 |
"\n",
|
176 |
-
"
|
177 |
-
"
|
178 |
"\n",
|
179 |
" try:\n",
|
180 |
-
" # Save the
|
181 |
-
" torch.save(
|
182 |
-
" torch.save(
|
183 |
"\n",
|
184 |
-
" logger.info(f\"Train
|
185 |
-
" logger.info(f\"Test
|
186 |
" except Exception as e:\n",
|
187 |
-
" logger.error(f\"Error saving
|
188 |
-
" raise e
|
189 |
]
|
190 |
},
|
191 |
{
|
192 |
"cell_type": "code",
|
193 |
-
"execution_count":
|
194 |
"metadata": {},
|
195 |
"outputs": [
|
196 |
{
|
197 |
"name": "stdout",
|
198 |
"output_type": "stream",
|
199 |
"text": [
|
200 |
-
"[2024-08-
|
201 |
-
"[2024-08-
|
202 |
-
"[2024-08-
|
203 |
-
"[2024-08-
|
204 |
-
"[2024-08-
|
205 |
]
|
206 |
}
|
207 |
],
|
@@ -215,10 +197,10 @@
|
|
215 |
" dataset = data_transformation.load_data()\n",
|
216 |
" \n",
|
217 |
" # Get the dataloader using the loaded dataset\n",
|
218 |
-
"
|
219 |
" \n",
|
220 |
-
" # Perform any further operations (e.g., saving the
|
221 |
-
" data_transformation.
|
222 |
" \n",
|
223 |
"except Exception as e:\n",
|
224 |
" raise e\n"
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
|
|
12 |
},
|
13 |
{
|
14 |
"cell_type": "code",
|
15 |
+
"execution_count": 3,
|
16 |
"metadata": {},
|
17 |
"outputs": [
|
18 |
{
|
|
|
21 |
"'c:\\\\mlops project\\\\image-colorization-mlops'"
|
22 |
]
|
23 |
},
|
24 |
+
"execution_count": 3,
|
25 |
"metadata": {},
|
26 |
"output_type": "execute_result"
|
27 |
}
|
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
+
"execution_count": 4,
|
36 |
"metadata": {},
|
37 |
"outputs": [],
|
38 |
"source": [
|
|
|
51 |
},
|
52 |
{
|
53 |
"cell_type": "code",
|
54 |
+
"execution_count": 5,
|
55 |
"metadata": {},
|
56 |
"outputs": [],
|
57 |
"source": [
|
|
|
88 |
},
|
89 |
{
|
90 |
"cell_type": "code",
|
91 |
+
"execution_count": 6,
|
92 |
"metadata": {},
|
93 |
"outputs": [],
|
94 |
"source": [
|
95 |
"import numpy as np\n",
|
96 |
"import torch\n",
|
97 |
+
"from torch.utils.data import Dataset\n",
|
98 |
"from torchvision import transforms\n",
|
99 |
"\n",
|
100 |
"class ImageColorizationDataset(Dataset):\n",
|
|
|
106 |
" return len(self.dataset[0])\n",
|
107 |
" \n",
|
108 |
" def __getitem__(self, idx):\n",
|
109 |
+
" L = np.array(self.dataset[0][idx]).reshape(self.image_size)\n",
|
110 |
" L = transforms.ToTensor()(L)\n",
|
111 |
" \n",
|
112 |
" ab = np.array(self.dataset[1][idx])\n",
|
113 |
" ab = transforms.ToTensor()(ab)\n",
|
114 |
" \n",
|
115 |
+
" return ab, L"
|
116 |
]
|
117 |
},
|
118 |
{
|
119 |
"cell_type": "code",
|
120 |
+
"execution_count": 11,
|
121 |
"metadata": {},
|
122 |
"outputs": [],
|
123 |
"source": [
|
124 |
"from torch.utils.data import DataLoader\n",
|
125 |
"import gc\n",
|
126 |
"import os\n",
|
127 |
+
"import numpy as np\n",
|
128 |
+
"import torch\n",
|
129 |
"from src.imagecolorization.logging import logger\n",
|
130 |
"\n",
|
131 |
"class DataTransformation:\n",
|
|
|
139 |
" gc.collect()\n",
|
140 |
" return dataset\n",
|
141 |
" \n",
|
142 |
+
" def get_datasets(self, dataset):\n",
|
|
|
143 |
" train_dataset = ImageColorizationDataset(\n",
|
144 |
" dataset=dataset,\n",
|
|
|
145 |
" )\n",
|
146 |
" test_dataset = ImageColorizationDataset(\n",
|
147 |
" dataset=dataset,\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
" )\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
" \n",
|
150 |
+
" return train_dataset, test_dataset\n",
|
151 |
" \n",
|
152 |
" \n",
|
153 |
" \n",
|
154 |
+
" def save_datasets(self, train_dataset, test_dataset):\n",
|
|
|
155 |
" # Ensure the directory exists\n",
|
156 |
" os.makedirs(self.config.root_dir, exist_ok=True)\n",
|
157 |
"\n",
|
158 |
+
" train_dataset_path = os.path.join(self.config.root_dir, 'train_dataset.pt')\n",
|
159 |
+
" test_dataset_path = os.path.join(self.config.root_dir, 'test_dataset.pt')\n",
|
160 |
"\n",
|
161 |
" try:\n",
|
162 |
+
" # Save the datasets\n",
|
163 |
+
" torch.save(train_dataset, train_dataset_path)\n",
|
164 |
+
" torch.save(test_dataset, test_dataset_path)\n",
|
165 |
"\n",
|
166 |
+
" logger.info(f\"Train dataset saved at: {train_dataset_path}\")\n",
|
167 |
+
" logger.info(f\"Test dataset saved at: {test_dataset_path}\")\n",
|
168 |
" except Exception as e:\n",
|
169 |
+
" logger.error(f\"Error saving datasets: {str(e)}\")\n",
|
170 |
+
" raise e"
|
171 |
]
|
172 |
},
|
173 |
{
|
174 |
"cell_type": "code",
|
175 |
+
"execution_count": 12,
|
176 |
"metadata": {},
|
177 |
"outputs": [
|
178 |
{
|
179 |
"name": "stdout",
|
180 |
"output_type": "stream",
|
181 |
"text": [
|
182 |
+
"[2024-08-24 19:09:25,021: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
|
183 |
+
"[2024-08-24 19:09:25,024: INFO: common: yaml file: params.yaml loaded successfully]\n",
|
184 |
+
"[2024-08-24 19:09:25,026: INFO: common: created directory at: artifacts]\n",
|
185 |
+
"[2024-08-24 19:09:43,417: INFO: 3400243030: Train dataset saved at: artifacts/data_transformation\\train_dataset.pt]\n",
|
186 |
+
"[2024-08-24 19:09:43,440: INFO: 3400243030: Test dataset saved at: artifacts/data_transformation\\test_dataset.pt]\n"
|
187 |
]
|
188 |
}
|
189 |
],
|
|
|
197 |
" dataset = data_transformation.load_data()\n",
|
198 |
" \n",
|
199 |
" # Get the dataloader using the loaded dataset\n",
|
200 |
+
" train_dataset, test_dataset = data_transformation.get_datasets(dataset)\n",
|
201 |
" \n",
|
202 |
+
" # Perform any further operations (e.g., saving the dataset)\n",
|
203 |
+
" data_transformation.save_datasets(train_dataset, test_dataset)\n",
|
204 |
" \n",
|
205 |
"except Exception as e:\n",
|
206 |
" raise e\n"
|
research/model_trainer.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/imagecolorization/conponents/data_tranformation.py
CHANGED
@@ -18,7 +18,7 @@ class ImageColorizationDataset(Dataset):
|
|
18 |
return len(self.dataset[0])
|
19 |
|
20 |
def __getitem__(self, idx):
|
21 |
-
L = np.array(self.dataset[0][idx]).reshape(self.image_size
|
22 |
L = transforms.ToTensor()(L)
|
23 |
|
24 |
ab = np.array(self.dataset[1][idx])
|
@@ -27,6 +27,13 @@ class ImageColorizationDataset(Dataset):
|
|
27 |
return ab, L
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class DataTransformation:
|
31 |
def __init__(self, config: DataTransformationConfig):
|
32 |
self.config = config
|
@@ -38,8 +45,7 @@ class DataTransformation:
|
|
38 |
gc.collect()
|
39 |
return dataset
|
40 |
|
41 |
-
|
42 |
-
def get_dataloader(self, dataset):
|
43 |
train_dataset = ImageColorizationDataset(
|
44 |
dataset=dataset,
|
45 |
image_size=self.config.IMAGE_SIZE
|
@@ -49,21 +55,25 @@ class DataTransformation:
|
|
49 |
image_size=self.config.IMAGE_SIZE
|
50 |
)
|
51 |
|
52 |
-
|
53 |
-
train_dataset,
|
54 |
-
batch_size=self.config.BATCH_SIZE,
|
55 |
-
shuffle=True,
|
56 |
-
pin_memory=True
|
57 |
-
)
|
58 |
-
test_loader = DataLoader(
|
59 |
-
test_dataset,
|
60 |
-
batch_size=self.config.BATCH_SIZE,
|
61 |
-
shuffle=True,
|
62 |
-
pin_memory=True
|
63 |
-
)
|
64 |
-
|
65 |
-
return train_loader, test_loader
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
|
@@ -87,4 +97,3 @@ class DataTransformation:
|
|
87 |
except Exception as e:
|
88 |
logger.error(f"Error saving dataloaders: {str(e)}")
|
89 |
raise e
|
90 |
-
|
|
|
18 |
return len(self.dataset[0])
|
19 |
|
20 |
def __getitem__(self, idx):
|
21 |
+
L = np.array(self.dataset[0][idx]).reshape(self.image_size)
|
22 |
L = transforms.ToTensor()(L)
|
23 |
|
24 |
ab = np.array(self.dataset[1][idx])
|
|
|
27 |
return ab, L
|
28 |
|
29 |
|
30 |
+
from torch.utils.data import DataLoader
|
31 |
+
import gc
|
32 |
+
import os
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
from src.imagecolorization.logging import logger
|
36 |
+
|
37 |
class DataTransformation:
|
38 |
def __init__(self, config: DataTransformationConfig):
|
39 |
self.config = config
|
|
|
45 |
gc.collect()
|
46 |
return dataset
|
47 |
|
48 |
+
def get_datasets(self, dataset):
|
|
|
49 |
train_dataset = ImageColorizationDataset(
|
50 |
dataset=dataset,
|
51 |
image_size=self.config.IMAGE_SIZE
|
|
|
55 |
image_size=self.config.IMAGE_SIZE
|
56 |
)
|
57 |
|
58 |
+
return train_dataset, test_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
def save_datasets(self, train_dataset, test_dataset):
|
61 |
+
# Ensure the directory exists
|
62 |
+
os.makedirs(self.config.root_dir, exist_ok=True)
|
63 |
+
|
64 |
+
train_dataset_path = os.path.join(self.config.root_dir, 'train_dataset.pt')
|
65 |
+
test_dataset_path = os.path.join(self.config.root_dir, 'test_dataset.pt')
|
66 |
+
|
67 |
+
try:
|
68 |
+
# Save the datasets
|
69 |
+
torch.save(train_dataset, train_dataset_path)
|
70 |
+
torch.save(test_dataset, test_dataset_path)
|
71 |
+
|
72 |
+
logger.info(f"Train dataset saved at: {train_dataset_path}")
|
73 |
+
logger.info(f"Test dataset saved at: {test_dataset_path}")
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error saving datasets: {str(e)}")
|
76 |
+
raise e
|
77 |
|
78 |
|
79 |
|
|
|
97 |
except Exception as e:
|
98 |
logger.error(f"Error saving dataloaders: {str(e)}")
|
99 |
raise e
|
|
src/imagecolorization/pipeline/stage02_data_transformation.py
CHANGED
@@ -10,6 +10,12 @@ class DataTransformationPipeline:
|
|
10 |
config = ConfigurationManager()
|
11 |
data_transformation_config = config.get_data_transformation_config()
|
12 |
data_transformation = DataTransformation(config=data_transformation_config)
|
|
|
|
|
13 |
dataset = data_transformation.load_data()
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
10 |
config = ConfigurationManager()
|
11 |
data_transformation_config = config.get_data_transformation_config()
|
12 |
data_transformation = DataTransformation(config=data_transformation_config)
|
13 |
+
|
14 |
+
# Load the dataset
|
15 |
dataset = data_transformation.load_data()
|
16 |
+
|
17 |
+
# Get the dataloader using the loaded dataset
|
18 |
+
train_dataset, test_dataset = data_transformation.get_datasets(dataset)
|
19 |
+
|
20 |
+
# Perform any further operations (e.g., saving the dataset)
|
21 |
+
data_transformation.save_datasets(train_dataset, test_dataset)
|