|
from transformers import PretrainedConfig, BertConfig |
|
from typing import List |
|
|
|
class VGCNConfig(BertConfig): |
|
model_type = "vgcn" |
|
|
|
def __init__( |
|
self, |
|
bert_model='readerbench/RoBERT-base', |
|
gcn_adj_matrix: str ='', |
|
max_seq_len: int = 256, |
|
npmi_threshold: float = 0.2, |
|
tf_threshold: float = 0.0, |
|
vocab_type: str = "all", |
|
gcn_embedding_dim: int = 32, |
|
**kwargs, |
|
): |
|
if vocab_type not in ["all", "pmi", "tf"]: |
|
raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.") |
|
if max_seq_len < 1 or max_seq_len > 512: |
|
raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.") |
|
if npmi_threshold < 0.0 or npmi_threshold > 1.0: |
|
raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.") |
|
if tf_threshold < 0.0 or tf_threshold > 1.0: |
|
raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.") |
|
|
|
self.gcn_adj_matrix = gcn_adj_matrix |
|
self.max_seq_len = max_seq_len |
|
self.npmi_threshold = npmi_threshold |
|
self.tf_threshold = tf_threshold |
|
self.vocab_type = vocab_type |
|
self.gcn_embedding_dim = gcn_embedding_dim |
|
self.bert_model = bert_model |
|
|
|
super().__init__(**kwargs) |