language: - ru

Model Card for Model ID

Model Details

Model Description

  • Developed by: DeepPavlov team
  • Model type: text generation
  • Language(s) (NLP): Russian
  • License: Openrail
  • Finetuned from model: facebook/mbart-large-50

Uses

Direct Use

from typing import List, TypedDict
from dataclasses import dataclass
from itertools import chain

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch


@dataclass
class H2PersonaChatHyperparametersV1:
    """
    chat_history_pair_length: int - количество пар диалога с конца
    """

    model_name: str = "facebook/bart-base"
    chat_history_pair_length: int = 7

    persona_max_length: int = 14
    chat_max_length: int = 25

    debug_status: int = 0


class PersonaChatDatasetSampleV1(TypedDict):
    """
    persona: List[str] - набор предложений фактов персоны
    history: List[str] - набор предложений истории переписки
    """

    persona: List[str]
    history: List[str]
    sample_id: str


class H2Seq2SeqInferenceSampleDictV1(TypedDict):
    input_ids: List[int]
    attention_mask: List[int]


class H2Seq2SeqInferenceSampleDictV2(TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


def flat_list(list_of_lists: List[List]) -> List:
    return list(chain.from_iterable(list_of_lists))


class H2Seq2SeqInferencePersonaSampleV1:
    def __init__(
        self,
        dataset_sample: PersonaChatDatasetSampleV1,
        tokenizer: AutoTokenizer,
        hyperparameters: H2PersonaChatHyperparametersV1,
    ) -> None:
        self.dataset_sample = dataset_sample
        self.tokenizer = tokenizer
        self.hyperparameters = hyperparameters

    def add_spaces_after(
        self,
        items: List[str],
    ) -> List[str]:
        items = [item + " " for item in items]
        return items

    @property
    def bos_token_id(self):
        if "t5" in self.hyperparameters.model_name:
            return []

        if self.tokenizer.bos_token_id is None:
            return []

        return [self.tokenizer.bos_token_id]

    @property
    def eos_token_id(self):
        if self.tokenizer.eos_token_id is None:
            return []

        return [self.tokenizer.eos_token_id]

    def add_sep_beetween(self, items: List[str], sep=" EOS ") -> List[str]:
        for i in range(1, len(items)):
            items[i] = sep + items[i]

        return items

    def add_spaces_between(self, items: List[str]) -> List[str]:
        items = self.add_spaces_after(items)
        items[-1] = items[-1].strip()
        return items

    def get_sample(self) -> H2Seq2SeqInferenceSampleDictV1:

        dialog_history = self.dataset_sample["history"]
        dialog_history = dialog_history[-self.hyperparameters.chat_history_pair_length * 2 - 1 :]
        dialog_history = self.add_sep_beetween(dialog_history)

        persona = self.dataset_sample["persona"]
        persona = self.add_sep_beetween(
            persona,
            sep=" ",
        )

        KNOWLEDGE_IDS = self.tokenizer.encode(
            " [KNOWLEDGE] ",
            add_special_tokens=False,
        )
        CONTEXT_IDS = self.tokenizer.encode(
            " [CONTEXT]",
            add_special_tokens=False,
        )

        encoded_history = self.tokenizer.batch_encode_plus(
            dialog_history,
            add_special_tokens=False,
            truncation=True,
            max_length=self.hyperparameters.chat_max_length,
        )
        encoded_history = flat_list(encoded_history["input_ids"])

        encoded_persona = self.tokenizer.batch_encode_plus(
            persona,
            add_special_tokens=False,
            truncation=True,
            max_length=self.hyperparameters.persona_max_length,
        )

        encoded_persona = flat_list(encoded_persona["input_ids"])

        input_ids = [
            *self.bos_token_id,
            *CONTEXT_IDS,
            *encoded_history,
            *KNOWLEDGE_IDS,
            *encoded_persona,
            *self.eos_token_id,
        ]

        attention_mask = [1] * len(input_ids)

        return H2Seq2SeqInferenceSampleDictV1(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )


class DialogBotV1:
    def __init__(
        self,
        model: AutoModelForSeq2SeqLM,
        tokenizer: AutoTokenizer,
        hyperparameters: H2PersonaChatHyperparametersV1,
        history: List[str] = None,
        persona: List[str] = None,
        device: str = "cuda",
        shuffle_persona: bool = True,
    ):
        self.model = model

        self.tokenizer = tokenizer
        self.hyperparameters = hyperparameters
        self.device = device
        self.shuffle_persona = shuffle_persona

        self.debug_status = hyperparameters.debug_status

        if history is None:
            self.history = []
        self.history = history

        if persona is None:
            self.persona = []
        self.persona = persona

    def _get_sample(
        self,
        persona: List[str],
        history: List[str],
    ) -> H2Seq2SeqInferenceSampleDictV1:
        dataset_sample = PersonaChatDatasetSampleV1(
            persona=persona,
            history=history,
        )

        sample = H2Seq2SeqInferencePersonaSampleV1(
            tokenizer=self.tokenizer,
            hyperparameters=self.hyperparameters,
            dataset_sample=dataset_sample,
        )
        sample = sample.get_sample()
        print(self.tokenizer.decode(sample['input_ids']))

        for key in sample.keys():
            sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(self.device)

        return sample

    def next_response(
        self,
        **generation_params,
    ) -> str:
        """
        делает предсказание на основе текущей истории
        и персоны
        """

        sample = self._get_sample(
            persona=self.persona,
            history=self.history,
        )
        answer = self.generate_response(
            sample,
            **generation_params,
        )
        answer = self.tokenizer.batch_decode(
            answer,
            skip_special_tokens=True,
        )
        self.history.append(answer[0])
        return answer[0]

    def generate_response(
        self,
        sample: H2Seq2SeqInferenceSampleDictV1,
        **generation_params,
    ):
        """
        generation_params - https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation
        """
        with torch.no_grad():
            return self.model.generate(
                **sample,
                **generation_params,
            )


# facebook/mbart-large-50
PRETRAINED_MODEL_NAME_OR_PATH = "DeepPavlov/mbart-large-50-ru-persona-chat"

PAIR_DIALOG_HISTORY_LENGTH = 2

# CHAT_MAX_LENGTH for single sentence
CHAT_MAX_LENGTH = 25
# PERSONA_MAX_LENGTH for single sentence
PERSONA_MAX_LENGTH = 19

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)

if torch.cuda.is_available():
    model.half()

hyperparameters = H2PersonaChatHyperparametersV1(
    chat_history_pair_length=PAIR_DIALOG_HISTORY_LENGTH,
    persona_max_length=PERSONA_MAX_LENGTH,
    chat_max_length=CHAT_MAX_LENGTH,
    model_name=PRETRAINED_MODEL_NAME_OR_PATH,
)


persona = [
     "Я люблю играть с милыми песиками",
    "Я ненавижу лук и броколли"
]

history = [
    "Привет. Ты любишь лук?"
]
            
persona_bot = DialogBotV1(
        model=model,
        tokenizer=tokenizer,
        hyperparameters=hyperparameters,
        history=history,
        persona=persona,
        device=device,
    )

GENERATION_PARAMS = {
    "max_new_tokens": 60,
    "penalty_alpha": 0.15,
    "top_k": 10
}
response = persona_bot.next_response(
    **GENERATION_PARAMS,
)

print(response)

Recommendations

Training Details

Training Data

[More Information Needed]

Preprocessing

  • Initial data was splitted by this script:
def ru_persona_chat_dataset_tranformer_v1(
    initial_dataset_path: str,
    output_folder: str,
) -> None:
    """
        example
            ru_persona_chat_dataset_tranformer_v1(
            initial_dataset_path="./datasets/ru_persona_chat/dialogues.tsv",
            output_folder="./datasets/ru_persona_chat",
    )
    """
    assert initial_dataset_path is not None, "initial_dataset_path is None"
    assert output_folder is not None, "output_folder is None"

    dataset = pd.read_csv(initial_dataset_path, sep="\t")
    split_ratio = int(len(dataset) * 0.95)
    train_dataset = dataset[:split_ratio]
    valid_dataset = dataset[split_ratio:]

    print(f"Dataset lengths: train {len(train_dataset)}, valid {len(valid_dataset)}")
    # save csv files
    train_dataset.to_csv(output_folder + "/train.csv", index=False)
    valid_dataset.to_csv(output_folder + "/valid.csv", index=False)
    print("Datasets saved.")

Evaluation

Metrics

  • BLUEL
  • CharF
  • RougeL
Downloads last month
190
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.