|
import timm |
|
import os |
|
from typing import Any |
|
from pytorch_lightning.utilities.types import LRSchedulerTypeUnion |
|
import torch as t |
|
from torch import nn |
|
import transformers |
|
import pytorch_lightning as plight |
|
import torchmetrics |
|
import einops as eo |
|
from loss_functions import corn_loss, corn_label_from_logits |
|
|
|
t.set_float32_matmul_precision("medium") |
|
global_settings = dict(try_using_torch_compile=False) |
|
|
|
|
|
class EnsembleModel(plight.LightningModule): |
|
def __init__(self, models_without_norm_df, models_with_norm_df, learning_rate=0.0002, use_simple_average=False): |
|
super().__init__() |
|
self.models_without_norm = nn.ModuleList(list(models_without_norm_df)) |
|
self.models_with_norm = nn.ModuleList(list(models_with_norm_df)) |
|
self.learning_rate = learning_rate |
|
self.use_simple_average = use_simple_average |
|
|
|
if not self.use_simple_average: |
|
self.combiner = nn.Linear( |
|
self.models_with_norm[0].num_classes * (len(self.models_with_norm) + len(self.models_without_norm)), |
|
self.models_with_norm[0].num_classes, |
|
) |
|
|
|
def forward(self, x): |
|
x_unnormed, x_normed = x |
|
if not self.use_simple_average: |
|
out_unnormed = t.cat([model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm], dim=-1) |
|
out_normed = t.cat([model.model_step(x_normed, 0)[0] for model in self.models_with_norm], dim=-1) |
|
out_avg = self.combiner(t.cat((out_unnormed, out_normed), dim=-1)) |
|
else: |
|
out_unnormed = [model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm] |
|
out_normed = [model.model_step(x_normed, 0)[0] for model in self.models_with_norm] |
|
|
|
out_avg = (t.stack(out_unnormed + out_normed, dim=-1) / 2).mean(-1) |
|
return {"out_avg": out_avg, "out_unnormed": out_unnormed, "out_normed": out_normed}, x_unnormed[-1] |
|
|
|
def training_step(self, batch, batch_idx): |
|
out, y = self(batch) |
|
loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0]) |
|
self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
out, y = self(batch) |
|
preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y) |
|
acc = torchmetrics.functional.accuracy( |
|
preds, |
|
y_onecold.to(t.long), |
|
ignore_index=ignore_index_val, |
|
num_classes=self.models_with_norm[0].num_classes, |
|
task="multiclass", |
|
) |
|
self.log("acc", acc * 100, prog_bar=True, sync_dist=True) |
|
loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0]) |
|
self.log("val_loss", loss, prog_bar=True, sync_dist=True) |
|
return loss |
|
|
|
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): |
|
out, y = self(batch) |
|
preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y) |
|
return preds, out, y_onecold |
|
|
|
def configure_optimizers(self): |
|
return t.optim.Adam(self.parameters(), lr=self.learning_rate) |
|
|
|
|
|
class TimmHeadReplace(nn.Module): |
|
def __init__(self, pooling=None, in_channels=512, pooling_output_dimension=1, all_identity=False) -> None: |
|
super().__init__() |
|
|
|
if all_identity: |
|
self.head = nn.Identity() |
|
self.pooling = None |
|
else: |
|
self.pooling = pooling |
|
if pooling is not None: |
|
self.pooling_output_dimension = pooling_output_dimension |
|
if self.pooling == "AdaptiveAvgPool2d": |
|
self.pooling_layer = nn.AdaptiveAvgPool2d(pooling_output_dimension) |
|
elif self.pooling == "AdaptiveMaxPool2d": |
|
self.pooling_layer = nn.AdaptiveMaxPool2d(pooling_output_dimension) |
|
self.head = nn.Flatten() |
|
|
|
def forward(self, x, pre_logits=False): |
|
if self.pooling is not None: |
|
if self.pooling == "stack_avg_max_attn": |
|
x = t.cat([layer(x) for layer in self.pooling_layer], dim=-1) |
|
else: |
|
x = self.pooling_layer(x) |
|
return self.head(x) |
|
|
|
|
|
class CVModel(nn.Module): |
|
def __init__( |
|
self, |
|
modelname, |
|
in_shape, |
|
num_classes, |
|
loss_func, |
|
last_activation: str, |
|
input_padding_val=10, |
|
char_dims=2, |
|
max_seq_length=1000, |
|
) -> None: |
|
super().__init__() |
|
self.modelname = modelname |
|
self.loss_func = loss_func |
|
self.in_shape = in_shape |
|
self.char_dims = char_dims |
|
self.x_shape = in_shape |
|
self.last_activation = last_activation |
|
self.max_seq_length = max_seq_length |
|
self.num_classes = num_classes |
|
if self.loss_func == "OrdinalRegLoss": |
|
self.out_shape = 1 |
|
else: |
|
self.out_shape = num_classes |
|
|
|
self.cv_model = timm.create_model(modelname, pretrained=True, num_classes=0) |
|
self.cv_model.classifier = nn.Identity() |
|
with t.inference_mode(): |
|
test_out = self.cv_model(t.ones(self.in_shape, dtype=t.float32)) |
|
self.cv_model_out_dim = test_out.shape[1] |
|
self.cv_model.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.cv_model_out_dim, self.max_seq_length)) |
|
if self.out_shape == 1: |
|
self.logit_norm = nn.Identity() |
|
self.out_project = nn.Identity() |
|
else: |
|
self.logit_norm = nn.LayerNorm(self.max_seq_length) |
|
self.out_project = nn.Linear(1, self.out_shape) |
|
|
|
if last_activation == "Softmax": |
|
self.final_activation = nn.Softmax(dim=-1) |
|
elif last_activation == "Sigmoid": |
|
self.final_activation = nn.Sigmoid() |
|
elif last_activation == "LogSigmoid": |
|
self.final_activation = nn.LogSigmoid() |
|
elif last_activation == "Identity": |
|
self.final_activation = nn.Identity() |
|
else: |
|
raise NotImplementedError(f"{last_activation} not implemented") |
|
|
|
def forward(self, x): |
|
if isinstance(x, list): |
|
x = x[0] |
|
x = self.cv_model(x) |
|
x = self.cv_model.classifier(x).unsqueeze(-1) |
|
x = self.out_project(x) |
|
return self.final_activation(x) |
|
|
|
|
|
class LitModel(plight.LightningModule): |
|
def __init__( |
|
self, |
|
in_shape: tuple, |
|
hidden_dim: int, |
|
num_attention_heads: int, |
|
num_layers: int, |
|
loss_func: str, |
|
learning_rate: float, |
|
weight_decay: float, |
|
cfg: dict, |
|
use_lr_warmup: bool, |
|
use_reduce_on_plateau: bool, |
|
track_gradient_histogram=False, |
|
register_forw_hook=False, |
|
char_dims=2, |
|
) -> None: |
|
super().__init__() |
|
if "only_use_2nd_input_stream" not in cfg: |
|
cfg["only_use_2nd_input_stream"] = False |
|
|
|
if "gamma_step_size" not in cfg: |
|
cfg["gamma_step_size"] = 5 |
|
if "gamma_step_factor" not in cfg: |
|
cfg["gamma_step_factor"] = 0.5 |
|
self.save_hyperparameters( |
|
dict( |
|
in_shape=in_shape, |
|
hidden_dim=hidden_dim, |
|
num_attention_heads=num_attention_heads, |
|
num_layers=num_layers, |
|
loss_func=loss_func, |
|
learning_rate=learning_rate, |
|
cfg=cfg, |
|
x_shape=in_shape, |
|
num_classes=cfg["num_classes"], |
|
use_lr_warmup=use_lr_warmup, |
|
num_warmup_steps=cfg["num_warmup_steps"], |
|
use_reduce_on_plateau=use_reduce_on_plateau, |
|
weight_decay=weight_decay, |
|
track_gradient_histogram=track_gradient_histogram, |
|
register_forw_hook=register_forw_hook, |
|
char_dims=char_dims, |
|
remove_timm_classifier_head_pooling=cfg["remove_timm_classifier_head_pooling"], |
|
change_pooling_for_timm_head_to=cfg["change_pooling_for_timm_head_to"], |
|
chars_conv_pooling_out_dim=cfg["chars_conv_pooling_out_dim"], |
|
) |
|
) |
|
self.model_to_use = cfg["model_to_use"] |
|
self.num_classes = cfg["num_classes"] |
|
self.x_shape = in_shape |
|
self.in_shape = in_shape |
|
self.hidden_dim = hidden_dim |
|
self.num_attention_heads = num_attention_heads |
|
self.num_layers = num_layers |
|
|
|
self.use_lr_warmup = use_lr_warmup |
|
self.num_warmup_steps = cfg["num_warmup_steps"] |
|
self.warmup_exponent = cfg["warmup_exponent"] |
|
|
|
self.use_reduce_on_plateau = use_reduce_on_plateau |
|
self.loss_func = loss_func |
|
self.learning_rate = learning_rate |
|
self.weight_decay = weight_decay |
|
self.using_one_hot_targets = cfg["one_hot_y"] |
|
self.track_gradient_histogram = track_gradient_histogram |
|
self.register_forw_hook = register_forw_hook |
|
if self.loss_func == "OrdinalRegLoss": |
|
self.ord_reg_loss_max = cfg["ord_reg_loss_max"] |
|
self.ord_reg_loss_min = cfg["ord_reg_loss_min"] |
|
|
|
self.num_lin_layers = cfg["num_lin_layers"] |
|
self.linear_activation = cfg["linear_activation"] |
|
self.last_activation = cfg["last_activation"] |
|
|
|
self.max_seq_length = cfg["manual_max_sequence_for_model"] |
|
|
|
self.use_char_embed_info = cfg["use_embedded_char_pos_info"] |
|
|
|
self.method_chars_into_model = cfg["method_chars_into_model"] |
|
self.source_for_pretrained_cv_model = cfg["source_for_pretrained_cv_model"] |
|
self.method_to_include_char_positions = cfg["method_to_include_char_positions"] |
|
|
|
self.char_dims = char_dims |
|
self.char_sequence_length = cfg["max_len_chars_list"] if self.use_char_embed_info else 0 |
|
|
|
self.chars_conv_lr_reduction_factor = cfg["chars_conv_lr_reduction_factor"] |
|
if self.use_char_embed_info: |
|
self.chars_bert_reduction_factor = cfg["chars_bert_reduction_factor"] |
|
|
|
self.use_in_projection_bias = cfg["use_in_projection_bias"] |
|
self.add_layer_norm_to_in_projection = cfg["add_layer_norm_to_in_projection"] |
|
|
|
self.hidden_dropout_prob = cfg["hidden_dropout_prob"] |
|
self.layer_norm_after_in_projection = cfg["layer_norm_after_in_projection"] |
|
self.method_chars_into_model = cfg["method_chars_into_model"] |
|
self.input_padding_val = cfg["input_padding_val"] |
|
self.cv_char_modelname = cfg["cv_char_modelname"] |
|
self.char_plot_shape = cfg["char_plot_shape"] |
|
|
|
self.remove_timm_classifier_head_pooling = cfg["remove_timm_classifier_head_pooling"] |
|
self.change_pooling_for_timm_head_to = cfg["change_pooling_for_timm_head_to"] |
|
self.chars_conv_pooling_out_dim = cfg["chars_conv_pooling_out_dim"] |
|
|
|
self.add_layer_norm_to_char_mlp = cfg["add_layer_norm_to_char_mlp"] |
|
if "profile_torch_run" in cfg: |
|
self.profile_torch_run = cfg["profile_torch_run"] |
|
else: |
|
self.profile_torch_run = False |
|
if self.loss_func == "OrdinalRegLoss": |
|
self.out_shape = 1 |
|
else: |
|
self.out_shape = cfg["num_classes"] |
|
|
|
if not self.hparams.cfg["only_use_2nd_input_stream"]: |
|
if ( |
|
self.method_chars_into_model == "dense" |
|
and self.use_char_embed_info |
|
and self.method_to_include_char_positions == "concat" |
|
): |
|
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias) |
|
elif ( |
|
self.method_chars_into_model == "bert" |
|
and self.use_char_embed_info |
|
and self.method_to_include_char_positions == "concat" |
|
): |
|
self.hidden_dim_chars = self.hidden_dim // 2 |
|
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim_chars, bias=self.use_in_projection_bias) |
|
elif ( |
|
self.method_chars_into_model == "resnet" |
|
and self.method_to_include_char_positions == "concat" |
|
and self.use_char_embed_info |
|
): |
|
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias) |
|
elif self.model_to_use == "cv_only_model": |
|
self.project = nn.Identity() |
|
else: |
|
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim, bias=self.use_in_projection_bias) |
|
if self.add_layer_norm_to_in_projection: |
|
self.project = nn.Sequential( |
|
nn.Linear(self.project.in_features, self.project.out_features, bias=self.use_in_projection_bias), |
|
nn.LayerNorm(self.project.out_features), |
|
) |
|
|
|
if hasattr(self, "project") and "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.project = t.compile(self.project) |
|
|
|
if self.use_char_embed_info: |
|
self._create_char_model() |
|
|
|
if self.layer_norm_after_in_projection: |
|
if self.hparams.cfg["only_use_2nd_input_stream"]: |
|
self.layer_norm_in = nn.LayerNorm(self.hidden_dim // 2) |
|
else: |
|
self.layer_norm_in = nn.LayerNorm(self.hidden_dim) |
|
|
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.layer_norm_in = t.compile(self.layer_norm_in) |
|
|
|
self._create_main_seq_model(cfg) |
|
|
|
if register_forw_hook: |
|
self.register_hooks() |
|
if self.hparams.cfg["only_use_2nd_input_stream"]: |
|
linear_in_dim = self.hidden_dim // 2 |
|
else: |
|
linear_in_dim = self.hidden_dim |
|
|
|
if self.num_lin_layers == 1: |
|
self.linear = nn.Linear(linear_in_dim, self.out_shape) |
|
else: |
|
lin_layers = [] |
|
for _ in range(self.num_lin_layers - 1): |
|
lin_layers.extend( |
|
[ |
|
nn.Linear(linear_in_dim, linear_in_dim), |
|
getattr(nn, self.linear_activation)(), |
|
] |
|
) |
|
self.linear = nn.Sequential(*lin_layers, nn.Linear(linear_in_dim, self.out_shape)) |
|
|
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.linear = t.compile(self.linear) |
|
|
|
if self.last_activation == "Softmax": |
|
self.final_activation = nn.Softmax(dim=-1) |
|
elif self.last_activation == "Sigmoid": |
|
self.final_activation = nn.Sigmoid() |
|
elif self.last_activation == "Identity": |
|
self.final_activation = nn.Identity() |
|
else: |
|
raise NotImplementedError(f"{self.last_activation} not implemented") |
|
|
|
if self.profile_torch_run: |
|
self.profilerr = t.profiler.profile( |
|
schedule=t.profiler.schedule(wait=1, warmup=10, active=10, repeat=1), |
|
on_trace_ready=t.profiler.tensorboard_trace_handler("tblogs"), |
|
with_stack=True, |
|
record_shapes=True, |
|
profile_memory=False, |
|
) |
|
|
|
def _create_main_seq_model(self, cfg): |
|
if self.hparams.cfg["only_use_2nd_input_stream"]: |
|
hidden_dim = self.hidden_dim // 2 |
|
else: |
|
hidden_dim = self.hidden_dim |
|
if self.model_to_use == "BERT": |
|
self.bert_config = transformers.BertConfig( |
|
vocab_size=self.x_shape[-1], |
|
hidden_size=hidden_dim, |
|
num_hidden_layers=self.num_layers, |
|
intermediate_size=hidden_dim, |
|
num_attention_heads=self.num_attention_heads, |
|
max_position_embeddings=self.max_seq_length, |
|
) |
|
self.bert_model = transformers.BertModel(self.bert_config) |
|
elif self.model_to_use == "cv_only_model": |
|
self.bert_model = CVModel( |
|
modelname=cfg["cv_modelname"], |
|
in_shape=self.in_shape, |
|
num_classes=cfg["num_classes"], |
|
loss_func=cfg["loss_function"], |
|
last_activation=cfg["last_activation"], |
|
input_padding_val=cfg["input_padding_val"], |
|
char_dims=self.char_dims, |
|
max_seq_length=cfg["manual_max_sequence_for_model"], |
|
) |
|
else: |
|
raise NotImplementedError(f"{self.model_to_use} not implemented") |
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.bert_model = t.compile(self.bert_model) |
|
return 0 |
|
|
|
def _create_char_model(self): |
|
if self.method_chars_into_model == "dense": |
|
self.chars_project_0 = nn.Linear(self.char_dims, 1, bias=self.use_in_projection_bias) |
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_project_0 = t.compile(self.chars_project_0) |
|
if self.method_to_include_char_positions == "concat": |
|
self.chars_project_1 = nn.Linear( |
|
self.char_sequence_length, self.hidden_dim // 2, bias=self.use_in_projection_bias |
|
) |
|
else: |
|
self.chars_project_1 = nn.Linear( |
|
self.char_sequence_length, self.hidden_dim, bias=self.use_in_projection_bias |
|
) |
|
|
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_project_1 = t.compile(self.chars_project_1) |
|
elif not self.method_chars_into_model == "resnet": |
|
self.chars_project = nn.Linear(self.char_dims, self.hidden_dim_chars, bias=self.use_in_projection_bias) |
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_project = t.compile(self.chars_project) |
|
|
|
if self.method_chars_into_model == "bert": |
|
if not hasattr(self, "hidden_dim_chars"): |
|
if self.hidden_dim // self.chars_bert_reduction_factor > 1: |
|
self.hidden_dim_chars = self.hidden_dim // self.chars_bert_reduction_factor |
|
else: |
|
self.hidden_dim_chars = self.hidden_dim |
|
self.num_attention_heads_chars = self.hidden_dim_chars // (self.hidden_dim // self.num_attention_heads) |
|
self.chars_bert_config = transformers.BertConfig( |
|
vocab_size=self.x_shape[-1], |
|
hidden_size=self.hidden_dim_chars, |
|
num_hidden_layers=self.num_layers, |
|
intermediate_size=self.hidden_dim_chars, |
|
num_attention_heads=self.num_attention_heads_chars, |
|
max_position_embeddings=self.char_sequence_length + 1, |
|
num_labels=1, |
|
) |
|
self.chars_bert = transformers.BertForSequenceClassification(self.chars_bert_config) |
|
|
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_bert = t.compile(self.chars_bert) |
|
self.chars_project_class_output = nn.Linear(1, self.hidden_dim_chars, bias=self.use_in_projection_bias) |
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_project_class_output = t.compile(self.chars_project_class_output) |
|
elif self.method_chars_into_model == "resnet": |
|
if self.source_for_pretrained_cv_model == "timm": |
|
self.chars_conv = timm.create_model( |
|
self.cv_char_modelname, |
|
pretrained=True, |
|
num_classes=0, |
|
) |
|
if self.remove_timm_classifier_head_pooling: |
|
self.chars_conv.head = TimmHeadReplace(all_identity=True) |
|
with t.inference_mode(): |
|
test_out = self.chars_conv( |
|
t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32) |
|
) |
|
if test_out.ndim > 3: |
|
self.chars_conv.head = TimmHeadReplace( |
|
self.change_pooling_for_timm_head_to, |
|
test_out.shape[1], |
|
) |
|
elif self.source_for_pretrained_cv_model == "huggingface": |
|
self.chars_conv = transformers.AutoModelForImageClassification.from_pretrained(self.cv_char_modelname) |
|
elif self.source_for_pretrained_cv_model == "torch_hub": |
|
self.chars_conv = t.hub.load(*self.cv_char_modelname.split(",")) |
|
|
|
if hasattr(self.chars_conv, "classifier"): |
|
self.chars_conv.classifier = nn.Identity() |
|
elif hasattr(self.chars_conv, "cls_classifier"): |
|
self.chars_conv.cls_classifier = nn.Identity() |
|
elif hasattr(self.chars_conv, "fc"): |
|
self.chars_conv.fc = nn.Identity() |
|
|
|
if hasattr(self.chars_conv, "distillation_classifier"): |
|
self.chars_conv.distillation_classifier = nn.Identity() |
|
with t.inference_mode(): |
|
test_out = self.chars_conv( |
|
t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32) |
|
) |
|
if hasattr(test_out, "last_hidden_state"): |
|
self.chars_conv_out_dim = test_out.last_hidden_state.shape[1] |
|
elif hasattr(test_out, "logits"): |
|
self.chars_conv_out_dim = test_out.logits.shape[1] |
|
elif isinstance(test_out, list): |
|
self.chars_conv_out_dim = test_out[0].shape[1] |
|
else: |
|
self.chars_conv_out_dim = test_out.shape[1] |
|
|
|
char_lin_layers = [nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)] |
|
if self.add_layer_norm_to_char_mlp: |
|
char_lin_layers.append(nn.LayerNorm(self.hidden_dim // 2)) |
|
self.chars_classifier = nn.Sequential(*char_lin_layers) |
|
if hasattr(self.chars_conv, "distillation_classifier"): |
|
self.chars_conv.distillation_classifier = nn.Sequential( |
|
nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2) |
|
) |
|
|
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_classifier = t.compile(self.chars_classifier) |
|
if "posix" in os.name and global_settings["try_using_torch_compile"]: |
|
self.chars_conv = t.compile(self.chars_conv) |
|
return 0 |
|
|
|
def register_hooks(self): |
|
def add_to_tb(layer): |
|
def hook(model, input, output): |
|
if hasattr(output, "detach"): |
|
for logger in self.loggers: |
|
if hasattr(logger.experiment, "add_histogram"): |
|
logger.experiment.add_histogram( |
|
tag=f"{layer}_{str(list(output.shape))}", |
|
values=output.detach(), |
|
global_step=self.trainer.global_step, |
|
) |
|
|
|
return hook |
|
|
|
for layer_id, layer in dict([*self.named_modules()]).items(): |
|
layer.register_forward_hook(add_to_tb(f"act_{layer_id}")) |
|
|
|
def on_after_backward(self) -> None: |
|
if self.track_gradient_histogram: |
|
if self.trainer.global_step % 200 == 0: |
|
for logger in self.loggers: |
|
if hasattr(logger.experiment, "add_histogram"): |
|
for layer_id, layer in dict([*self.named_modules()]).items(): |
|
parameters = layer.parameters() |
|
for idx2, p in enumerate(parameters): |
|
grad_val = p.grad |
|
if grad_val is not None: |
|
grad_name = f"grad_{idx2}_{layer_id}_{str(list(p.grad.shape))}" |
|
logger.experiment.add_histogram( |
|
tag=grad_name, values=grad_val, global_step=self.trainer.global_step |
|
) |
|
|
|
return super().on_after_backward() |
|
|
|
def _fold_in_seq_dim(self, out, y): |
|
batch_size, seq_len, num_classes = out.shape |
|
out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len) |
|
if y is None: |
|
return out, None |
|
if len(y.shape) > 2: |
|
y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len) |
|
else: |
|
y = eo.rearrange(y, "b s -> (b s)", s=seq_len) |
|
return out, y |
|
|
|
def _get_loss(self, out, y, batch): |
|
attention_mask = batch[-2] |
|
if self.loss_func == "BCELoss": |
|
if self.last_activation == "Identity": |
|
loss = t.nn.functional.binary_cross_entropy_with_logits(out, y, reduction="none") |
|
else: |
|
loss = t.nn.functional.binary_cross_entropy(out, y, reduction="none") |
|
|
|
replace_tensor = t.zeros(loss[1, 1, :].shape, device=loss.device, dtype=loss.dtype, requires_grad=False) |
|
loss[~attention_mask.bool()] = replace_tensor |
|
loss = loss.mean() |
|
elif self.loss_func == "CrossEntropyLoss": |
|
if len(out.shape) > 2: |
|
out, y = self._fold_in_seq_dim(out, y) |
|
loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100) |
|
else: |
|
loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100) |
|
|
|
elif self.loss_func == "OrdinalRegLoss": |
|
loss = t.nn.functional.mse_loss(out, y, reduction="none") |
|
loss = loss[attention_mask.bool()].sum() * 10.0 / attention_mask.sum() |
|
elif self.loss_func == "corn_loss": |
|
out, y = self._fold_in_seq_dim(out, y) |
|
loss = corn_loss(out, y.squeeze(), self.out_shape) |
|
else: |
|
raise ValueError("Loss Function not reckognized") |
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
if self.profile_torch_run: |
|
self.profilerr.step() |
|
out, y = self.model_step(batch, batch_idx) |
|
loss = self._get_loss(out, y, batch) |
|
self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True) |
|
return loss |
|
|
|
def forward(*args): |
|
return forward(args[0], args[1:]) |
|
|
|
def model_step(self, batch, batch_idx): |
|
out = self.forward(batch) |
|
return out, batch[-1] |
|
|
|
def optimizer_step( |
|
self, |
|
epoch, |
|
batch_idx, |
|
optimizer, |
|
optimizer_closure, |
|
): |
|
optimizer.step(closure=optimizer_closure) |
|
|
|
if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR": |
|
if self.trainer.global_step < self.num_warmup_steps: |
|
lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.num_warmup_steps) ** self.warmup_exponent |
|
for pg in optimizer.param_groups: |
|
pg["lr"] = lr_scale * self.hparams.learning_rate |
|
if self.trainer.global_step % 10 == 0 or self.trainer.global_step == 0: |
|
for idx, pg in enumerate(optimizer.param_groups): |
|
self.log(f"lr_{idx}", pg["lr"], prog_bar=True, sync_dist=True) |
|
|
|
def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None: |
|
if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR": |
|
if self.trainer.global_step > self.num_warmup_steps: |
|
if metric is None: |
|
scheduler.step() |
|
else: |
|
scheduler.step(metric) |
|
else: |
|
if metric is None: |
|
scheduler.step() |
|
else: |
|
scheduler.step(metric) |
|
|
|
def _get_preds_reals(self, out, y): |
|
if self.loss_func == "corn_loss": |
|
seq_len = out.shape[1] |
|
out, y = self._fold_in_seq_dim(out, y) |
|
preds = corn_label_from_logits(out) |
|
preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len) |
|
if y is not None: |
|
y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len) |
|
|
|
elif self.loss_func == "OrdinalRegLoss": |
|
preds = out * (self.ord_reg_loss_max - self.ord_reg_loss_min) |
|
preds = (preds + self.ord_reg_loss_min).round().to(t.long) |
|
|
|
else: |
|
preds = t.argmax(out, dim=-1) |
|
if y is None: |
|
return preds, y, -100 |
|
else: |
|
if self.using_one_hot_targets: |
|
y_onecold = t.argmax(y, dim=-1) |
|
ignore_index_val = 0 |
|
elif self.loss_func == "OrdinalRegLoss": |
|
y_onecold = (y * self.num_classes).round().to(t.long) |
|
|
|
y_onecold = y * (self.ord_reg_loss_max - self.ord_reg_loss_min) |
|
y_onecold = (y_onecold + self.ord_reg_loss_min).round().to(t.long) |
|
ignore_index_val = t.min(y_onecold).to(t.long) |
|
else: |
|
y_onecold = y |
|
ignore_index_val = -100 |
|
|
|
if len(preds.shape) > len(y_onecold.shape): |
|
preds = preds.squeeze() |
|
return preds, y_onecold, ignore_index_val |
|
|
|
def validation_step(self, batch, batch_idx): |
|
out, y = self.model_step(batch, batch_idx) |
|
preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y) |
|
|
|
if self.loss_func == "OrdinalRegLoss": |
|
y_onecold = y_onecold.flatten() |
|
preds = preds.flatten()[y_onecold != ignore_index_val] |
|
y_onecold = y_onecold[y_onecold != ignore_index_val] |
|
acc = (preds == y_onecold).sum() / len(y_onecold) |
|
else: |
|
acc = torchmetrics.functional.accuracy( |
|
preds, |
|
y_onecold.to(t.long), |
|
ignore_index=ignore_index_val, |
|
num_classes=self.num_classes, |
|
task="multiclass", |
|
) |
|
self.log("acc", acc * 100, prog_bar=True, sync_dist=True) |
|
loss = self._get_loss(out, y, batch) |
|
self.log("val_loss", loss, prog_bar=True, sync_dist=True) |
|
|
|
return loss |
|
|
|
def predict_step(self, batch, batch_idx): |
|
out, y = self.model_step(batch, batch_idx) |
|
preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y) |
|
return preds, y_onecold |
|
|
|
def configure_optimizers(self): |
|
params = list(self.named_parameters()) |
|
|
|
def is_chars_conv(n): |
|
if "chars_conv" not in n: |
|
return False |
|
if "chars_conv" in n and "classifier" in n: |
|
return False |
|
else: |
|
return True |
|
|
|
grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in params if is_chars_conv(n)], |
|
"lr": self.learning_rate / self.chars_conv_lr_reduction_factor, |
|
"weight_decay": self.weight_decay, |
|
}, |
|
{ |
|
"params": [p for n, p in params if not is_chars_conv(n)], |
|
"lr": self.learning_rate, |
|
"weight_decay": self.weight_decay, |
|
}, |
|
] |
|
opti = t.optim.AdamW(grouped_parameters, lr=self.learning_rate, weight_decay=self.weight_decay) |
|
if self.use_reduce_on_plateau: |
|
opti_dict = { |
|
"optimizer": opti, |
|
"lr_scheduler": { |
|
"scheduler": t.optim.lr_scheduler.ReduceLROnPlateau(opti, mode="min", patience=2, factor=0.5), |
|
"monitor": "val_loss", |
|
"frequency": 1, |
|
"interval": "epoch", |
|
}, |
|
} |
|
return opti_dict |
|
else: |
|
cfg = self.hparams["cfg"] |
|
if cfg["use_reduce_on_plateau"]: |
|
scheduler = None |
|
elif cfg["lr_scheduling"] == "multistep": |
|
scheduler = t.optim.lr_scheduler.MultiStepLR( |
|
opti, milestones=cfg["multistep_milestones"], gamma=cfg["gamma_multistep"], verbose=False |
|
) |
|
interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch" |
|
elif cfg["lr_scheduling"] == "StepLR": |
|
scheduler = t.optim.lr_scheduler.StepLR( |
|
opti, step_size=cfg["gamma_step_size"], gamma=cfg["gamma_step_factor"] |
|
) |
|
interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch" |
|
elif cfg["lr_scheduling"] == "anneal": |
|
scheduler = t.optim.lr_scheduler.CosineAnnealingLR( |
|
opti, 250, eta_min=cfg["min_lr_anneal"], last_epoch=-1, verbose=False |
|
) |
|
interval = "step" |
|
elif cfg["lr_scheduling"] == "ExponentialLR": |
|
scheduler = t.optim.lr_scheduler.ExponentialLR(opti, gamma=cfg["lr_sched_exp_fac"]) |
|
interval = "step" |
|
else: |
|
scheduler = None |
|
if scheduler is None: |
|
return [opti] |
|
else: |
|
opti_dict = { |
|
"optimizer": opti, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"monitor": "global_step", |
|
"frequency": 1, |
|
"interval": interval, |
|
}, |
|
} |
|
return opti_dict |
|
|
|
def on_fit_start(self) -> None: |
|
if self.profile_torch_run: |
|
self.profilerr.start() |
|
return super().on_fit_start() |
|
|
|
def on_fit_end(self) -> None: |
|
if self.profile_torch_run: |
|
self.profilerr.stop() |
|
return super().on_fit_end() |
|
|
|
|
|
def prep_model_input(self, batch): |
|
if len(batch) == 1: |
|
batch = batch[0] |
|
if self.use_char_embed_info: |
|
if len(batch) == 5: |
|
x, chars_coords, ims, attention_mask, _ = batch |
|
elif batch[1].ndim == 4: |
|
x, ims, attention_mask, _ = batch |
|
else: |
|
x, chars_coords, attention_mask, _ = batch |
|
padding_list = None |
|
else: |
|
if len(batch) > 3: |
|
x = batch[0] |
|
y = batch[-1] |
|
attention_mask = batch[1] |
|
else: |
|
x, attention_mask, y = batch |
|
|
|
if self.model_to_use != "cv_only_model" and not self.hparams.cfg["only_use_2nd_input_stream"]: |
|
x_embedded = self.project(x) |
|
else: |
|
x_embedded = x |
|
if self.use_char_embed_info: |
|
if self.method_chars_into_model == "dense": |
|
bool_mask = chars_coords == self.input_padding_val |
|
bool_mask = bool_mask[:, :, 0] |
|
chars_coords_projected = self.chars_project_0(chars_coords).squeeze(-1) |
|
chars_coords_projected = chars_coords_projected * bool_mask |
|
if self.chars_project_1.in_features == chars_coords_projected.shape[-1]: |
|
chars_coords_projected = self.chars_project_1(chars_coords_projected) |
|
else: |
|
chars_coords_projected = chars_coords_projected.mean(dim=-1) |
|
chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[2]) |
|
elif self.method_chars_into_model == "bert": |
|
chars_mask = chars_coords != self.input_padding_val |
|
chars_mask = t.cat( |
|
( |
|
t.ones(chars_mask[:, :1, 0].shape, dtype=t.long, device=chars_coords.device), |
|
chars_mask[:, :, 0].to(t.long), |
|
), |
|
dim=1, |
|
) |
|
chars_coords_projected = self.chars_project(chars_coords) |
|
|
|
position_ids = t.arange( |
|
0, chars_coords_projected.shape[1] + 1, dtype=t.long, device=chars_coords_projected.device |
|
) |
|
token_type_ids = t.zeros( |
|
(chars_coords_projected.size()[0], chars_coords_projected.size()[1] + 1), |
|
dtype=t.long, |
|
device=chars_coords_projected.device, |
|
) |
|
chars_coords_projected = t.cat( |
|
(t.ones_like(chars_coords_projected[:, :1, :]), chars_coords_projected), dim=1 |
|
) |
|
chars_coords_projected = self.chars_bert( |
|
position_ids=position_ids, |
|
inputs_embeds=chars_coords_projected, |
|
token_type_ids=token_type_ids, |
|
attention_mask=chars_mask, |
|
) |
|
if hasattr(chars_coords_projected, "last_hidden_state"): |
|
chars_coords_projected = chars_coords_projected.last_hidden_state[:, 0, :] |
|
elif hasattr(chars_coords_projected, "logits"): |
|
chars_coords_projected = chars_coords_projected.logits |
|
else: |
|
chars_coords_projected = chars_coords_projected.hidden_states[-1][:, 0, :] |
|
elif self.method_chars_into_model == "resnet": |
|
chars_conv_out = self.chars_conv(ims) |
|
if isinstance(chars_conv_out, list): |
|
chars_conv_out = chars_conv_out[0] |
|
if hasattr(chars_conv_out, "logits"): |
|
chars_conv_out = chars_conv_out.logits |
|
chars_coords_projected = self.chars_classifier(chars_conv_out) |
|
|
|
chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[1], 1) |
|
if hasattr(self, "chars_project_class_output"): |
|
chars_coords_projected = self.chars_project_class_output(chars_coords_projected) |
|
|
|
if self.hparams.cfg["only_use_2nd_input_stream"]: |
|
x_embedded = chars_coords_projected |
|
elif self.method_to_include_char_positions == "concat": |
|
x_embedded = t.cat((x_embedded, chars_coords_projected), dim=-1) |
|
else: |
|
x_embedded = x_embedded + chars_coords_projected |
|
return x_embedded, attention_mask |
|
|
|
|
|
def forward(self, batch): |
|
prepped_input = prep_model_input(self, batch) |
|
|
|
if len(batch) > 5: |
|
x_embedded, padding_list, attention_mask, attention_mask_for_prediction = prepped_input |
|
elif len(batch) > 2: |
|
x_embedded, attention_mask = prepped_input |
|
else: |
|
x_embedded = prepped_input[0] |
|
attention_mask = prepped_input[-1] |
|
|
|
position_ids = t.arange(0, x_embedded.shape[1], dtype=t.long, device=x_embedded.device) |
|
token_type_ids = t.zeros(x_embedded.size()[:-1], dtype=t.long, device=x_embedded.device) |
|
|
|
if self.layer_norm_after_in_projection: |
|
x_embedded = self.layer_norm_in(x_embedded) |
|
|
|
if self.model_to_use == "LSTM": |
|
bert_out = self.bert_model(x_embedded) |
|
elif self.model_to_use in ["ProphetNet", "T5", "FunnelModel"]: |
|
bert_out = self.bert_model(inputs_embeds=x_embedded, attention_mask=attention_mask) |
|
elif self.model_to_use == "xBERT": |
|
bert_out = self.bert_model(x_embedded, mask=attention_mask.to(bool)) |
|
elif self.model_to_use == "cv_only_model": |
|
bert_out = self.bert_model(x_embedded) |
|
else: |
|
bert_out = self.bert_model( |
|
position_ids=position_ids, |
|
inputs_embeds=x_embedded, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
) |
|
if hasattr(bert_out, "last_hidden_state"): |
|
last_hidden_state = bert_out.last_hidden_state |
|
out = self.linear(last_hidden_state) |
|
elif hasattr(bert_out, "logits"): |
|
out = bert_out.logits |
|
else: |
|
out = bert_out |
|
out = self.final_activation(out) |
|
return out |
|
|