import torch as t def corn_loss(logits, y_train, num_classes): """Computes the CORN loss described in our forthcoming 'Deep Neural Networks for Rank Consistent Ordinal Regression based on Conditional Probabilities' manuscript. Parameters ---------- logits : torch.tensor, shape=(num_examples, num_classes-1) Outputs of the CORN layer. y_train : torch.tensor, shape=(num_examples) Torch tensor containing the class labels. num_classes : int Number of unique class labels (class labels should start at 0). Returns ---------- loss : torch.tensor A torch.tensor containing a single loss value. Examples ---------- >>> import torch >>> from coral_pytorch.losses import corn_loss >>> # Consider 8 training examples >>> _ = torch.manual_seed(123) >>> X_train = torch.rand(8, 99) >>> y_train = torch.tensor([0, 1, 2, 2, 2, 3, 4, 4]) >>> NUM_CLASSES = 5 >>> # >>> # >>> # def __init__(self): >>> corn_net = torch.nn.Linear(99, NUM_CLASSES-1) >>> # >>> # >>> # def forward(self, X_train): >>> logits = corn_net(X_train) >>> logits.shape torch.Size([8, 4]) >>> corn_loss(logits, y_train, NUM_CLASSES) tensor(0.7127, grad_fn=) https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/losses.py """ sets = [] for i in range(num_classes - 1): label_mask = y_train > i - 1 label_tensor = (y_train[label_mask] > i).to(t.int64) sets.append((label_mask, label_tensor)) num_examples = 0 losses = 0.0 for task_index, s in enumerate(sets): train_examples = s[0] train_labels = s[1] if len(train_labels) < 1: continue num_examples += len(train_labels) pred = logits[train_examples, task_index] loss = -t.sum( t.nn.functional.logsigmoid(pred) * train_labels + (t.nn.functional.logsigmoid(pred) - pred) * (1 - train_labels) ) losses += loss return losses / num_examples def corn_label_from_logits(logits): """ Returns the predicted rank label from logits for a network trained via the CORN loss. Parameters ---------- logits : torch.tensor, shape=(n_examples, n_classes) Torch tensor consisting of logits returned by the neural net. Returns ---------- labels : torch.tensor, shape=(n_examples) Integer tensor containing the predicted rank (class) labels Examples ---------- >>> # 2 training examples, 5 classes >>> logits = torch.tensor([[14.152, -6.1942, 0.47710, 0.96850], ... [65.667, 0.303, 11.500, -4.524]]) >>> corn_label_from_logits(logits) tensor([1, 3]) https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/dataset.py """ probas = t.sigmoid(logits) probas = t.cumprod(probas, dim=1) predict_levels = probas > 0.5 predicted_labels = t.sum(predict_levels, dim=1) return predicted_labels