File size: 6,421 Bytes
2d79e15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from __future__ import annotations

import torch
import torch.nn.functional as F

from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer
from transformers.modeling_attn_mask_utils import AttentionMaskConverter


class DramaModel(LlamaModel):
    """
    DramaModel is a modified version of the LlamaModel that supports bi-directional attention
    and provides query and document encoding functionalities.
    """

    def __init__(self, config: LlamaConfig):
        """
        Initializes the DramaModel by disabling causal masking in self-attention layers.
        """
        super().__init__(config)
        for layer in self.layers:
            layer.self_attn.is_causal = False
        # query prefix
        self.query_prefix = "Query: "
        self.max_seq_len = 8192
        self.hidden_size = config.hidden_size
    
    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_seen_tokens=None,
        output_attentions=False,
    ):
        """
        Updates the causal mask for attention computations.
        """
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and (attention_mask == 0.0).any():
                return attention_mask
            return None
        if attention_mask is None or attention_mask.dim() == 4:
            return attention_mask
        
        return AttentionMaskConverter._expand_mask(
            mask=attention_mask,
            dtype=input_tensor.dtype,
        )

    def _average_pool(
        self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Computes the average pooled representation of the last hidden states.
        """
        last_hidden = last_hidden_states.masked_fill(
            ~attention_mask[..., None].bool(), 0.0
        )
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def _tokenize(
            self,
            tokenizer: PreTrainedTokenizer,
            texts: list[str],
            max_seq_len: int = None,
        ):
        """
        Tokenizes input text sequences with optional sequence length restriction.
        """
        if max_seq_len is None:
            max_seq_len = self.max_seq_len
        tokenized = tokenizer(
            texts,
            padding=False,
            truncation=True,
            max_length=max_seq_len - 1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True
        )
        tokenized['input_ids'] = [
            t + [tokenizer.eos_token_id] for t in tokenized['input_ids']
        ]
        tokenized = tokenizer.pad(
            tokenized,
            padding=True, 
            return_attention_mask=True,
            return_tensors='pt',
        ).to(self.device)
        return tokenized

    def forward(self, input_ids, attention_mask, dim, *args, **kwargs):
        """
        Forward pass through the model.
        
        Args:
            input_ids (torch.Tensor): Input token IDs.
            attention_mask (torch.Tensor): Attention mask tensor.
            dim (int): Dimensionality for output embeddings.
        
        Returns:
            torch.Tensor: Normalized output embeddings.
        """
        outputs = super().forward(
            input_ids, attention_mask, *args, **kwargs
        )
        embeddings = self._average_pool(
            outputs.last_hidden_state[:, :, :dim], attention_mask
        )
        # normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings

    def encode_queries(
            self,
            tokenizer: PreTrainedTokenizer,
            queries: list[str],
            max_seq_len: int = None,
            dim: int = None,
        ):
        """
        Encodes a list of queries into embeddings.
        
        Args:
            tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
            queries (list[str]): List of query texts.
            max_seq_len (int, optional): Maximum sequence length.
            dim (int, optional): Dimensionality for output embeddings.
        
        Returns:
            torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
        """
        if not queries:
            raise ValueError("queries must not be empty.")
        if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
            raise ValueError("queries must be a list of strings.")
        if tokenizer is None:
            raise ValueError("tokenizer must not be None.")
        if dim is not None and (dim < 1 or dim > self.hidden_size):
            raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
        queries = [self.query_prefix + query for query in queries]
        tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
        embeddings = self(**tokenized_queries, dim=dim)
        return embeddings

    def encode_documents(
            self,
            tokenizer: PreTrainedTokenizer,
            documents: list[str],
            max_seq_len: int = None,
            dim: int = None,
        ):
        """
        Encodes a list of documents into embeddings.
        
        Args:
            tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
            documents (list[str]): List of document texts.
            max_seq_len (int, optional): Maximum sequence length.
            dim (int, optional): Dimensionality for output embeddings.
        
        Returns:
            torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
        """
        if not documents:
            raise ValueError("documents must not be empty.")
        if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents):
            raise ValueError("documents must be a list of strings.")
        if tokenizer is None:
            raise ValueError("tokenizer must not be None.")
        if dim is not None and (dim < 1 or dim > self.hidden_size):
            raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
        tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
        embeddings = self(**tokenized_documents, dim=dim)
        return embeddings