import timm import os from typing import Any from pytorch_lightning.utilities.types import LRSchedulerTypeUnion import torch as t from torch import nn import numpy as np import transformers import pytorch_lightning as plight import torchmetrics import einops as eo from loss_functions import coral_loss, corn_loss, corn_label_from_logits, macro_soft_f1 t.set_float32_matmul_precision("medium") global_settings = dict(try_using_torch_compile=False) class EnsembleModel(plight.LightningModule): def __init__(self, models_without_norm_df, models_with_norm_df, learning_rate=0.0002, use_simple_average=False): super().__init__() self.models_without_norm = nn.ModuleList(list(models_without_norm_df)) self.models_with_norm = nn.ModuleList(list(models_with_norm_df)) self.learning_rate = learning_rate self.use_simple_average = use_simple_average if not self.use_simple_average: self.combiner = nn.Linear( self.models_with_norm[0].num_classes * (len(self.models_with_norm) + len(self.models_without_norm)), self.models_with_norm[0].num_classes, ) def forward(self, x): x_unnormed, x_normed = x if not self.use_simple_average: out_unnormed = t.cat([model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm], dim=-1) out_normed = t.cat([model.model_step(x_normed, 0)[0] for model in self.models_with_norm], dim=-1) out_avg = self.combiner(t.cat((out_unnormed, out_normed), dim=-1)) else: out_unnormed = [model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm] out_normed = [model.model_step(x_normed, 0)[0] for model in self.models_with_norm] out_avg = (t.stack(out_unnormed + out_normed, dim=-1) / 2).mean(-1) return {"out_avg": out_avg, "out_unnormed": out_unnormed, "out_normed": out_normed}, x_unnormed[-1] def training_step(self, batch, batch_idx): out, y = self(batch) loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0]) self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): out, y = self(batch) preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y) acc = torchmetrics.functional.accuracy( preds, y_onecold.to(t.long), ignore_index=ignore_index_val, num_classes=self.models_with_norm[0].num_classes, task="multiclass", ) self.log("acc", acc * 100, prog_bar=True, sync_dist=True) loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0]) self.log("val_loss", loss, prog_bar=True, sync_dist=True) return loss def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): out, y = self(batch) preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y) return preds, out, y_onecold def configure_optimizers(self): return t.optim.Adam(self.parameters(), lr=self.learning_rate) class TimmHeadReplace(nn.Module): def __init__(self, pooling=None, in_channels=512, pooling_output_dimension=1, all_identity=False) -> None: super().__init__() if all_identity: self.head = nn.Identity() self.pooling = None else: self.pooling = pooling if pooling is not None: self.pooling_output_dimension = pooling_output_dimension if self.pooling == "AdaptiveAvgPool2d": self.pooling_layer = nn.AdaptiveAvgPool2d(pooling_output_dimension) elif self.pooling == "AdaptiveMaxPool2d": self.pooling_layer = nn.AdaptiveMaxPool2d(pooling_output_dimension) self.head = nn.Flatten() def forward(self, x, pre_logits=False): if self.pooling is not None: if self.pooling == "stack_avg_max_attn": x = t.cat([layer(x) for layer in self.pooling_layer], dim=-1) else: x = self.pooling_layer(x) return self.head(x) class CVModel(nn.Module): def __init__( self, modelname, in_shape, num_classes, loss_func, last_activation: str, input_padding_val=10, char_dims=2, max_seq_length=1000, ) -> None: super().__init__() self.modelname = modelname self.loss_func = loss_func self.in_shape = in_shape self.char_dims = char_dims self.x_shape = in_shape self.last_activation = last_activation self.max_seq_length = max_seq_length self.num_classes = num_classes if self.loss_func == "OrdinalRegLoss": self.out_shape = 1 else: self.out_shape = num_classes self.cv_model = timm.create_model(modelname, pretrained=True, num_classes=0) self.cv_model.classifier = nn.Identity() with t.inference_mode(): test_out = self.cv_model(t.ones(self.in_shape, dtype=t.float32)) self.cv_model_out_dim = test_out.shape[1] self.cv_model.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.cv_model_out_dim, self.max_seq_length)) if self.out_shape == 1: self.logit_norm = nn.Identity() self.out_project = nn.Identity() else: self.logit_norm = nn.LayerNorm(self.max_seq_length) self.out_project = nn.Linear(1, self.out_shape) if last_activation == "Softmax": self.final_activation = nn.Softmax(dim=-1) elif last_activation == "Sigmoid": self.final_activation = nn.Sigmoid() elif last_activation == "LogSigmoid": self.final_activation = nn.LogSigmoid() elif last_activation == "Identity": self.final_activation = nn.Identity() else: raise NotImplementedError(f"{last_activation} not implemented") def forward(self, x): if isinstance(x, list): x = x[0] x = self.cv_model(x) x = self.cv_model.classifier(x).unsqueeze(-1) x = self.out_project(x) return self.final_activation(x) class LitModel(plight.LightningModule): def __init__( self, in_shape: tuple, hidden_dim: int, num_attention_heads: int, num_layers: int, loss_func: str, learning_rate: float, weight_decay: float, cfg: dict, use_lr_warmup: bool, use_reduce_on_plateau: bool, track_gradient_histogram=False, register_forw_hook=False, char_dims=2, ) -> None: super().__init__() if "only_use_2nd_input_stream" not in cfg: cfg["only_use_2nd_input_stream"] = False if "gamma_step_size" not in cfg: cfg["gamma_step_size"] = 5 if "gamma_step_factor" not in cfg: cfg["gamma_step_factor"] = 0.5 self.save_hyperparameters( dict( in_shape=in_shape, hidden_dim=hidden_dim, num_attention_heads=num_attention_heads, num_layers=num_layers, loss_func=loss_func, learning_rate=learning_rate, cfg=cfg, x_shape=in_shape, num_classes=cfg["num_classes"], use_lr_warmup=use_lr_warmup, num_warmup_steps=cfg["num_warmup_steps"], use_reduce_on_plateau=use_reduce_on_plateau, weight_decay=weight_decay, track_gradient_histogram=track_gradient_histogram, register_forw_hook=register_forw_hook, char_dims=char_dims, remove_timm_classifier_head_pooling=cfg["remove_timm_classifier_head_pooling"], change_pooling_for_timm_head_to=cfg["change_pooling_for_timm_head_to"], chars_conv_pooling_out_dim=cfg["chars_conv_pooling_out_dim"], ) ) self.model_to_use = cfg["model_to_use"] self.num_classes = cfg["num_classes"] self.x_shape = in_shape self.in_shape = in_shape self.hidden_dim = hidden_dim self.num_attention_heads = num_attention_heads self.num_layers = num_layers self.use_lr_warmup = use_lr_warmup self.num_warmup_steps = cfg["num_warmup_steps"] self.warmup_exponent = cfg["warmup_exponent"] self.use_reduce_on_plateau = use_reduce_on_plateau self.loss_func = loss_func self.learning_rate = learning_rate self.weight_decay = weight_decay self.using_one_hot_targets = cfg["one_hot_y"] self.track_gradient_histogram = track_gradient_histogram self.register_forw_hook = register_forw_hook if self.loss_func == "OrdinalRegLoss": self.ord_reg_loss_max = cfg["ord_reg_loss_max"] self.ord_reg_loss_min = cfg["ord_reg_loss_min"] self.num_lin_layers = cfg["num_lin_layers"] self.linear_activation = cfg["linear_activation"] self.last_activation = cfg["last_activation"] self.max_seq_length = cfg["manual_max_sequence_for_model"] self.use_char_embed_info = cfg["use_embedded_char_pos_info"] self.method_chars_into_model = cfg["method_chars_into_model"] self.source_for_pretrained_cv_model = cfg["source_for_pretrained_cv_model"] self.method_to_include_char_positions = cfg["method_to_include_char_positions"] self.char_dims = char_dims self.char_sequence_length = cfg["max_len_chars_list"] if self.use_char_embed_info else 0 self.chars_conv_lr_reduction_factor = cfg["chars_conv_lr_reduction_factor"] if self.use_char_embed_info: self.chars_bert_reduction_factor = cfg["chars_bert_reduction_factor"] self.use_in_projection_bias = cfg["use_in_projection_bias"] self.add_layer_norm_to_in_projection = cfg["add_layer_norm_to_in_projection"] self.hidden_dropout_prob = cfg["hidden_dropout_prob"] self.layer_norm_after_in_projection = cfg["layer_norm_after_in_projection"] self.method_chars_into_model = cfg["method_chars_into_model"] self.input_padding_val = cfg["input_padding_val"] self.cv_char_modelname = cfg["cv_char_modelname"] self.char_plot_shape = cfg["char_plot_shape"] self.remove_timm_classifier_head_pooling = cfg["remove_timm_classifier_head_pooling"] self.change_pooling_for_timm_head_to = cfg["change_pooling_for_timm_head_to"] self.chars_conv_pooling_out_dim = cfg["chars_conv_pooling_out_dim"] self.add_layer_norm_to_char_mlp = cfg["add_layer_norm_to_char_mlp"] if "profile_torch_run" in cfg: self.profile_torch_run = cfg["profile_torch_run"] else: self.profile_torch_run = False if self.loss_func == "OrdinalRegLoss": self.out_shape = 1 else: self.out_shape = cfg["num_classes"] if not self.hparams.cfg["only_use_2nd_input_stream"]: if ( self.method_chars_into_model == "dense" and self.use_char_embed_info and self.method_to_include_char_positions == "concat" ): self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias) elif ( self.method_chars_into_model == "bert" and self.use_char_embed_info and self.method_to_include_char_positions == "concat" ): self.hidden_dim_chars = self.hidden_dim // 2 self.project = nn.Linear(self.x_shape[-1], self.hidden_dim_chars, bias=self.use_in_projection_bias) elif ( self.method_chars_into_model == "resnet" and self.method_to_include_char_positions == "concat" and self.use_char_embed_info ): self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias) elif self.model_to_use == "cv_only_model": self.project = nn.Identity() else: self.project = nn.Linear(self.x_shape[-1], self.hidden_dim, bias=self.use_in_projection_bias) if self.add_layer_norm_to_in_projection: self.project = nn.Sequential( nn.Linear(self.project.in_features, self.project.out_features, bias=self.use_in_projection_bias), nn.LayerNorm(self.project.out_features), ) if hasattr(self, "project") and "posix" in os.name and global_settings["try_using_torch_compile"]: self.project = t.compile(self.project) if self.use_char_embed_info: self._create_char_model() if self.layer_norm_after_in_projection: if self.hparams.cfg["only_use_2nd_input_stream"]: self.layer_norm_in = nn.LayerNorm(self.hidden_dim // 2) else: self.layer_norm_in = nn.LayerNorm(self.hidden_dim) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.layer_norm_in = t.compile(self.layer_norm_in) self._create_main_seq_model(cfg) if register_forw_hook: self.register_hooks() if self.hparams.cfg["only_use_2nd_input_stream"]: linear_in_dim = self.hidden_dim // 2 else: linear_in_dim = self.hidden_dim if self.num_lin_layers == 1: self.linear = nn.Linear(linear_in_dim, self.out_shape) else: lin_layers = [] for _ in range(self.num_lin_layers - 1): lin_layers.extend( [ nn.Linear(linear_in_dim, linear_in_dim), getattr(nn, self.linear_activation)(), ] ) self.linear = nn.Sequential(*lin_layers, nn.Linear(linear_in_dim, self.out_shape)) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.linear = t.compile(self.linear) if self.last_activation == "Softmax": self.final_activation = nn.Softmax(dim=-1) elif self.last_activation == "Sigmoid": self.final_activation = nn.Sigmoid() elif self.last_activation == "Identity": self.final_activation = nn.Identity() else: raise NotImplementedError(f"{self.last_activation} not implemented") if self.profile_torch_run: self.profilerr = t.profiler.profile( schedule=t.profiler.schedule(wait=1, warmup=10, active=10, repeat=1), on_trace_ready=t.profiler.tensorboard_trace_handler("tblogs"), with_stack=True, record_shapes=True, profile_memory=False, ) def _create_main_seq_model(self, cfg): if self.hparams.cfg["only_use_2nd_input_stream"]: hidden_dim = self.hidden_dim // 2 else: hidden_dim = self.hidden_dim if self.model_to_use == "BERT": self.bert_config = transformers.BertConfig( vocab_size=self.x_shape[-1], hidden_size=hidden_dim, num_hidden_layers=self.num_layers, intermediate_size=hidden_dim, num_attention_heads=self.num_attention_heads, max_position_embeddings=self.max_seq_length, ) self.bert_model = transformers.BertModel(self.bert_config) elif self.model_to_use == "cv_only_model": self.bert_model = CVModel( modelname=cfg["cv_modelname"], in_shape=self.in_shape, num_classes=cfg["num_classes"], loss_func=cfg["loss_function"], last_activation=cfg["last_activation"], input_padding_val=cfg["input_padding_val"], char_dims=self.char_dims, max_seq_length=cfg["manual_max_sequence_for_model"], ) else: raise NotImplementedError(f"{self.model_to_use} not implemented") if "posix" in os.name and global_settings["try_using_torch_compile"]: self.bert_model = t.compile(self.bert_model) return 0 def _create_char_model(self): if self.method_chars_into_model == "dense": self.chars_project_0 = nn.Linear(self.char_dims, 1, bias=self.use_in_projection_bias) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_project_0 = t.compile(self.chars_project_0) if self.method_to_include_char_positions == "concat": self.chars_project_1 = nn.Linear( self.char_sequence_length, self.hidden_dim // 2, bias=self.use_in_projection_bias ) else: self.chars_project_1 = nn.Linear( self.char_sequence_length, self.hidden_dim, bias=self.use_in_projection_bias ) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_project_1 = t.compile(self.chars_project_1) elif not self.method_chars_into_model == "resnet": self.chars_project = nn.Linear(self.char_dims, self.hidden_dim_chars, bias=self.use_in_projection_bias) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_project = t.compile(self.chars_project) if self.method_chars_into_model == "bert": if not hasattr(self, "hidden_dim_chars"): if self.hidden_dim // self.chars_bert_reduction_factor > 1: self.hidden_dim_chars = self.hidden_dim // self.chars_bert_reduction_factor else: self.hidden_dim_chars = self.hidden_dim self.num_attention_heads_chars = self.hidden_dim_chars // (self.hidden_dim // self.num_attention_heads) self.chars_bert_config = transformers.BertConfig( vocab_size=self.x_shape[-1], hidden_size=self.hidden_dim_chars, num_hidden_layers=self.num_layers, intermediate_size=self.hidden_dim_chars, num_attention_heads=self.num_attention_heads_chars, max_position_embeddings=self.char_sequence_length + 1, num_labels=1, ) self.chars_bert = transformers.BertForSequenceClassification(self.chars_bert_config) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_bert = t.compile(self.chars_bert) self.chars_project_class_output = nn.Linear(1, self.hidden_dim_chars, bias=self.use_in_projection_bias) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_project_class_output = t.compile(self.chars_project_class_output) elif self.method_chars_into_model == "resnet": if self.source_for_pretrained_cv_model == "timm": self.chars_conv = timm.create_model( self.cv_char_modelname, pretrained=True, num_classes=0, # remove classifier nn.Linear ) if self.remove_timm_classifier_head_pooling: self.chars_conv.head = TimmHeadReplace(all_identity=True) with t.inference_mode(): test_out = self.chars_conv( t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32) ) if test_out.ndim > 3: self.chars_conv.head = TimmHeadReplace( self.change_pooling_for_timm_head_to, test_out.shape[1], ) elif self.source_for_pretrained_cv_model == "huggingface": self.chars_conv = transformers.AutoModelForImageClassification.from_pretrained(self.cv_char_modelname) elif self.source_for_pretrained_cv_model == "torch_hub": self.chars_conv = t.hub.load(*self.cv_char_modelname.split(",")) if hasattr(self.chars_conv, "classifier"): self.chars_conv.classifier = nn.Identity() elif hasattr(self.chars_conv, "cls_classifier"): self.chars_conv.cls_classifier = nn.Identity() elif hasattr(self.chars_conv, "fc"): self.chars_conv.fc = nn.Identity() if hasattr(self.chars_conv, "distillation_classifier"): self.chars_conv.distillation_classifier = nn.Identity() with t.inference_mode(): test_out = self.chars_conv( t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32) ) if hasattr(test_out, "last_hidden_state"): self.chars_conv_out_dim = test_out.last_hidden_state.shape[1] elif hasattr(test_out, "logits"): self.chars_conv_out_dim = test_out.logits.shape[1] elif isinstance(test_out, list): self.chars_conv_out_dim = test_out[0].shape[1] else: self.chars_conv_out_dim = test_out.shape[1] char_lin_layers = [nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)] if self.add_layer_norm_to_char_mlp: char_lin_layers.append(nn.LayerNorm(self.hidden_dim // 2)) self.chars_classifier = nn.Sequential(*char_lin_layers) if hasattr(self.chars_conv, "distillation_classifier"): self.chars_conv.distillation_classifier = nn.Sequential( nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2) ) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_classifier = t.compile(self.chars_classifier) if "posix" in os.name and global_settings["try_using_torch_compile"]: self.chars_conv = t.compile(self.chars_conv) return 0 def register_hooks(self): def add_to_tb(layer): def hook(model, input, output): if hasattr(output, "detach"): for logger in self.loggers: if hasattr(logger.experiment, "add_histogram"): logger.experiment.add_histogram( tag=f"{layer}_{str(list(output.shape))}", values=output.detach(), global_step=self.trainer.global_step, ) return hook for layer_id, layer in dict([*self.named_modules()]).items(): layer.register_forward_hook(add_to_tb(f"act_{layer_id}")) def on_after_backward(self) -> None: if self.track_gradient_histogram: if self.trainer.global_step % 200 == 0: for logger in self.loggers: if hasattr(logger.experiment, "add_histogram"): for layer_id, layer in dict([*self.named_modules()]).items(): parameters = layer.parameters() for idx2, p in enumerate(parameters): grad_val = p.grad if grad_val is not None: grad_name = f"grad_{idx2}_{layer_id}_{str(list(p.grad.shape))}" logger.experiment.add_histogram( tag=grad_name, values=grad_val, global_step=self.trainer.global_step ) return super().on_after_backward() def _fold_in_seq_dim(self, out, y): batch_size, seq_len, num_classes = out.shape out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len) if y is None: return out, None if len(y.shape) > 2: y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len) else: y = eo.rearrange(y, "b s -> (b s)", s=seq_len) return out, y def _get_loss(self, out, y, batch): attention_mask = batch[-2] if self.loss_func == "BCELoss": if self.last_activation == "Identity": loss = t.nn.functional.binary_cross_entropy_with_logits(out, y, reduction="none") else: loss = t.nn.functional.binary_cross_entropy(out, y, reduction="none") replace_tensor = t.zeros(loss[1, 1, :].shape, device=loss.device, dtype=loss.dtype, requires_grad=False) loss[~attention_mask.bool()] = replace_tensor loss = loss.mean() elif self.loss_func == "CrossEntropyLoss": if len(out.shape) > 2: out, y = self._fold_in_seq_dim(out, y) loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100) else: loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100) elif self.loss_func == "OrdinalRegLoss": loss = t.nn.functional.mse_loss(out, y, reduction="none") loss = loss[attention_mask.bool()].sum() * 10.0 / attention_mask.sum() elif self.loss_func == "macro_soft_f1": loss = macro_soft_f1(y, out, reduction="mean") elif self.loss_func == "coral_loss": loss = coral_loss(out, y) elif self.loss_func == "corn_loss": out, y = self._fold_in_seq_dim(out, y) loss = corn_loss(out, y.squeeze(), self.out_shape) else: raise ValueError("Loss Function not reckognized") return loss def training_step(self, batch, batch_idx): if self.profile_torch_run: self.profilerr.step() out, y = self.model_step(batch, batch_idx) loss = self._get_loss(out, y, batch) self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True) return loss def forward(*args): return forward(args[0], args[1:]) def model_step(self, batch, batch_idx): out = self.forward(batch) return out, batch[-1] def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_closure, ): optimizer.step(closure=optimizer_closure) if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR": if self.trainer.global_step < self.num_warmup_steps: lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.num_warmup_steps) ** self.warmup_exponent for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.learning_rate if self.trainer.global_step % 10 == 0 or self.trainer.global_step == 0: for idx, pg in enumerate(optimizer.param_groups): self.log(f"lr_{idx}", pg["lr"], prog_bar=True, sync_dist=True) def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None: if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR": if self.trainer.global_step > self.num_warmup_steps: if metric is None: scheduler.step() else: scheduler.step(metric) else: if metric is None: scheduler.step() else: scheduler.step(metric) def _get_preds_reals(self, out, y): if self.loss_func == "corn_loss": seq_len = out.shape[1] out, y = self._fold_in_seq_dim(out, y) preds = corn_label_from_logits(out) preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len) if y is not None: y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len) elif self.loss_func == "OrdinalRegLoss": preds = out * (self.ord_reg_loss_max - self.ord_reg_loss_min) preds = (preds + self.ord_reg_loss_min).round().to(t.long) else: preds = t.argmax(out, dim=-1) if y is None: return preds, y, -100 else: if self.using_one_hot_targets: y_onecold = t.argmax(y, dim=-1) ignore_index_val = 0 elif self.loss_func == "OrdinalRegLoss": y_onecold = (y * self.num_classes).round().to(t.long) y_onecold = y * (self.ord_reg_loss_max - self.ord_reg_loss_min) y_onecold = (y_onecold + self.ord_reg_loss_min).round().to(t.long) ignore_index_val = t.min(y_onecold).to(t.long) else: y_onecold = y ignore_index_val = -100 if len(preds.shape) > len(y_onecold.shape): preds = preds.squeeze() return preds, y_onecold, ignore_index_val def validation_step(self, batch, batch_idx): out, y = self.model_step(batch, batch_idx) preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y) if self.loss_func == "OrdinalRegLoss": y_onecold = y_onecold.flatten() preds = preds.flatten()[y_onecold != ignore_index_val] y_onecold = y_onecold[y_onecold != ignore_index_val] acc = (preds == y_onecold).sum() / len(y_onecold) else: acc = torchmetrics.functional.accuracy( preds, y_onecold.to(t.long), ignore_index=ignore_index_val, num_classes=self.num_classes, task="multiclass", ) self.log("acc", acc * 100, prog_bar=True, sync_dist=True) loss = self._get_loss(out, y, batch) self.log("val_loss", loss, prog_bar=True, sync_dist=True) return loss def predict_step(self, batch, batch_idx): out, y = self.model_step(batch, batch_idx) preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y) return preds, y_onecold def configure_optimizers(self): params = list(self.named_parameters()) def is_chars_conv(n): if "chars_conv" not in n: return False if "chars_conv" in n and "classifier" in n: return False else: return True grouped_parameters = [ { "params": [p for n, p in params if is_chars_conv(n)], "lr": self.learning_rate / self.chars_conv_lr_reduction_factor, "weight_decay": self.weight_decay, }, { "params": [p for n, p in params if not is_chars_conv(n)], "lr": self.learning_rate, "weight_decay": self.weight_decay, }, ] opti = t.optim.AdamW(grouped_parameters, lr=self.learning_rate, weight_decay=self.weight_decay) if self.use_reduce_on_plateau: opti_dict = { "optimizer": opti, "lr_scheduler": { "scheduler": t.optim.lr_scheduler.ReduceLROnPlateau(opti, mode="min", patience=2, factor=0.5), "monitor": "val_loss", "frequency": 1, "interval": "epoch", }, } return opti_dict else: cfg = self.hparams["cfg"] if cfg["use_reduce_on_plateau"]: scheduler = None elif cfg["lr_scheduling"] == "multistep": scheduler = t.optim.lr_scheduler.MultiStepLR( opti, milestones=cfg["multistep_milestones"], gamma=cfg["gamma_multistep"], verbose=False ) interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch" elif cfg["lr_scheduling"] == "StepLR": scheduler = t.optim.lr_scheduler.StepLR( opti, step_size=cfg["gamma_step_size"], gamma=cfg["gamma_step_factor"] ) interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch" elif cfg["lr_scheduling"] == "anneal": scheduler = t.optim.lr_scheduler.CosineAnnealingLR( opti, 250, eta_min=cfg["min_lr_anneal"], last_epoch=-1, verbose=False ) interval = "step" elif cfg["lr_scheduling"] == "ExponentialLR": scheduler = t.optim.lr_scheduler.ExponentialLR(opti, gamma=cfg["lr_sched_exp_fac"]) interval = "step" else: scheduler = None if scheduler is None: return [opti] else: opti_dict = { "optimizer": opti, "lr_scheduler": { "scheduler": scheduler, "monitor": "global_step", "frequency": 1, "interval": interval, }, } return opti_dict def on_fit_start(self) -> None: if self.profile_torch_run: self.profilerr.start() return super().on_fit_start() def on_fit_end(self) -> None: if self.profile_torch_run: self.profilerr.stop() return super().on_fit_end() def prep_model_input(self, batch): if len(batch) == 1: batch = batch[0] if self.use_char_embed_info: if len(batch) == 5: x, chars_coords, ims, attention_mask, _ = batch elif batch[1].ndim == 4: x, ims, attention_mask, _ = batch else: x, chars_coords, attention_mask, _ = batch padding_list = None else: if len(batch) > 3: x = batch[0] y = batch[-1] attention_mask = batch[1] else: x, attention_mask, y = batch if self.model_to_use != "cv_only_model" and not self.hparams.cfg["only_use_2nd_input_stream"]: x_embedded = self.project(x) else: x_embedded = x if self.use_char_embed_info: if self.method_chars_into_model == "dense": bool_mask = chars_coords == self.input_padding_val bool_mask = bool_mask[:, :, 0] chars_coords_projected = self.chars_project_0(chars_coords).squeeze(-1) chars_coords_projected = chars_coords_projected * bool_mask if self.chars_project_1.in_features == chars_coords_projected.shape[-1]: chars_coords_projected = self.chars_project_1(chars_coords_projected) else: chars_coords_projected = chars_coords_projected.mean(dim=-1) chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[2]) elif self.method_chars_into_model == "bert": chars_mask = chars_coords != self.input_padding_val chars_mask = t.cat( ( t.ones(chars_mask[:, :1, 0].shape, dtype=t.long, device=chars_coords.device), chars_mask[:, :, 0].to(t.long), ), dim=1, ) chars_coords_projected = self.chars_project(chars_coords) position_ids = t.arange( 0, chars_coords_projected.shape[1] + 1, dtype=t.long, device=chars_coords_projected.device ) token_type_ids = t.zeros( (chars_coords_projected.size()[0], chars_coords_projected.size()[1] + 1), dtype=t.long, device=chars_coords_projected.device, ) # +1 for CLS chars_coords_projected = t.cat( (t.ones_like(chars_coords_projected[:, :1, :]), chars_coords_projected), dim=1 ) # to add CLS token chars_coords_projected = self.chars_bert( position_ids=position_ids, inputs_embeds=chars_coords_projected, token_type_ids=token_type_ids, attention_mask=chars_mask, ) if hasattr(chars_coords_projected, "last_hidden_state"): chars_coords_projected = chars_coords_projected.last_hidden_state[:, 0, :] elif hasattr(chars_coords_projected, "logits"): chars_coords_projected = chars_coords_projected.logits else: chars_coords_projected = chars_coords_projected.hidden_states[-1][:, 0, :] elif self.method_chars_into_model == "resnet": chars_conv_out = self.chars_conv(ims) if isinstance(chars_conv_out, list): chars_conv_out = chars_conv_out[0] if hasattr(chars_conv_out, "logits"): chars_conv_out = chars_conv_out.logits chars_coords_projected = self.chars_classifier(chars_conv_out) chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[1], 1) if hasattr(self, "chars_project_class_output"): chars_coords_projected = self.chars_project_class_output(chars_coords_projected) if self.hparams.cfg["only_use_2nd_input_stream"]: x_embedded = chars_coords_projected elif self.method_to_include_char_positions == "concat": x_embedded = t.cat((x_embedded, chars_coords_projected), dim=-1) else: x_embedded = x_embedded + chars_coords_projected return x_embedded, attention_mask def forward(self, batch): prepped_input = prep_model_input(self, batch) if len(batch) > 5: x_embedded, padding_list, attention_mask, attention_mask_for_prediction = prepped_input elif len(batch) > 2: x_embedded, attention_mask = prepped_input else: x_embedded = prepped_input[0] attention_mask = prepped_input[-1] position_ids = t.arange(0, x_embedded.shape[1], dtype=t.long, device=x_embedded.device) token_type_ids = t.zeros(x_embedded.size()[:-1], dtype=t.long, device=x_embedded.device) if self.layer_norm_after_in_projection: x_embedded = self.layer_norm_in(x_embedded) if self.model_to_use == "LSTM": bert_out = self.bert_model(x_embedded) elif self.model_to_use in ["ProphetNet", "T5", "FunnelModel"]: bert_out = self.bert_model(inputs_embeds=x_embedded, attention_mask=attention_mask) elif self.model_to_use == "xBERT": bert_out = self.bert_model(x_embedded, mask=attention_mask.to(bool)) elif self.model_to_use == "cv_only_model": bert_out = self.bert_model(x_embedded) else: bert_out = self.bert_model( position_ids=position_ids, inputs_embeds=x_embedded, token_type_ids=token_type_ids, attention_mask=attention_mask, ) if hasattr(bert_out, "last_hidden_state"): last_hidden_state = bert_out.last_hidden_state out = self.linear(last_hidden_state) elif hasattr(bert_out, "logits"): out = bert_out.logits else: out = bert_out out = self.final_activation(out) return out