import os import time import copy import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torchvision import datasets, transforms, models from sklearn.metrics import confusion_matrix, classification_report, top_k_accuracy_score import matplotlib.pyplot as plt import seaborn as sns # Параметры обучения num_epochs = 25 batch_size = 32 learning_rate = 0.001 data_dir = '/kaggle/input/centraasia' # Корневая папка с поддиректориями train, val и test # Трансформации для train, validation и test data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), 'test': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), } # Загрузка датасетов с помощью ImageFolder image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val', 'test']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']} class_names = image_datasets['train'].classes print("Классы: ", class_names) print("Размеры датасетов: ", dataset_sizes) # Вычисление количества изображений по классам (для гистограммы и расчёта весов) class_counts = {} for _, label in image_datasets['train'].imgs: class_counts[label] = class_counts.get(label, 0) + 1 print("\nРаспределение по классам:") for idx, count in class_counts.items(): print(f"{class_names[idx]}: {count}") # Отобразим гистограмму распределения plt.figure(figsize=(10, 8)) counts = [class_counts[i] for i in range(len(class_names))] sns.barplot(x=counts, y=class_names, orient='h') plt.xlabel("Number of images") plt.ylabel("Class names") plt.title("Распределение изображений по классам") plt.show() # Определение устройства device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Используем устройство: ", device) # Инициализация модели ResNet50 model_ft = models.resnet50(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft = model_ft.to(device) # Если доступны несколько GPU, оборачиваем модель в DataParallel if torch.cuda.device_count() > 1: print("Используем", torch.cuda.device_count(), "GPU!") model_ft = nn.DataParallel(model_ft) # Расчет весов для классов для учёта дисбаланса. # Можно вычислить веса как обратную пропорциональность количеству примеров в классе. class_sample_counts = np.array([class_counts[i] for i in range(len(class_names))]) # Вес для каждого класса: чем меньше примеров, тем больший вес. class_weights = 1. / class_sample_counts # Приводим веса к тензору и перемещаем на устройство class_weights_tensor = torch.FloatTensor(class_weights).to(device) # Определение функции потерь с учетом весов классов criterion = nn.CrossEntropyLoss(weight=class_weights_tensor) # Определение оптимизатора и планировщика optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=0.9) exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) # Функция обучения модели def train_model(model, criterion, optimizer, scheduler, num_epochs=25): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(num_epochs): print('-' * 30) print(f'Эпоха {epoch+1}/{num_epochs}') # Каждая эпоха проходит фазы обучения и валидации for phase in ['train', 'val']: if phase == 'train': model.train() # режим обучения else: model.eval() # режим валидации running_loss = 0.0 running_corrects = 0 # Итерация по данным for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() # Прямой проход with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Обратный проход и оптимизация в режиме обучения if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') # Сохраняем лучшую модель по точности валидации if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print(f'Обучение завершено за {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s') print(f'Лучшая точность на валидации: {best_acc:.4f}') # Загружаем лучшие веса модели model.load_state_dict(best_model_wts) return model # Обучаем модель model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=num_epochs) # Функция тестирования модели с вычислением дополнительных метрик def test_model(model, dataloader): model.eval() running_corrects = 0 all_preds = [] all_labels = [] # Для расчёта top-5 точности all_outputs = [] with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) running_corrects += torch.sum(preds == labels.data) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_outputs.extend(outputs.cpu().numpy()) test_acc = running_corrects.double() / dataset_sizes['test'] print(f'\nTop-1 точность на тестовом наборе: {test_acc:.4f}') # Вычисление top-5 точности all_outputs = np.array(all_outputs) all_labels_np = np.array(all_labels) top5_acc = top_k_accuracy_score(all_labels_np, all_outputs, k=5) print(f'Top-5 точность на тестовом наборе: {top5_acc:.4f}') # Матрица ошибок cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(12, 10)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel("Предсказанный класс") plt.ylabel("Истинный класс") plt.title("Матрица ошибок") plt.show() # Отчет по классам (precision, recall, f1-score) print("Отчет по классам:") print(classification_report(all_labels, all_preds, target_names=class_names)) test_model(model_ft, dataloaders['test'])