File size: 1,359 Bytes
00a31fe
 
 
 
dc7ce01
00a31fe
 
 
6ef6e5f
 
00a31fe
9388f53
00a31fe
 
 
 
9388f53
00a31fe
 
 
 
 
6ef6e5f
 
b0f59a3
 
 
6ef6e5f
b0f59a3
6ef6e5f
 
 
00a31fe
 
6ef6e5f
00a31fe
6ef6e5f
00a31fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from typing import Any, Dict

from transformers import (Pipeline, T5ForConditionalGeneration, T5Tokenizer,
                          pipeline)

auth_token = os.environ.get("CLARIN_KNEXT")


DEFAULT_DST_INPUTS: Dict[str, str] = {
    "polish": (
        "[U] Chciałbym zarezerwować stolik na 4 osoby na piątek o godzinie 18:30. "
        "[Dziedzina] Restauracje: Popularna usługa wyszukiwania i rezerwacji restauracji "
        "[Atrybut] Czas: Wstępny czas rezerwacji restauracji"
    ),
    "english": (
        "[U] I want to book a table for 4 people on Friday, 6:30 pm. "
        "[Domain] Restaurants: A popular restaurant search and reservation service "
        "[Slot] Time: Tentative time of restaurant reservation"
    ),
}


DST_MODELS: Dict[str, Dict[str, Any]] = {
    "plt5-large-poquad-dst-v2": {
        "model": T5ForConditionalGeneration.from_pretrained("clarin-knext/plt5-large-poquad-dst-v2", use_auth_token=auth_token),
        "tokenizer": T5Tokenizer.from_pretrained("clarin-knext/plt5-large-poquad-dst-v2", use_auth_token=auth_token),
        "default_input": DEFAULT_DST_INPUTS["polish"],
    }
}


PIPELINES: Dict[str, Pipeline] = {
    model_name: pipeline(
        "text2text-generation", model=DST_MODELS[model_name]["model"], tokenizer=DST_MODELS[model_name]["tokenizer"]
    )
    for model_name in DST_MODELS
}