Spaces:
No application file
No application file
File size: 1,189 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from typing import Any, Dict
from torch import nn
class TextEmbExtractor(nn.Module):
def __init__(self, tokenizer, text_encoder) -> None:
super(TextEmbExtractor, self).__init__()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
def forward(
self,
texts,
text_params: Dict = None,
):
if text_params is None:
text_params = {}
special_prompt_input = self.tokenizer(
texts,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
if (
hasattr(self.text_encoder.config, "use_attention_mask")
and self.text_encoder.config.use_attention_mask
):
attention_mask = special_prompt_input.attention_mask.to(
self.text_encoder.device
)
else:
attention_mask = None
embeddings = self.text_encoder(
special_prompt_input.input_ids.to(self.text_encoder.device),
attention_mask=attention_mask,
**text_params
)
return embeddings
|