Spaces:
Runtime error
Runtime error
from torch import optim | |
from torch.utils.data import DataLoader | |
from torchvision.utils import save_image | |
from tqdm import trange | |
from Dataloader import * | |
from .utils import image_quality | |
from .utils.cls import CyclicLR | |
from .utils.prepare_images import * | |
train_folder = './dataset/train' | |
test_folder = "./dataset/test" | |
img_dataset = ImageDBData(db_file='dataset/images.db', db_table="train_images_size_128_noise_1_rgb", max_images=24) | |
img_data = DataLoader(img_dataset, batch_size=6, shuffle=True, num_workers=6) | |
total_batch = len(img_data) | |
print(len(img_dataset)) | |
test_dataset = ImageDBData(db_file='dataset/test2.db', db_table="test_images_size_128_noise_1_rgb", max_images=None) | |
num_test = len(test_dataset) | |
test_data = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1) | |
criteria = nn.L1Loss() | |
model = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d, | |
single_conv_size=3, single_conv_group=1, | |
scale=2, activation=nn.LeakyReLU(0.1), | |
SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1)) | |
model.total_parameters() | |
# model.initialize_weights_xavier_uniform() | |
# fp16 training is available in GPU only | |
model = network_to_half(model) | |
model = model.cuda() | |
model.load_state_dict(torch.load("CARN_model_checkpoint.pt")) | |
learning_rate = 1e-4 | |
weight_decay = 1e-6 | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=True) | |
# optimizer = optim.SGD(model.parameters(), momentum=0.9, nesterov=True, weight_decay=weight_decay, lr=learning_rate) | |
# optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0, verbose=False) | |
# optimizer.load_state_dict(torch.load("CARN_adam_checkpoint.pt")) | |
last_iter = -1 # torch.load("CARN_scheduler_last_iter") | |
scheduler = CyclicLR(optimizer, base_lr=1e-4, max_lr=1e-4, | |
step_size=3 * total_batch, mode="triangular", | |
last_batch_iteration=last_iter) | |
train_loss = [] | |
train_ssim = [] | |
train_psnr = [] | |
test_loss = [] | |
test_ssim = [] | |
test_psnr = [] | |
# train_loss = torch.load("train_loss.pt") | |
# train_ssim = torch.load("train_ssim.pt") | |
# train_psnr = torch.load("train_psnr.pt") | |
# | |
# test_loss = torch.load("test_loss.pt") | |
# test_ssim = torch.load("test_ssim.pt") | |
# test_psnr = torch.load("test_psnr.pt") | |
counter = 0 | |
iteration = 2 | |
ibar = trange(iteration, ascii=True, maxinterval=1, postfix={"avg_loss": 0, "train_ssim": 0, "test_ssim": 0}) | |
for i in ibar: | |
# batch_loss = [] | |
# insample_ssim = [] | |
# insample_psnr = [] | |
for index, batch in enumerate(img_data): | |
scheduler.batch_step() | |
lr_img, hr_img = batch | |
lr_img = lr_img.cuda().half() | |
hr_img = hr_img.cuda() | |
# model.zero_grad() | |
optimizer.zero_grad() | |
outputs = model.forward(lr_img) | |
outputs = outputs.float() | |
loss = criteria(outputs, hr_img) | |
# loss.backward() | |
optimizer.backward(loss) | |
# nn.utils.clip_grad_norm_(model.parameters(), 5) | |
optimizer.step() | |
counter += 1 | |
# train_loss.append(loss.item()) | |
ssim = image_quality.msssim(outputs, hr_img).item() | |
psnr = image_quality.psnr(outputs, hr_img).item() | |
ibar.set_postfix(ratio=index / total_batch, loss=loss.item(), | |
ssim=ssim, batch=index, | |
psnr=psnr, | |
lr=scheduler.current_lr | |
) | |
train_loss.append(loss.item()) | |
train_ssim.append(ssim) | |
train_psnr.append(psnr) | |
# +++++++++++++++++++++++++++++++++++++ | |
# save checkpoints by iterations | |
# ------------------------------------- | |
if (counter + 1) % 500 == 0: | |
torch.save(model.state_dict(), 'CARN_model_checkpoint.pt') | |
torch.save(optimizer.state_dict(), 'CARN_adam_checkpoint.pt') | |
torch.save(train_loss, 'train_loss.pt') | |
torch.save(train_ssim, "train_ssim.pt") | |
torch.save(train_psnr, 'train_psnr.pt') | |
torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt") | |
# +++++++++++++++++++++++++++++++++++++ | |
# End of One Epoch | |
# ------------------------------------- | |
# one_ite_loss = np.mean(batch_loss) | |
# one_ite_ssim = np.mean(insample_ssim) | |
# one_ite_psnr = np.mean(insample_psnr) | |
# print(f"One iteration loss {one_ite_loss}, ssim {one_ite_ssim}, psnr {one_ite_psnr}") | |
# train_loss.append(one_ite_loss) | |
# train_ssim.append(one_ite_ssim) | |
# train_psnr.append(one_ite_psnr) | |
torch.save(model.state_dict(), 'CARN_model_checkpoint.pt') | |
# torch.save(scheduler, "CARN_scheduler_optim.pt") | |
torch.save(optimizer.state_dict(), 'CARN_adam_checkpoint.pt') | |
torch.save(train_loss, 'train_loss.pt') | |
torch.save(train_ssim, "train_ssim.pt") | |
torch.save(train_psnr, 'train_psnr.pt') | |
# torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt") | |
# +++++++++++++++++++++++++++++++++++++ | |
# Test | |
# ------------------------------------- | |
with torch.no_grad(): | |
ssim = [] | |
batch_loss = [] | |
psnr = [] | |
for index, test_batch in enumerate(test_data): | |
lr_img, hr_img = test_batch | |
lr_img = lr_img.cuda() | |
hr_img = hr_img.cuda() | |
lr_img_up = model(lr_img) | |
lr_img_up = lr_img_up.float() | |
loss = criteria(lr_img_up, hr_img) | |
save_image([lr_img_up[0], hr_img[0]], f"check_test_imgs/{index}.png") | |
batch_loss.append(loss.item()) | |
ssim.append(image_quality.msssim(lr_img_up, hr_img).item()) | |
psnr.append(image_quality.psnr(lr_img_up, hr_img).item()) | |
test_ssim.append(np.mean(ssim)) | |
test_loss.append(np.mean(batch_loss)) | |
test_psnr.append(np.mean(psnr)) | |
torch.save(test_loss, 'test_loss.pt') | |
torch.save(test_ssim, "test_ssim.pt") | |
torch.save(test_psnr, "test_psnr.pt") | |
# import subprocess | |
# subprocess.call(["shutdown", "/s"]) | |