HAMIM-ML commited on
Commit
e623e7e
·
1 Parent(s): cf8627d

model trainer fixed

Browse files
Files changed (46) hide show
  1. config/config.yaml +5 -0
  2. lightning_logs/version_0/events.out.tfevents.1724493340.Hakim.9156.0 +0 -0
  3. lightning_logs/version_0/hparams.yaml +7 -0
  4. lightning_logs/version_1/events.out.tfevents.1724493420.Hakim.2840.0 +0 -0
  5. lightning_logs/version_1/hparams.yaml +7 -0
  6. lightning_logs/version_10/events.out.tfevents.1724505838.Hakim.4444.2 +0 -0
  7. lightning_logs/version_10/hparams.yaml +7 -0
  8. lightning_logs/version_11/events.out.tfevents.1724505997.Hakim.4444.3 +0 -0
  9. lightning_logs/version_11/hparams.yaml +7 -0
  10. lightning_logs/version_12/events.out.tfevents.1724506091.Hakim.4444.4 +0 -0
  11. lightning_logs/version_12/hparams.yaml +7 -0
  12. lightning_logs/version_13/events.out.tfevents.1724506200.Hakim.4444.5 +0 -0
  13. lightning_logs/version_13/hparams.yaml +7 -0
  14. lightning_logs/version_14/events.out.tfevents.1724506514.Hakim.19820.0 +0 -0
  15. lightning_logs/version_14/hparams.yaml +7 -0
  16. lightning_logs/version_15/events.out.tfevents.1724521354.Hakim.8444.0 +0 -0
  17. lightning_logs/version_15/hparams.yaml +7 -0
  18. lightning_logs/version_16/events.out.tfevents.1724521651.Hakim.8444.1 +0 -0
  19. lightning_logs/version_16/hparams.yaml +7 -0
  20. lightning_logs/version_17/events.out.tfevents.1724521762.Hakim.8444.2 +0 -0
  21. lightning_logs/version_17/hparams.yaml +7 -0
  22. lightning_logs/version_18/events.out.tfevents.1724521935.Hakim.8444.3 +0 -0
  23. lightning_logs/version_18/hparams.yaml +7 -0
  24. lightning_logs/version_19/events.out.tfevents.1724522010.Hakim.28412.0 +0 -0
  25. lightning_logs/version_19/hparams.yaml +7 -0
  26. lightning_logs/version_2/events.out.tfevents.1724493545.Hakim.2840.1 +0 -0
  27. lightning_logs/version_2/hparams.yaml +7 -0
  28. lightning_logs/version_3/events.out.tfevents.1724494109.Hakim.14856.0 +0 -0
  29. lightning_logs/version_3/hparams.yaml +7 -0
  30. lightning_logs/version_4/events.out.tfevents.1724494965.Hakim.16704.0 +0 -0
  31. lightning_logs/version_4/hparams.yaml +7 -0
  32. lightning_logs/version_5/events.out.tfevents.1724495112.Hakim.17000.0 +0 -0
  33. lightning_logs/version_5/hparams.yaml +7 -0
  34. lightning_logs/version_6/events.out.tfevents.1724495583.Hakim.2896.0 +0 -0
  35. lightning_logs/version_6/hparams.yaml +7 -0
  36. lightning_logs/version_7/events.out.tfevents.1724505035.Hakim.3288.0 +0 -0
  37. lightning_logs/version_7/hparams.yaml +7 -0
  38. lightning_logs/version_8/events.out.tfevents.1724505263.Hakim.4444.0 +0 -0
  39. lightning_logs/version_8/hparams.yaml +7 -0
  40. lightning_logs/version_9/events.out.tfevents.1724505733.Hakim.4444.1 +0 -0
  41. lightning_logs/version_9/hparams.yaml +7 -0
  42. params.yaml +8 -0
  43. research/data_transformation.ipynb +33 -51
  44. research/model_trainer.ipynb +0 -0
  45. src/imagecolorization/conponents/data_tranformation.py +27 -18
  46. src/imagecolorization/pipeline/stage02_data_transformation.py +8 -2
config/config.yaml CHANGED
@@ -14,5 +14,10 @@ data_transformation:
14
 
15
  model_building:
16
  root_dir : artifacts/model
 
 
 
 
 
17
 
18
 
 
14
 
15
  model_building:
16
  root_dir : artifacts/model
17
+
18
+ model_trainer:
19
+ root_dir : artifacts/trained_model
20
+ test_data_path : artifacts/data_transformation/test_dataset.pt
21
+ train_data_path : artifacts/data_transformation/train_dataset.pt
22
 
23
 
lightning_logs/version_0/events.out.tfevents.1724493340.Hakim.9156.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_0/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_1/events.out.tfevents.1724493420.Hakim.2840.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_1/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_10/events.out.tfevents.1724505838.Hakim.4444.2 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_10/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_11/events.out.tfevents.1724505997.Hakim.4444.3 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_11/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_12/events.out.tfevents.1724506091.Hakim.4444.4 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_12/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_13/events.out.tfevents.1724506200.Hakim.4444.5 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_13/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_14/events.out.tfevents.1724506514.Hakim.19820.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_14/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_15/events.out.tfevents.1724521354.Hakim.8444.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_15/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_16/events.out.tfevents.1724521651.Hakim.8444.1 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_16/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_17/events.out.tfevents.1724521762.Hakim.8444.2 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_17/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_18/events.out.tfevents.1724521935.Hakim.8444.3 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_18/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_19/events.out.tfevents.1724522010.Hakim.28412.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_19/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_2/events.out.tfevents.1724493545.Hakim.2840.1 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_2/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_3/events.out.tfevents.1724494109.Hakim.14856.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_3/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_4/events.out.tfevents.1724494965.Hakim.16704.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_4/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_5/events.out.tfevents.1724495112.Hakim.17000.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_5/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_6/events.out.tfevents.1724495583.Hakim.2896.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_6/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_7/events.out.tfevents.1724505035.Hakim.3288.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_7/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_8/events.out.tfevents.1724505263.Hakim.4444.0 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_8/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
lightning_logs/version_9/events.out.tfevents.1724505733.Hakim.4444.1 ADDED
Binary file (683 Bytes). View file
 
lightning_logs/version_9/hparams.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ display_step: 10
2
+ in_channels: 1
3
+ lambda_gp: 10
4
+ lambda_r1: 10
5
+ lambda_recon: 100
6
+ learning_rate: 0.0002
7
+ out_channels: 2
params.yaml CHANGED
@@ -23,3 +23,11 @@ OUTPUT_CHANNELS: 2
23
 
24
  # Critic Parameters
25
  IN_CHANNELS: 3
 
 
 
 
 
 
 
 
 
23
 
24
  # Critic Parameters
25
  IN_CHANNELS: 3
26
+
27
+ # model train
28
+ LEARNING_RATE : 2e-4
29
+ LAMBDA_RECON : 100
30
+ DISPLAY_STEP : 10
31
+ INPUT_CHANNELS : 1
32
+ OUTPUT_CHANNELS : 2
33
+ EPOCH : 1
research/data_transformation.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -12,7 +12,7 @@
12
  },
13
  {
14
  "cell_type": "code",
15
- "execution_count": 2,
16
  "metadata": {},
17
  "outputs": [
18
  {
@@ -21,7 +21,7 @@
21
  "'c:\\\\mlops project\\\\image-colorization-mlops'"
22
  ]
23
  },
24
- "execution_count": 2,
25
  "metadata": {},
26
  "output_type": "execute_result"
27
  }
@@ -32,7 +32,7 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 3,
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
@@ -51,7 +51,7 @@
51
  },
52
  {
53
  "cell_type": "code",
54
- "execution_count": 4,
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
@@ -88,13 +88,13 @@
88
  },
89
  {
90
  "cell_type": "code",
91
- "execution_count": 5,
92
  "metadata": {},
93
  "outputs": [],
94
  "source": [
95
  "import numpy as np\n",
96
  "import torch\n",
97
- "from torch.utils.data import DataLoader, Dataset\n",
98
  "from torchvision import transforms\n",
99
  "\n",
100
  "class ImageColorizationDataset(Dataset):\n",
@@ -106,24 +106,26 @@
106
  " return len(self.dataset[0])\n",
107
  " \n",
108
  " def __getitem__(self, idx):\n",
109
- " L = np.array(self.dataset[0][idx]).reshape(self.image_size + (1,))\n",
110
  " L = transforms.ToTensor()(L)\n",
111
  " \n",
112
  " ab = np.array(self.dataset[1][idx])\n",
113
  " ab = transforms.ToTensor()(ab)\n",
114
  " \n",
115
- " return ab, L\n"
116
  ]
117
  },
118
  {
119
  "cell_type": "code",
120
- "execution_count": 6,
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
124
  "from torch.utils.data import DataLoader\n",
125
  "import gc\n",
126
  "import os\n",
 
 
127
  "from src.imagecolorization.logging import logger\n",
128
  "\n",
129
  "class DataTransformation:\n",
@@ -137,71 +139,51 @@
137
  " gc.collect()\n",
138
  " return dataset\n",
139
  " \n",
140
- " \n",
141
- " def get_dataloader(self, dataset):\n",
142
  " train_dataset = ImageColorizationDataset(\n",
143
  " dataset=dataset,\n",
144
- " image_size=self.config.IMAGE_SIZE\n",
145
  " )\n",
146
  " test_dataset = ImageColorizationDataset(\n",
147
  " dataset=dataset,\n",
148
- " image_size=self.config.IMAGE_SIZE\n",
149
- " )\n",
150
- " \n",
151
- " train_loader = DataLoader(\n",
152
- " train_dataset,\n",
153
- " batch_size=self.config.BATCH_SIZE,\n",
154
- " shuffle=True,\n",
155
- " pin_memory=True\n",
156
  " )\n",
157
- " test_loader = DataLoader(\n",
158
- " test_dataset,\n",
159
- " batch_size=self.config.BATCH_SIZE,\n",
160
- " shuffle=True,\n",
161
- " pin_memory=True\n",
162
- " )\n",
163
- " \n",
164
- " return train_loader, test_loader\n",
165
- " \n",
166
- " \n",
167
  " \n",
 
168
  " \n",
169
  " \n",
170
  " \n",
171
- " \n",
172
- " def save_dataloaders(self, train_loader, test_loader):\n",
173
  " # Ensure the directory exists\n",
174
  " os.makedirs(self.config.root_dir, exist_ok=True)\n",
175
  "\n",
176
- " train_loader_path = os.path.join(self.config.root_dir, 'train_loader.pt')\n",
177
- " test_loader_path = os.path.join(self.config.root_dir, 'test_loader.pt')\n",
178
  "\n",
179
  " try:\n",
180
- " # Save the dataloaders\n",
181
- " torch.save(train_loader, train_loader_path)\n",
182
- " torch.save(test_loader, test_loader_path)\n",
183
  "\n",
184
- " logger.info(f\"Train Loader saved at: {train_loader_path}\")\n",
185
- " logger.info(f\"Test Loader saved at: {test_loader_path}\")\n",
186
  " except Exception as e:\n",
187
- " logger.error(f\"Error saving dataloaders: {str(e)}\")\n",
188
- " raise e\n"
189
  ]
190
  },
191
  {
192
  "cell_type": "code",
193
- "execution_count": 7,
194
  "metadata": {},
195
  "outputs": [
196
  {
197
  "name": "stdout",
198
  "output_type": "stream",
199
  "text": [
200
- "[2024-08-18 17:50:45,127: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
201
- "[2024-08-18 17:50:45,129: INFO: common: yaml file: params.yaml loaded successfully]\n",
202
- "[2024-08-18 17:50:45,129: INFO: common: created directory at: artifacts]\n",
203
- "[2024-08-18 17:50:57,600: INFO: 2567581832: Train Loader saved at: artifacts/data_transformation\\train_loader.pt]\n",
204
- "[2024-08-18 17:50:57,605: INFO: 2567581832: Test Loader saved at: artifacts/data_transformation\\test_loader.pt]\n"
205
  ]
206
  }
207
  ],
@@ -215,10 +197,10 @@
215
  " dataset = data_transformation.load_data()\n",
216
  " \n",
217
  " # Get the dataloader using the loaded dataset\n",
218
- " train_loader, test_loader = data_transformation.get_dataloader(dataset)\n",
219
  " \n",
220
- " # Perform any further operations (e.g., saving the dataloaders)\n",
221
- " data_transformation.save_dataloaders(train_loader, test_loader)\n",
222
  " \n",
223
  "except Exception as e:\n",
224
  " raise e\n"
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 2,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
12
  },
13
  {
14
  "cell_type": "code",
15
+ "execution_count": 3,
16
  "metadata": {},
17
  "outputs": [
18
  {
 
21
  "'c:\\\\mlops project\\\\image-colorization-mlops'"
22
  ]
23
  },
24
+ "execution_count": 3,
25
  "metadata": {},
26
  "output_type": "execute_result"
27
  }
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 4,
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
 
51
  },
52
  {
53
  "cell_type": "code",
54
+ "execution_count": 5,
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
 
88
  },
89
  {
90
  "cell_type": "code",
91
+ "execution_count": 6,
92
  "metadata": {},
93
  "outputs": [],
94
  "source": [
95
  "import numpy as np\n",
96
  "import torch\n",
97
+ "from torch.utils.data import Dataset\n",
98
  "from torchvision import transforms\n",
99
  "\n",
100
  "class ImageColorizationDataset(Dataset):\n",
 
106
  " return len(self.dataset[0])\n",
107
  " \n",
108
  " def __getitem__(self, idx):\n",
109
+ " L = np.array(self.dataset[0][idx]).reshape(self.image_size)\n",
110
  " L = transforms.ToTensor()(L)\n",
111
  " \n",
112
  " ab = np.array(self.dataset[1][idx])\n",
113
  " ab = transforms.ToTensor()(ab)\n",
114
  " \n",
115
+ " return ab, L"
116
  ]
117
  },
118
  {
119
  "cell_type": "code",
120
+ "execution_count": 11,
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
124
  "from torch.utils.data import DataLoader\n",
125
  "import gc\n",
126
  "import os\n",
127
+ "import numpy as np\n",
128
+ "import torch\n",
129
  "from src.imagecolorization.logging import logger\n",
130
  "\n",
131
  "class DataTransformation:\n",
 
139
  " gc.collect()\n",
140
  " return dataset\n",
141
  " \n",
142
+ " def get_datasets(self, dataset):\n",
 
143
  " train_dataset = ImageColorizationDataset(\n",
144
  " dataset=dataset,\n",
 
145
  " )\n",
146
  " test_dataset = ImageColorizationDataset(\n",
147
  " dataset=dataset,\n",
 
 
 
 
 
 
 
 
148
  " )\n",
 
 
 
 
 
 
 
 
 
 
149
  " \n",
150
+ " return train_dataset, test_dataset\n",
151
  " \n",
152
  " \n",
153
  " \n",
154
+ " def save_datasets(self, train_dataset, test_dataset):\n",
 
155
  " # Ensure the directory exists\n",
156
  " os.makedirs(self.config.root_dir, exist_ok=True)\n",
157
  "\n",
158
+ " train_dataset_path = os.path.join(self.config.root_dir, 'train_dataset.pt')\n",
159
+ " test_dataset_path = os.path.join(self.config.root_dir, 'test_dataset.pt')\n",
160
  "\n",
161
  " try:\n",
162
+ " # Save the datasets\n",
163
+ " torch.save(train_dataset, train_dataset_path)\n",
164
+ " torch.save(test_dataset, test_dataset_path)\n",
165
  "\n",
166
+ " logger.info(f\"Train dataset saved at: {train_dataset_path}\")\n",
167
+ " logger.info(f\"Test dataset saved at: {test_dataset_path}\")\n",
168
  " except Exception as e:\n",
169
+ " logger.error(f\"Error saving datasets: {str(e)}\")\n",
170
+ " raise e"
171
  ]
172
  },
173
  {
174
  "cell_type": "code",
175
+ "execution_count": 12,
176
  "metadata": {},
177
  "outputs": [
178
  {
179
  "name": "stdout",
180
  "output_type": "stream",
181
  "text": [
182
+ "[2024-08-24 19:09:25,021: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
183
+ "[2024-08-24 19:09:25,024: INFO: common: yaml file: params.yaml loaded successfully]\n",
184
+ "[2024-08-24 19:09:25,026: INFO: common: created directory at: artifacts]\n",
185
+ "[2024-08-24 19:09:43,417: INFO: 3400243030: Train dataset saved at: artifacts/data_transformation\\train_dataset.pt]\n",
186
+ "[2024-08-24 19:09:43,440: INFO: 3400243030: Test dataset saved at: artifacts/data_transformation\\test_dataset.pt]\n"
187
  ]
188
  }
189
  ],
 
197
  " dataset = data_transformation.load_data()\n",
198
  " \n",
199
  " # Get the dataloader using the loaded dataset\n",
200
+ " train_dataset, test_dataset = data_transformation.get_datasets(dataset)\n",
201
  " \n",
202
+ " # Perform any further operations (e.g., saving the dataset)\n",
203
+ " data_transformation.save_datasets(train_dataset, test_dataset)\n",
204
  " \n",
205
  "except Exception as e:\n",
206
  " raise e\n"
research/model_trainer.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/imagecolorization/conponents/data_tranformation.py CHANGED
@@ -18,7 +18,7 @@ class ImageColorizationDataset(Dataset):
18
  return len(self.dataset[0])
19
 
20
  def __getitem__(self, idx):
21
- L = np.array(self.dataset[0][idx]).reshape(self.image_size + (1,))
22
  L = transforms.ToTensor()(L)
23
 
24
  ab = np.array(self.dataset[1][idx])
@@ -27,6 +27,13 @@ class ImageColorizationDataset(Dataset):
27
  return ab, L
28
 
29
 
 
 
 
 
 
 
 
30
  class DataTransformation:
31
  def __init__(self, config: DataTransformationConfig):
32
  self.config = config
@@ -38,8 +45,7 @@ class DataTransformation:
38
  gc.collect()
39
  return dataset
40
 
41
-
42
- def get_dataloader(self, dataset):
43
  train_dataset = ImageColorizationDataset(
44
  dataset=dataset,
45
  image_size=self.config.IMAGE_SIZE
@@ -49,21 +55,25 @@ class DataTransformation:
49
  image_size=self.config.IMAGE_SIZE
50
  )
51
 
52
- train_loader = DataLoader(
53
- train_dataset,
54
- batch_size=self.config.BATCH_SIZE,
55
- shuffle=True,
56
- pin_memory=True
57
- )
58
- test_loader = DataLoader(
59
- test_dataset,
60
- batch_size=self.config.BATCH_SIZE,
61
- shuffle=True,
62
- pin_memory=True
63
- )
64
-
65
- return train_loader, test_loader
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
 
@@ -87,4 +97,3 @@ class DataTransformation:
87
  except Exception as e:
88
  logger.error(f"Error saving dataloaders: {str(e)}")
89
  raise e
90
-
 
18
  return len(self.dataset[0])
19
 
20
  def __getitem__(self, idx):
21
+ L = np.array(self.dataset[0][idx]).reshape(self.image_size)
22
  L = transforms.ToTensor()(L)
23
 
24
  ab = np.array(self.dataset[1][idx])
 
27
  return ab, L
28
 
29
 
30
+ from torch.utils.data import DataLoader
31
+ import gc
32
+ import os
33
+ import numpy as np
34
+ import torch
35
+ from src.imagecolorization.logging import logger
36
+
37
  class DataTransformation:
38
  def __init__(self, config: DataTransformationConfig):
39
  self.config = config
 
45
  gc.collect()
46
  return dataset
47
 
48
+ def get_datasets(self, dataset):
 
49
  train_dataset = ImageColorizationDataset(
50
  dataset=dataset,
51
  image_size=self.config.IMAGE_SIZE
 
55
  image_size=self.config.IMAGE_SIZE
56
  )
57
 
58
+ return train_dataset, test_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def save_datasets(self, train_dataset, test_dataset):
61
+ # Ensure the directory exists
62
+ os.makedirs(self.config.root_dir, exist_ok=True)
63
+
64
+ train_dataset_path = os.path.join(self.config.root_dir, 'train_dataset.pt')
65
+ test_dataset_path = os.path.join(self.config.root_dir, 'test_dataset.pt')
66
+
67
+ try:
68
+ # Save the datasets
69
+ torch.save(train_dataset, train_dataset_path)
70
+ torch.save(test_dataset, test_dataset_path)
71
+
72
+ logger.info(f"Train dataset saved at: {train_dataset_path}")
73
+ logger.info(f"Test dataset saved at: {test_dataset_path}")
74
+ except Exception as e:
75
+ logger.error(f"Error saving datasets: {str(e)}")
76
+ raise e
77
 
78
 
79
 
 
97
  except Exception as e:
98
  logger.error(f"Error saving dataloaders: {str(e)}")
99
  raise e
 
src/imagecolorization/pipeline/stage02_data_transformation.py CHANGED
@@ -10,6 +10,12 @@ class DataTransformationPipeline:
10
  config = ConfigurationManager()
11
  data_transformation_config = config.get_data_transformation_config()
12
  data_transformation = DataTransformation(config=data_transformation_config)
 
 
13
  dataset = data_transformation.load_data()
14
- train_loader, test_loader = data_transformation.get_dataloader(dataset)
15
- data_transformation.save_dataloaders(train_loader, test_loader)
 
 
 
 
 
10
  config = ConfigurationManager()
11
  data_transformation_config = config.get_data_transformation_config()
12
  data_transformation = DataTransformation(config=data_transformation_config)
13
+
14
+ # Load the dataset
15
  dataset = data_transformation.load_data()
16
+
17
+ # Get the dataloader using the loaded dataset
18
+ train_dataset, test_dataset = data_transformation.get_datasets(dataset)
19
+
20
+ # Perform any further operations (e.g., saving the dataset)
21
+ data_transformation.save_datasets(train_dataset, test_dataset)