language: - ru
Model Card for Model ID
Model Details
Model Description
- Developed by: DeepPavlov team
- Model type: text generation
- Language(s) (NLP): Russian
- License: Openrail
- Finetuned from model: facebook/mbart-large-50
Uses
Direct Use
from typing import List, TypedDict
from dataclasses import dataclass
from itertools import chain
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
@dataclass
class H2PersonaChatHyperparametersV1:
"""
chat_history_pair_length: int - количество пар диалога с конца
"""
model_name: str = "facebook/bart-base"
chat_history_pair_length: int = 7
persona_max_length: int = 14
chat_max_length: int = 25
debug_status: int = 0
class PersonaChatDatasetSampleV1(TypedDict):
"""
persona: List[str] - набор предложений фактов персоны
history: List[str] - набор предложений истории переписки
"""
persona: List[str]
history: List[str]
sample_id: str
class H2Seq2SeqInferenceSampleDictV1(TypedDict):
input_ids: List[int]
attention_mask: List[int]
class H2Seq2SeqInferenceSampleDictV2(TypedDict):
input_ids: torch.Tensor
attention_mask: torch.Tensor
def flat_list(list_of_lists: List[List]) -> List:
return list(chain.from_iterable(list_of_lists))
class H2Seq2SeqInferencePersonaSampleV1:
def __init__(
self,
dataset_sample: PersonaChatDatasetSampleV1,
tokenizer: AutoTokenizer,
hyperparameters: H2PersonaChatHyperparametersV1,
) -> None:
self.dataset_sample = dataset_sample
self.tokenizer = tokenizer
self.hyperparameters = hyperparameters
def add_spaces_after(
self,
items: List[str],
) -> List[str]:
items = [item + " " for item in items]
return items
@property
def bos_token_id(self):
if "t5" in self.hyperparameters.model_name:
return []
if self.tokenizer.bos_token_id is None:
return []
return [self.tokenizer.bos_token_id]
@property
def eos_token_id(self):
if self.tokenizer.eos_token_id is None:
return []
return [self.tokenizer.eos_token_id]
def add_sep_beetween(self, items: List[str], sep=" EOS ") -> List[str]:
for i in range(1, len(items)):
items[i] = sep + items[i]
return items
def add_spaces_between(self, items: List[str]) -> List[str]:
items = self.add_spaces_after(items)
items[-1] = items[-1].strip()
return items
def get_sample(self) -> H2Seq2SeqInferenceSampleDictV1:
dialog_history = self.dataset_sample["history"]
dialog_history = dialog_history[-self.hyperparameters.chat_history_pair_length * 2 - 1 :]
dialog_history = self.add_sep_beetween(dialog_history)
persona = self.dataset_sample["persona"]
persona = self.add_sep_beetween(
persona,
sep=" ",
)
KNOWLEDGE_IDS = self.tokenizer.encode(
" [KNOWLEDGE] ",
add_special_tokens=False,
)
CONTEXT_IDS = self.tokenizer.encode(
" [CONTEXT]",
add_special_tokens=False,
)
encoded_history = self.tokenizer.batch_encode_plus(
dialog_history,
add_special_tokens=False,
truncation=True,
max_length=self.hyperparameters.chat_max_length,
)
encoded_history = flat_list(encoded_history["input_ids"])
encoded_persona = self.tokenizer.batch_encode_plus(
persona,
add_special_tokens=False,
truncation=True,
max_length=self.hyperparameters.persona_max_length,
)
encoded_persona = flat_list(encoded_persona["input_ids"])
input_ids = [
*self.bos_token_id,
*CONTEXT_IDS,
*encoded_history,
*KNOWLEDGE_IDS,
*encoded_persona,
*self.eos_token_id,
]
attention_mask = [1] * len(input_ids)
return H2Seq2SeqInferenceSampleDictV1(
input_ids=input_ids,
attention_mask=attention_mask,
)
class DialogBotV1:
def __init__(
self,
model: AutoModelForSeq2SeqLM,
tokenizer: AutoTokenizer,
hyperparameters: H2PersonaChatHyperparametersV1,
history: List[str] = None,
persona: List[str] = None,
device: str = "cuda",
shuffle_persona: bool = True,
):
self.model = model
self.tokenizer = tokenizer
self.hyperparameters = hyperparameters
self.device = device
self.shuffle_persona = shuffle_persona
self.debug_status = hyperparameters.debug_status
if history is None:
self.history = []
self.history = history
if persona is None:
self.persona = []
self.persona = persona
def _get_sample(
self,
persona: List[str],
history: List[str],
) -> H2Seq2SeqInferenceSampleDictV1:
dataset_sample = PersonaChatDatasetSampleV1(
persona=persona,
history=history,
)
sample = H2Seq2SeqInferencePersonaSampleV1(
tokenizer=self.tokenizer,
hyperparameters=self.hyperparameters,
dataset_sample=dataset_sample,
)
sample = sample.get_sample()
print(self.tokenizer.decode(sample['input_ids']))
for key in sample.keys():
sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(self.device)
return sample
def next_response(
self,
**generation_params,
) -> str:
"""
делает предсказание на основе текущей истории
и персоны
"""
sample = self._get_sample(
persona=self.persona,
history=self.history,
)
answer = self.generate_response(
sample,
**generation_params,
)
answer = self.tokenizer.batch_decode(
answer,
skip_special_tokens=True,
)
self.history.append(answer[0])
return answer[0]
def generate_response(
self,
sample: H2Seq2SeqInferenceSampleDictV1,
**generation_params,
):
"""
generation_params - https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation
"""
with torch.no_grad():
return self.model.generate(
**sample,
**generation_params,
)
# facebook/mbart-large-50
PRETRAINED_MODEL_NAME_OR_PATH = "DeepPavlov/mbart-large-50-ru-persona-chat"
PAIR_DIALOG_HISTORY_LENGTH = 2
# CHAT_MAX_LENGTH for single sentence
CHAT_MAX_LENGTH = 25
# PERSONA_MAX_LENGTH for single sentence
PERSONA_MAX_LENGTH = 19
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
if torch.cuda.is_available():
model.half()
hyperparameters = H2PersonaChatHyperparametersV1(
chat_history_pair_length=PAIR_DIALOG_HISTORY_LENGTH,
persona_max_length=PERSONA_MAX_LENGTH,
chat_max_length=CHAT_MAX_LENGTH,
model_name=PRETRAINED_MODEL_NAME_OR_PATH,
)
persona = [
"Я люблю играть с милыми песиками",
"Я ненавижу лук и броколли"
]
history = [
"Привет. Ты любишь лук?"
]
persona_bot = DialogBotV1(
model=model,
tokenizer=tokenizer,
hyperparameters=hyperparameters,
history=history,
persona=persona,
device=device,
)
GENERATION_PARAMS = {
"max_new_tokens": 60,
"penalty_alpha": 0.15,
"top_k": 10
}
response = persona_bot.next_response(
**GENERATION_PARAMS,
)
print(response)
Recommendations
Training Details
Training Data
[More Information Needed]
Preprocessing
- Initial data was splitted by this script:
def ru_persona_chat_dataset_tranformer_v1(
initial_dataset_path: str,
output_folder: str,
) -> None:
"""
example
ru_persona_chat_dataset_tranformer_v1(
initial_dataset_path="./datasets/ru_persona_chat/dialogues.tsv",
output_folder="./datasets/ru_persona_chat",
)
"""
assert initial_dataset_path is not None, "initial_dataset_path is None"
assert output_folder is not None, "output_folder is None"
dataset = pd.read_csv(initial_dataset_path, sep="\t")
split_ratio = int(len(dataset) * 0.95)
train_dataset = dataset[:split_ratio]
valid_dataset = dataset[split_ratio:]
print(f"Dataset lengths: train {len(train_dataset)}, valid {len(valid_dataset)}")
# save csv files
train_dataset.to_csv(output_folder + "/train.csv", index=False)
valid_dataset.to_csv(output_folder + "/valid.csv", index=False)
print("Datasets saved.")
Evaluation
Metrics
- BLUEL
- CharF
- RougeL
- Downloads last month
- 190
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.