import json |
import logging |
import os |
import re |
import shutil |
import sys |
from dataclasses import dataclass, field |
from pathlib import Path |
from typing import Any, Dict, List, Optional, Tuple, Union |
import numpy as np |
import torch |
from accelerate import Accelerator, skip_first_batches |
from accelerate.logging import get_logger |
from datasets import DatasetDict, load_dataset |
from torch.utils.data import DataLoader |
from tqdm import tqdm |
from transformers import ( |
AutoModelForCausalLM, |
AutoTokenizer, |
BitsAndBytesConfig, |
HfArgumentParser, |
) |
logger = get_logger(__name__, log_level="INFO") |
@dataclass |
class ModelArguments: |
""" |
Arguments pertaining to what data we are going to input our model for training and eval. |
""" |
model_name_or_path: str = field( |
metadata={"help": "The name of the model to use (via the transformers library) for the prompt annotation."}, |
) |
per_device_eval_batch_size: int = field( |
metadata={"help": "The per-device batch size to use for inference."}, |
) |
model_variant: str = field( |
default=None, |
metadata={"help": "If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. "}, |
) |
model_revision: str = field( |
default="main", |
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
) |
cache_dir: Optional[str] = field( |
default=None, |
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, |
) |
torch_dtype: Optional[str] = field( |
default="float16", |
metadata={ |
"help": ( |
"Floating-point format in which the model weights should be initialized" |
" and the computations run. Choose one of `[float32, float16, bfloat16]`." |
) |
}, |
) |
attn_implementation: Optional[str] = field( |
default="sdpa", |
metadata={"help": "Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']"}, |
) |
load_in_8bit: Optional[bool] = field( |
default=False, metadata={"help": "Whether to use 8-bit precision for inference."} |
) |
load_in_4bit: Optional[bool] = field( |
default=False, metadata={"help": "Whether to use 4-bit precision for inference."} |
) |
bnb_4bit_quant_type: Optional[str] = field( |
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} |
) |
use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"}) |
trust_remote_code: Optional[bool] = field( |
default=False, |
metadata={ |
"help": ( |
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " |
"should only be set to `True` for repositories you trust and in which you have read the code, as it will " |
"execute code present on the Hub on your local machine." |
) |
}, |
) |
use_fast_tokenizer: Optional[bool] = field( |
default=True, metadata={"help": "Use fast tokenizer for encoding/decoding input ids"} |
) |
token: Optional[bool] = field( |
default=True, |
metadata={ |
"help": "Whether or not to use an authentication token when loading/uploading from the Hugging Face Hub" |
}, |
) |
do_sample: Optional[bool] = field(default=True, metadata={"help": "Whether to use sampling mode for generation"}) |
temperature: Optional[float] = field(default=0.6, metadata={"help": "Temperature for sampling-based generation"}) |
max_new_tokens: Optional[int] = field( |
default=256, metadata={"help": "Maximum number of new tokens during generation"} |
) |
torch_compile: Optional[bool] = field( |
default=False, |
metadata={ |
"help": "Whether to compile the forward pass (not sampling) in generate. Only compatible with Gemma and LlaMA." |
}, |
) |
@dataclass |
class DataArguments: |
""" |
Arguments pertaining to what data we are going to input our model for training and eval. |
""" |
output_dir: str = field( |
metadata={ |
"help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the " |
"original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'." |
}, |
) |
dataset_name: str = field( |
default=None, |
metadata={"help": "The name of the dataset to use (via the datasets library)"}, |
) |
dataset_config_name: Optional[str] = field( |
default=None, |
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, |
) |
dataset_split_name: Optional[str] = field( |
default=None, |
metadata={"help": "The split name of the dataset to use (via the datasets library)."}, |
) |
dataset_cache_dir: Optional[str] = field( |
default=None, |
metadata={"help": "Path to cache directory for saving and loading datasets"}, |
) |
max_eval_samples: Optional[int] = field( |
default=None, |
metadata={"help": "Maximum number of samples for generation - use for debugging purposes."}, |
) |
overwrite_cache: bool = field( |
default=False, |
metadata={"help": "Overwrite the cached training and evaluation sets"}, |
) |
preprocessing_num_workers: Optional[int] = field( |
default=None, |
metadata={"help": "The number of processes to use for the preprocessing."}, |
) |
dataloader_num_workers: Optional[int] = field( |
default=0, |
metadata={"help": "The number of processes to use for the dataloader."}, |
) |
push_to_hub: Optional[bool] = field( |
default=False, |
metadata={"help": "Whether or not to push the processed dataset to the Hub."}, |
) |
hub_dataset_id: Optional[str] = field( |
default=None, |
metadata={"help": "Repository namespace if pushing to the Hugging Face Hub."}, |
) |
overwrite_output_dir: Optional[bool] = field( |
default=False, |
metadata={"help": "Overwrite the content of the output directory each time the script is run."}, |
) |
save_steps: Optional[int] = field( |
default=500, |
metadata={"help": "Save the generated prompts every save_steps."}, |
) |
save_total_limit: Optional[int] = field( |
default=1, metadata={"help": ("If a value is passed, will limit the total number of saved checkpoints")} |
) |
def __post_init__(self): |
if self.push_to_hub and self.hub_dataset_id is None: |
raise ValueError("You must specify the `hub_dataset_id` when setting `--push_to_hub=True`") |
def get_quantization_config(model_args: ModelArguments) -> Union[BitsAndBytesConfig, None]: |
if model_args.load_in_4bit: |
compute_dtype = torch.float16 |
if model_args.torch_dtype not in {"auto", None}: |
compute_dtype = getattr(torch, model_args.torch_dtype) |
quantization_config = BitsAndBytesConfig( |
load_in_4bit=True, |
bnb_4bit_compute_dtype=compute_dtype, |
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, |
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, |
) |
elif model_args.load_in_8bit: |
quantization_config = BitsAndBytesConfig( |
load_in_8bit=True, |
) |
else: |
quantization_config = None |
return quantization_config |
def get_current_device() -> int: |
"""Get the current device. For GPU we return the local process index to enable multiple GPU training.""" |
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" |
def get_kbit_device_map() -> Union[Dict[str, int], None]: |
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" |
return {"": get_current_device()} if torch.cuda.is_available() else None |
CHECKPOINT_PREFIX = "checkpoint" |
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+).json$") |
def save_checkpoint(output_dir, all_generated_ids, step): |
checkpoint_path = f"{CHECKPOINT_PREFIX}-{step}.json" |
output_path = os.path.join(output_dir, checkpoint_path) |
all_generated_ids = [ids.tolist() for ids in all_generated_ids] |
with open(output_path, "w") as file: |
json.dump(all_generated_ids, file) |
def load_checkpoint(checkpoint_path): |
with open(checkpoint_path, "r") as file: |
all_generated_ids = json.load(file) |
all_generated_ids = [np.array(lst) for lst in all_generated_ids] |
return all_generated_ids |
def sorted_checkpoints(output_dir=None) -> List[str]: |
"""Helper function to sort saved checkpoints from oldest to newest.""" |
ordering_and_checkpoint_path = [] |
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{CHECKPOINT_PREFIX}-*")] |
for path in glob_checkpoints: |
regex_match = re.match(f".*{CHECKPOINT_PREFIX}-([0-9]+)", path) |
if regex_match is not None and regex_match.groups() is not None: |
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) |
checkpoints_sorted = sorted(ordering_and_checkpoint_path) |
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] |
return checkpoints_sorted |
def rotate_checkpoints(save_total_limit=None, output_dir=None) -> None: |
"""Helper function to delete old checkpoints.""" |
if save_total_limit is None or save_total_limit <= 0: |
return |
checkpoints_sorted = sorted_checkpoints(output_dir=output_dir) |
if len(checkpoints_sorted) <= save_total_limit: |
return |
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) |
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] |
for checkpoint in checkpoints_to_be_deleted: |
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
os.remove(checkpoint) |
def get_last_checkpoint(folder) -> Tuple[List, int]: |
if not os.path.exists(folder) or not os.path.isdir(folder): |
os.makedirs(folder, exist_ok=True) |
return [], 0 |
content = os.listdir(folder) |
checkpoints = [path for path in content if _RE_CHECKPOINT.search(path) is not None] |
if len(checkpoints) == 0: |
return [], 0 |
last_checkpoint = os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0]))) |
pattern = r"checkpoint-(\d+).json" |
match = re.search(pattern, last_checkpoint) |
cur_step = int(match.group(1)) |
all_generated_ids = load_checkpoint(last_checkpoint) |
return all_generated_ids, cur_step |
@dataclass |
class DataCollatorWithPadding: |
""" |
Data collator that will dynamically pad the inputs received to the longest sequence in the batch. |
""" |
tokenizer: Any |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
input_ids = {"input_ids": [feature["input_ids"] for feature in features]} |
batch = self.tokenizer.pad(input_ids, return_tensors="pt", padding="longest", return_attention_mask=True) |
return batch |
id_to_name = { |
"ex01": "Jerry", |
"ex02": "Elisabeth", |
"ex03": "Thomas", |
"ex04": "Talia" |
} |
PROMPT = """You will be given a name and an enunciation style related to an audio sample of a person's speech. |
1. The name will be one of those: Jerry, Elisabeth, Thomas, Talia. |
2. The enunciation style will be one of those: 'enunciated', 'happy', 'confused', 'default', 'laughing', 'sad', 'whisper', 'emphasis'. |
The enunciation style 'default' can be associated to 'with no particular emotion conveyed'. |
Your task is to create a text description using these information that accurately describes the speech sample. Ensure that the generated description is grammatically correct, easy to understand, and most importantly, concise. |
For example, given the following keywords: 'Talia', 'happy', a valid description would be: 'In an excellent recording, Talia speaks happily.'. |
Another valid description would be: 'Talia delivers her words happily.' |
Another example, given the following keywords: 'Jerry', 'emphasis': 'Jerry speaks with emphasis on certain words.' |
You are free to change the order of th!e information, and replace synonymous terms. |
You must give one and only one description and nothing else. Remember, I only want one description and nothing else. |
For the information: '[speaker_id]', '[style]', the corresponding description is:""" |
def main(): |
parser = HfArgumentParser((ModelArguments, DataArguments)) |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
else: |
model_args, data_args = parser.parse_args_into_dataclasses() |
logging.basicConfig( |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
datefmt="%m/%d/%Y %H:%M:%S", |
handlers=[logging.StreamHandler(sys.stdout)], |
) |
accelerator = Accelerator() |
if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir): |
logger.info("Cleaning output dir from previous run...") |
shutil.rmtree(data_args.output_dir) |
logger.info("*** Load annotated dataset ***") |
if data_args.dataset_split_name is not None: |
raw_datasets = DatasetDict() |
data_splits = data_args.dataset_split_name.split("+") |
for split in data_splits: |
with accelerator.local_main_process_first(): |
raw_datasets[split] = load_dataset( |
data_args.dataset_name, |
data_args.dataset_config_name, |
split=split, |
cache_dir=model_args.cache_dir, |
token=model_args.token, |
num_proc=data_args.preprocessing_num_workers, |
) |
else: |
with accelerator.local_main_process_first(): |
raw_datasets = load_dataset( |
data_args.dataset_name, |
data_args.dataset_config_name, |
cache_dir=model_args.cache_dir, |
token=model_args.token, |
num_proc=data_args.preprocessing_num_workers, |
) |
raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys()) |
if data_args.max_eval_samples is not None: |
for split in raw_datasets: |
raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples)) |
EXPECTED_COLUMNS = {"speaker_id", "style"} |
if not EXPECTED_COLUMNS.issubset(raw_datasets_features): |
missing_columns = EXPECTED_COLUMNS - raw_datasets_features |
raise ValueError( |
f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}" |
) |
logger.info("*** Load pretrained model ***") |
torch_dtype = ( |
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) |
) |
quantization_config = get_quantization_config(model_args) |
model = AutoModelForCausalLM.from_pretrained( |
model_args.model_name_or_path, |
revision=model_args.model_revision, |
variant=model_args.model_variant, |
trust_remote_code=model_args.trust_remote_code, |
attn_implementation=model_args.attn_implementation, |
torch_dtype=torch_dtype, |
device_map=get_kbit_device_map() if quantization_config is not None else None, |
quantization_config=quantization_config, |
low_cpu_mem_usage=True, |
token=model_args.token, |
).eval() |
if model_args.torch_compile: |
if not callable(getattr(model, "_setup_cache", None)): |
raise ValueError( |
f"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--torch_compile=False" |
"for dynamic k/v cache" |
) |
model.generation_config.cache_implementation = "static" |
model = torch.compile(model, mode="reduce-overhead", fullgraph=True) |
tokenizer = AutoTokenizer.from_pretrained( |
model_args.model_name_or_path, |
revision=model_args.model_revision, |
trust_remote_code=model_args.trust_remote_code, |
use_fast=model_args.use_fast_tokenizer, |
padding_side="left", |
) |
if tokenizer.pad_token_id is None: |
tokenizer.pad_token_id = tokenizer.bos_token_id |
model.generation_config.pad_token_id = model.generation_config.eos_token_id |
def prepare_dataset(sample): |
sample_prompt = PROMPT |
sample["speaker_id"] = id_to_name[sample["speaker_id"]] |
for key in EXPECTED_COLUMNS: |
sample_prompt = sample_prompt.replace(f"[{key}]", sample[key]) |
sample_prompt = [{"role": "user", "content": sample_prompt}] |
token_ids = tokenizer.apply_chat_template(sample_prompt) |
sample["input_ids"] = token_ids |
return sample |
with accelerator.local_main_process_first(): |
vectorized_datasets = raw_datasets.map( |
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts" |
) |
model = accelerator.prepare(model) |
data_collator = DataCollatorWithPadding(tokenizer) |
def generate_step(batch): |
output_ids = accelerator.unwrap_model(model).generate( |
batch["input_ids"], |
attention_mask=batch["attention_mask"], |
do_sample=model_args.do_sample, |
temperature=model_args.temperature, |
max_new_tokens=model_args.max_new_tokens, |
) |
output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id) |
return output_ids |
def postprocess_dataset(sample): |
prompt_text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True) |
generated_text = tokenizer.decode(sample["generated_ids"], skip_special_tokens=True) |
sample["text_description"] = generated_text[len(prompt_text) :] |
return sample |
for split in vectorized_datasets: |
data_loader = DataLoader( |
vectorized_datasets[split], |
batch_size=model_args.per_device_eval_batch_size, |
collate_fn=data_collator, |
num_workers=data_args.dataloader_num_workers, |
pin_memory=True, |
) |
data_loader = accelerator.prepare(data_loader) |
total_inference_steps = len(data_loader) |
progress_bar = tqdm( |
range(total_inference_steps), desc=" ... ", position=0, disable=not accelerator.is_local_main_process |
) |
split_output_dir = os.path.join(data_args.output_dir, split) |
all_generated_ids, cur_step = get_last_checkpoint(split_output_dir) |
if cur_step > 0: |
logger.info(f"Resuming {split} from step {cur_step}") |
data_loader = skip_first_batches(data_loader, cur_step) |
progress_bar.update(cur_step) |
while cur_step < total_inference_steps: |
for batch in data_loader: |
generated_ids = generate_step(batch) |
generated_ids = accelerator.gather_for_metrics(generated_ids) |
all_generated_ids.extend(generated_ids.cpu().numpy()) |
cur_step += 1 |
progress_bar.update(1) |
if (cur_step % data_args.save_steps == 0) or (cur_step == total_inference_steps): |
save_checkpoint(split_output_dir, all_generated_ids, cur_step) |
rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir) |
vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids) |
if accelerator.is_main_process: |
vectorized_datasets[split] = vectorized_datasets[split].map( |
postprocess_dataset, |
num_proc=data_args.preprocessing_num_workers, |
desc="Postprocessing dataset", |
remove_columns=["input_ids", "generated_ids"], |
) |
accelerator.wait_for_everyone() |
if accelerator.is_main_process: |
vectorized_datasets.save_to_disk(data_args.output_dir) |
if data_args.push_to_hub: |
vectorized_datasets.push_to_hub( |
data_args.hub_dataset_id, |
config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else "default", |
token=model_args.token, |
) |
accelerator.wait_for_everyone() |
accelerator.end_training() |
if __name__ == "__main__": |
main() |