File size: 1,126 Bytes
ab5dfc2 0157dfd ab5dfc2 be1f224 ab5dfc2 be1f224 0157dfd be1f224 ab5dfc2 492106d ab5dfc2 112c207 ab5dfc2 112c207 ab5dfc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from transformers import DPRReader, DPRReaderTokenizer
from typing import List, Dict, Tuple
from dotenv import load_dotenv
from src.readers.base_reader import Reader
load_dotenv()
class DprReader(Reader):
def __init__(self) -> None:
self._tokenizer = DPRReaderTokenizer.from_pretrained(
"facebook/dpr-reader-single-nq-base")
self._model = DPRReader.from_pretrained(
"facebook/dpr-reader-single-nq-base")
def read(self,
query: str,
context: Dict[str, List[str]],
num_answers=5) -> List[Tuple]:
encoded_inputs = self._tokenizer(
questions=query,
titles=context['titles'],
texts=context['texts'],
return_tensors='pt',
truncation=True,
padding=True
)
outputs = self._model(**encoded_inputs)
predicted_spans = self._tokenizer.decode_best_spans(
encoded_inputs,
outputs,
num_spans=num_answers,
num_spans_per_passage=2,
max_answer_length=512)
return predicted_spans
|