|
from __future__ import annotations |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
|
|
|
|
class DramaModel(LlamaModel): |
|
""" |
|
DramaModel is a modified version of the LlamaModel that supports bi-directional attention |
|
and provides query and document encoding functionalities. |
|
""" |
|
|
|
def __init__(self, config: LlamaConfig): |
|
""" |
|
Initializes the DramaModel by disabling causal masking in self-attention layers. |
|
""" |
|
super().__init__(config) |
|
for layer in self.layers: |
|
layer.self_attn.is_causal = False |
|
|
|
self.query_prefix = "Query: " |
|
self.max_seq_len = 8192 |
|
self.hidden_size = config.hidden_size |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_seen_tokens=None, |
|
output_attentions=False, |
|
): |
|
""" |
|
Updates the causal mask for attention computations. |
|
""" |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and (attention_mask == 0.0).any(): |
|
return attention_mask |
|
return None |
|
if attention_mask is None or attention_mask.dim() == 4: |
|
return attention_mask |
|
|
|
return AttentionMaskConverter._expand_mask( |
|
mask=attention_mask, |
|
dtype=input_tensor.dtype, |
|
) |
|
|
|
def _average_pool( |
|
self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Computes the average pooled representation of the last hidden states. |
|
""" |
|
last_hidden = last_hidden_states.masked_fill( |
|
~attention_mask[..., None].bool(), 0.0 |
|
) |
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
def _tokenize( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
texts: list[str], |
|
max_seq_len: int = None, |
|
): |
|
""" |
|
Tokenizes input text sequences with optional sequence length restriction. |
|
""" |
|
if max_seq_len is None: |
|
max_seq_len = self.max_seq_len |
|
tokenized = tokenizer( |
|
texts, |
|
padding=False, |
|
truncation=True, |
|
max_length=max_seq_len - 1, |
|
return_attention_mask=False, |
|
return_token_type_ids=False, |
|
add_special_tokens=True |
|
) |
|
tokenized['input_ids'] = [ |
|
t + [tokenizer.eos_token_id] for t in tokenized['input_ids'] |
|
] |
|
tokenized = tokenizer.pad( |
|
tokenized, |
|
padding=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
).to(self.device) |
|
return tokenized |
|
|
|
def forward(self, input_ids, attention_mask, dim, *args, **kwargs): |
|
""" |
|
Forward pass through the model. |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input token IDs. |
|
attention_mask (torch.Tensor): Attention mask tensor. |
|
dim (int): Dimensionality for output embeddings. |
|
|
|
Returns: |
|
torch.Tensor: Normalized output embeddings. |
|
""" |
|
outputs = super().forward( |
|
input_ids, attention_mask, *args, **kwargs |
|
) |
|
embeddings = self._average_pool( |
|
outputs.last_hidden_state[:, :, :dim], attention_mask |
|
) |
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
return embeddings |
|
|
|
def encode_queries( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
queries: list[str], |
|
max_seq_len: int = None, |
|
dim: int = None, |
|
): |
|
""" |
|
Encodes a list of queries into embeddings. |
|
|
|
Args: |
|
tokenizer (PreTrainedTokenizer): Tokenizer for text processing. |
|
queries (list[str]): List of query texts. |
|
max_seq_len (int, optional): Maximum sequence length. |
|
dim (int, optional): Dimensionality for output embeddings. |
|
|
|
Returns: |
|
torch.Tensor: Encoded query embeddings in shape (num_queries, dim). |
|
""" |
|
if not queries: |
|
raise ValueError("queries must not be empty.") |
|
if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries): |
|
raise ValueError("queries must be a list of strings.") |
|
if tokenizer is None: |
|
raise ValueError("tokenizer must not be None.") |
|
if dim is not None and (dim < 1 or dim > self.hidden_size): |
|
raise ValueError(f"dim must be in range [1, {self.hidden_size}].") |
|
queries = [self.query_prefix + query for query in queries] |
|
tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len) |
|
embeddings = self(**tokenized_queries, dim=dim) |
|
return embeddings |
|
|
|
def encode_documents( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
documents: list[str], |
|
max_seq_len: int = None, |
|
dim: int = None, |
|
): |
|
""" |
|
Encodes a list of documents into embeddings. |
|
|
|
Args: |
|
tokenizer (PreTrainedTokenizer): Tokenizer for text processing. |
|
documents (list[str]): List of document texts. |
|
max_seq_len (int, optional): Maximum sequence length. |
|
dim (int, optional): Dimensionality for output embeddings. |
|
|
|
Returns: |
|
torch.Tensor: Encoded document embeddings in shape (num_documents, dim). |
|
""" |
|
if not documents: |
|
raise ValueError("documents must not be empty.") |
|
if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents): |
|
raise ValueError("documents must be a list of strings.") |
|
if tokenizer is None: |
|
raise ValueError("tokenizer must not be None.") |
|
if dim is not None and (dim < 1 or dim > self.hidden_size): |
|
raise ValueError(f"dim must be in range [1, {self.hidden_size}].") |
|
tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len) |
|
embeddings = self(**tokenized_documents, dim=dim) |
|
return embeddings |
|
|
|
|