File size: 1,479 Bytes
be1f224
 
0157dfd
be1f224
 
 
0157dfd
be1f224
 
 
 
0157dfd
 
 
be1f224
 
 
0157dfd
be1f224
 
 
 
 
 
 
 
492106d
0157dfd
be1f224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from transformers import (
    LongformerTokenizer,
    LongformerForQuestionAnswering
)
from typing import List, Dict, Tuple
from dotenv import load_dotenv

from src.readers.base_reader import Reader


load_dotenv()


class LongformerReader(Reader):
    def __init__(self) -> None:
        checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
        self.tokenizer = LongformerTokenizer.from_pretrained(checkpoint)
        self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)

    def read(self,
             query: str,
             context: Dict[str, List[str]],
             num_answers=5) -> List[Tuple]:
        answers = []

        for text in context['texts'][:num_answers]:
            encoding = self.tokenizer(query, text, return_tensors="pt")
            input_ids = encoding["input_ids"]
            attention_mask = encoding["attention_mask"]
            outputs = self.model(input_ids, attention_mask=attention_mask)

            start_logits = outputs.start_logits
            end_logits = outputs.end_logits
            all_tokens = self.tokenizer.convert_ids_to_tokens(
                input_ids[0].tolist())
            answer_tokens = all_tokens[
                torch.argmax(start_logits):torch.argmax(end_logits) + 1]
            answer = self.tokenizer.decode(
                self.tokenizer.convert_tokens_to_ids(answer_tokens)
            )
            answers.append([answer, [], []])

        return answers