Spaces:
No application file
No application file
from typing import Any, Dict | |
from torch import nn | |
class TextEmbExtractor(nn.Module): | |
def __init__(self, tokenizer, text_encoder) -> None: | |
super(TextEmbExtractor, self).__init__() | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
def forward( | |
self, | |
texts, | |
text_params: Dict = None, | |
): | |
if text_params is None: | |
text_params = {} | |
special_prompt_input = self.tokenizer( | |
texts, | |
max_length=self.tokenizer.model_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
) | |
if ( | |
hasattr(self.text_encoder.config, "use_attention_mask") | |
and self.text_encoder.config.use_attention_mask | |
): | |
attention_mask = special_prompt_input.attention_mask.to( | |
self.text_encoder.device | |
) | |
else: | |
attention_mask = None | |
embeddings = self.text_encoder( | |
special_prompt_input.input_ids.to(self.text_encoder.device), | |
attention_mask=attention_mask, | |
**text_params | |
) | |
return embeddings | |