Text Classification
Transformers
PyTorch
bert
Inference Endpoints
foody-bert / foodybert.py
rttl's picture
Upload foodybert.py
79ff564
raw
history blame
5.95 kB
import os
from sklearn.metrics import classification_report
import torch.nn as nn
import transformers
from transformers import BertModel, BertTokenizer, BertForSequenceClassification
import numpy as np
from datasets import load_dataset, load_metric
import math
import warnings
from dataclasses import dataclass
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import List, Optional, Tuple, Union
import torch
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
class FoodyBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.pre_classifier = torch.nn.Linear(4*config.hidden_size, 4*config.hidden_size)
self.tanh = nn.Tanh()
#self.relu = nn.ReLU()
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(4*config.hidden_size, config.num_labels)
self.post_init()
def post_init(self):
pass
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# outputs is a tuple contains(last_hidden_state, pooler_output,hidden_states...+3more)
# [0] last_hidden_state -> tensor [batch, #tokens, 768]
# [1] pooler_output -> tensor [1, 768]
# [2] hidden_states -> tuple with 13 tensors of size [batch, #tokens,768]
# use only CLS
#pooled_output = outputs[1]
#average actoss tokens at the last layer
#last_state = outputs[0]
#pooled_output = torch.mean(last_state,1)
# use hidden_states and concatenate layers -> change classifier dimensions!
hidden_states = outputs[2]
#concatenate 4 layers and average tokens
pooled_output = torch.cat(tuple([hidden_states[i] for i in [-4,-3,-2,-1]]), dim = -1)
pooled_output = torch.mean(pooled_output,1)
#concatenate 4 layers and use CLS
#pooled_output = torch.cat(tuple([hidden_states[i] for i in [-4,-3,-2,-1]]), dim = -1)
#pooled_output = pooled_output[:, 0, :]
pooled_output = self.pre_classifier(pooled_output)
pooled_output = self.tanh(pooled_output)
#pooled_output = self.relu(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)