drama-large / modeling_drama.py
ccsasuke's picture
Initial commit
84dbc89
from __future__ import annotations
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
class DramaModel(LlamaModel):
"""
DramaModel is a modified version of the LlamaModel that supports bi-directional attention
and provides query and document encoding functionalities.
"""
def __init__(self, config: LlamaConfig):
"""
Initializes the DramaModel by disabling causal masking in self-attention layers.
"""
super().__init__(config)
for layer in self.layers:
layer.self_attn.is_causal = False
# query prefix
self.query_prefix = "Query: "
self.max_seq_len = 8192
self.hidden_size = config.hidden_size
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens=None,
output_attentions=False,
):
"""
Updates the causal mask for attention computations.
"""
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if attention_mask is None or attention_mask.dim() == 4:
return attention_mask
return AttentionMaskConverter._expand_mask(
mask=attention_mask,
dtype=input_tensor.dtype,
)
def _average_pool(
self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Computes the average pooled representation of the last hidden states.
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def _tokenize(
self,
tokenizer: PreTrainedTokenizer,
texts: list[str],
max_seq_len: int = None,
):
"""
Tokenizes input text sequences with optional sequence length restriction.
"""
if max_seq_len is None:
max_seq_len = self.max_seq_len
tokenized = tokenizer(
texts,
padding=False,
truncation=True,
max_length=max_seq_len - 1,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=True
)
tokenized['input_ids'] = [
t + [tokenizer.eos_token_id] for t in tokenized['input_ids']
]
tokenized = tokenizer.pad(
tokenized,
padding=True,
return_attention_mask=True,
return_tensors='pt',
).to(self.device)
return tokenized
def forward(self, input_ids, attention_mask, dim, *args, **kwargs):
"""
Forward pass through the model.
Args:
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask tensor.
dim (int): Dimensionality for output embeddings.
Returns:
torch.Tensor: Normalized output embeddings.
"""
outputs = super().forward(
input_ids, attention_mask, *args, **kwargs
)
embeddings = self._average_pool(
outputs.last_hidden_state[:, :, :dim], attention_mask
)
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def encode_queries(
self,
tokenizer: PreTrainedTokenizer,
queries: list[str],
max_seq_len: int = None,
dim: int = None,
):
"""
Encodes a list of queries into embeddings.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
queries (list[str]): List of query texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
Returns:
torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
"""
if not queries:
raise ValueError("queries must not be empty.")
if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
raise ValueError("queries must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
queries = [self.query_prefix + query for query in queries]
tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
embeddings = self(**tokenized_queries, dim=dim)
return embeddings
def encode_documents(
self,
tokenizer: PreTrainedTokenizer,
documents: list[str],
max_seq_len: int = None,
dim: int = None,
):
"""
Encodes a list of documents into embeddings.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
documents (list[str]): List of document texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
Returns:
torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
"""
if not documents:
raise ValueError("documents must not be empty.")
if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents):
raise ValueError("documents must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
embeddings = self(**tokenized_documents, dim=dim)
return embeddings