how to fine tune?

#1
by NickyNicky - opened

Thank you very much for the model, I really liked it and I would like to be able to use it to make fine tune, do you have a Google Colab so I can interact and be able to make fine tune?

thank you so much.

I am running this, so train the language part of it. if something has a more generic one that would be helpful

import os
import re
from typing import Dict, Union, List
import dataclasses
import gc
import json
import time
from datetime import datetime
import logging
import argparse
import pandas as pd
import numpy as np
from functools import partial
from IPython.display import display
from tqdm.auto import tqdm
import openai
from PIL import Image
from pillow_heif import register_heif_opener
import io
import base64
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TrainingArguments, EvalPrediction
from transformers.integrations import WandbCallback
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from string import Template
from research.utilities.custom_callbacks import *
from trl import ModelConfig, SFTConfig, SFTTrainer
from transformers.trainer_callback import TrainerCallback
from peft import LoraConfig, get_peft_model, TaskType
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import wandb
from multiprocessing import Pool, cpu_count, Manager
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    FullOptimStateDictConfig,
    FullStateDictConfig,
    StateDictType,
)

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
    use_orig_params=True
)

register_heif_opener()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
pd.set_option('display.max_columns', None)

logging.basicConfig(
    level=logging.INFO,  # Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler("app.log"),  logging.StreamHandler()]
)

logger = logging.getLogger("VLM fine tuning | Training")
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])


class VLTrainingWrapper(torch.nn.Module):
    def __init__(self, processor, model):
        super().__init__()
        # Store only what we need, avoid keeping multiple references
        vision_model = model.vision_model
        aligner = model.aligner
        language_model = model.language_model
        self.processor = processor
        self.processor.tokenizer.padding_side = 'right'

        # Clear the original model to free memory and avoid refs
        del model

        # Store components we need
        self.vision_model = vision_model
        self.aligner = aligner

        # Freeze vision components
        self.vision_model.requires_grad_(False)
        self.aligner.requires_grad_(False)

        # Apply LoRA only to the language model part
        peft_config = LoraConfig(
            r=16,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
            inference_mode=False
        )

        # Convert and wrap language model
        language_model = language_model.to(torch.bfloat16)
        self.language_model = get_peft_model(language_model, peft_config)
        del language_model

        # Ensure LoRA parameters are trainable
        for name, param in self.language_model.named_parameters():
            if 'lora' in name:
                param.requires_grad_(True)
            else:
                param.requires_grad_(False)

    def get_vision_device(self):
        return next(self.vision_model.parameters()).device

    def get_language_device(self):
        return next(self.language_model.parameters()).device

    def prepare_inputs_embeds(self, input_ids=None, pixel_values=None,
                              images_seq_mask=None, images_emb_mask=None):
        vision_device = self.get_vision_device()
        language_device = self.get_language_device()

        # Process vision inputs
        if len(pixel_values.shape) == 5:
            pixel_values = pixel_values.squeeze(1)

        if pixel_values.shape[-1] == 3:
            pixel_values = pixel_values.permute(0, 3, 1, 2)

        pixel_values = pixel_values.to(vision_device, dtype=torch.bfloat16)
        with torch.no_grad():
            vision_embeds = self.vision_model(pixel_values)
            aligned_embeds = self.aligner(vision_embeds)
            del vision_embeds  # Free memory

        # Process text inputs
        input_ids = input_ids.to(language_device)
        images_seq_mask = images_seq_mask.to(language_device)

        text_embeds = self.language_model.get_input_embeddings()(input_ids)

        # Get dimensions
        bsz = input_ids.shape[0]
        text_seq_length = text_embeds.shape[1]
        vision_seq_length = aligned_embeds.shape[1]
        seq_length = images_seq_mask.shape[1]

        # Create embeddings tensor
        inputs_embeds = torch.zeros(
            (bsz, seq_length, text_embeds.shape[-1]),
            dtype=text_embeds.dtype,
            device=text_embeds.device,
            requires_grad=True
        )

        # Move aligned embeds to same device
        aligned_embeds = aligned_embeds.to(text_embeds.device)

        # Combine embeddings
        for b in range(bsz):
            text_idx = 0
            vision_idx = 0
            for pos in range(seq_length):
                if pos < images_seq_mask.shape[1] and images_seq_mask[b][pos]:
                    if vision_idx < vision_seq_length:
                        inputs_embeds[b, pos] = aligned_embeds[b, vision_idx]
                        vision_idx += 1
                else:
                    if text_idx < text_seq_length:
                        inputs_embeds[b, pos] = text_embeds[b, text_idx]
                        text_idx += 1

        del aligned_embeds, text_embeds  # Free memory
        return inputs_embeds

    def forward(self, inputs_embeds, attention_mask, labels=None):
        device = self.get_language_device()
        inputs_embeds = inputs_embeds.to(device)
        attention_mask = attention_mask.to(device)
        if labels is not None:
            labels = labels.to(device)

        return self.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
            pad_token_id=self.processor.tokenizer.eos_token_id,
            bos_token_id=self.processor.tokenizer.bos_token_id,
            eos_token_id=self.processor.tokenizer.eos_token_id,
            do_sample=False,
            use_cache=True
        )

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if hasattr(self.language_model, "gradient_checkpointing_enable"):
            self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
            # Ensure trainable parameters have requires_grad=True after enabling checkpointing
            for name, param in self.language_model.named_parameters():
                if 'lora' in name:
                    param.requires_grad_(True)

    def gradient_checkpointing_disable(self):
        if hasattr(self.language_model, "gradient_checkpointing_disable"):
            self.language_model.gradient_checkpointing_disable()


def clear_memory(model=None):

    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    if model is not None:
        del model.llm_engine.model_executor.driver_worker
        del model  # Isn't necessary for releasing memory, but why not
        gc.collect()
        torch.cuda.empty_cache()
        time.sleep(2)
        torch.cuda.synchronize()
        time.sleep(2)
        gc.collect()

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB")


class TextValidator:
    def __init__(self, text, injected_list=None):
        self.text = text
        self.summary = None
        self.list_data = None
        self.json_data = None
        self.injected_list = injected_list
        self.result = Template("Summary:\n$summary\n\nRecommendations:\n$list_data\n")

    def extract_summary(self):
        match = re.search(r"(?<=[Ss]ummary:\n)[^\n]*", self.text, re.S)
        if match:
            self.summary = match.group(0).strip()

    def extract_list(self):
        list_pattern = re.search(r"\[([\s\S]*?)\]", self.text)
        if list_pattern:
            extracted_list = list_pattern.group(1)
            self.list_data = eval(f"[{extracted_list.strip()}]")
        else:
            self.list_data = []

    def reassemble(self):
        self.extract_summary()
        self.extract_list()
        res = ""

        if self.injected_list is not None:
            if (self.summary is not None) & (self.injected_list is not None):
                res = self.result.substitute(summary=self.summary,
                                             list_data=json.dumps(self.injected_list, indent=4))
        else:
            if (self.summary is not None) & (self.list_data is not None):
                res = self.result.substitute(summary=self.summary,
                                             list_data=json.dumps(self.list_data, indent=4))
        return res


def encode_image(image_path, max_size=(512, 512)):
    try:
        # Open the image file
        with Image.open(image_path) as img:
            # Check if the image exceeds max dimensions
            if img.size[0] > max_size[0] or img.size[1] > max_size[1]:
                # Resize the image to fit within max dimensions, maintaining aspect ratio
                img.thumbnail(max_size)

            if img.mode == "RGBA":
                img = img.convert("RGB")

            # display(img)

            # Encode the bytes buffer to base64
            return img

    except Exception as e:
        print(f"Error encoding image: {e}")
        return None


@dataclasses.dataclass
class VLCollator:
    def __init__(self, processor, model):
        self.processor = processor
        self.model = model

    def pad_tensor(self, tensor: torch.Tensor, target_size: int, padding_value: int = 0) -> torch.Tensor:
        current_size = tensor.size(1)
        if current_size >= target_size:
            return tensor[:, :target_size]

        batch_size = tensor.size(0)
        remaining_dims = tensor.size()[2:] if len(tensor.size()) > 2 else []
        padding_shape = (batch_size, target_size - current_size, *remaining_dims)

        return torch.cat([
            tensor,
            torch.full(padding_shape, padding_value, dtype=tensor.dtype, device='cpu')
        ], dim=1)

    def __call__(self, features: List[List[Dict]]) -> Dict:

        def create_assistant_mask(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
            """Create a mask for everything before the Assistant part.
            Returns None if no Assistant token is found, indicating to keep original labels."""
            batch_size, seq_length = input_ids.shape
            mask = torch.ones_like(input_ids, dtype=torch.bool)  # True means will be masked
            assistant_token_id = tokenizer.encode("Assistant", add_special_tokens=False)[0]

            found_assistant = False
            for i in range(batch_size):
                # Find the position of the Assistant token
                assistant_positions = (input_ids[i] == assistant_token_id).nonzero()
                if len(assistant_positions) > 0:
                    # Get the last Assistant token position (in case there are multiple)
                    assistant_pos = assistant_positions[-1]
                    # Unmask everything after the Assistant token
                    mask[i, assistant_pos + 1:] = False
                    found_assistant = True
            return mask if found_assistant else None

        # Process each conversation
        batch_outputs = [
            self.processor(
                conversations=feature,
                images=[feature[1]['images'][0]],
                force_batchify=True
            ) for feature in features
        ]

        # Find maximum sequence length
        max_length = max(output.input_ids.size(1) for output in batch_outputs)

        # Prepare padded tensors, keeping everything on CPU
        padded_outputs = [{
            'input_ids': self.pad_tensor(
                output.input_ids.cpu(),
                max_length,
                self.processor.tokenizer.pad_token_id or 0
            ),
            'attention_mask': self.pad_tensor(output.attention_mask.cpu(), max_length, 0),
            'images_seq_mask': self.pad_tensor(output.images_seq_mask.cpu(), max_length, 0),
            'images_emb_mask': output.images_emb_mask.cpu(),
            'pixel_values': output.pixel_values.to(torch.bfloat16).cpu()
        } for output in batch_outputs]

        # Stack tensors on CPU
        input_ids = torch.cat([out['input_ids'] for out in padded_outputs], dim=0)
        pixel_values = torch.cat([out['pixel_values'] for out in padded_outputs], dim=0)
        attention_mask = torch.cat([out['attention_mask'] for out in padded_outputs], dim=0)
        images_seq_mask = torch.cat([out['images_seq_mask'] for out in padded_outputs], dim=0)
        images_emb_mask = torch.cat([out['images_emb_mask'] for out in padded_outputs], dim=0)

        # Get embeddings and prepare labels
        with torch.no_grad():
            # Let the model wrapper handle device placement
            inputs_embeds = self.model.prepare_inputs_embeds(
                input_ids=input_ids,
                pixel_values=pixel_values,
                images_seq_mask=images_seq_mask,
                images_emb_mask=images_emb_mask
            )

            labels = input_ids.clone()
            assistant_mask = create_assistant_mask(input_ids, self.processor.tokenizer)
            if assistant_mask is not None:
                labels.masked_fill_(assistant_mask | images_seq_mask | (attention_mask == 0), -100)
            else:
                # Fallback to original masking if no Assistant token found
                labels.masked_fill_(images_seq_mask | (attention_mask == 0), -100)

            # labels.masked_fill_(images_seq_mask | (attention_mask == 0), -100)

        return {
            "inputs_embeds": inputs_embeds.cpu(),  # Return to CPU for pin_memory
            "attention_mask": attention_mask,
            "labels": labels
        }


def process_row(row_args):

    row, system_prompt, user_prompt, thumbnail_size, counter = row_args
    sized_image = encode_image(row['images'], max_size=thumbnail_size)

    # Construct the messages
    messages = [
        {
            "role": "System",
            "content": system_prompt,
        },
        {
            "role": "User",
            "content": f"<image_placeholder>{user_prompt}",
            "images": [sized_image],
        },
        {
            "role": "Assistant",
            "content": row['example'],
        },
    ]

    counter.put(1)

    return messages


def generate_examples(df, system_prompt, user_prompt, thumbnail_size):
    rows = df.to_dict(orient='records')  # Convert DataFrame to a list of dictionaries
    total_rows = len(rows)

    with Manager() as manager:

        counter = manager.Queue()
        row_args = [(row, system_prompt, user_prompt, thumbnail_size, counter)
                for row in rows]
        pbar = tqdm(total=total_rows, desc="Processing rows", unit=" rows")

        with Pool(cpu_count()) as pool:
            # Start the async result
            async_result = pool.map_async(process_row, row_args)

            # Update progress bar until completion
            completed = 0
            while not async_result.ready():
                # Get all current updates from queue
                while not counter.empty():
                    _ = counter.get()
                    completed += 1
                    pbar.update(1)
                async_result.wait(0.1)  # Small timeout to prevent busy waiting

            # Get any remaining updates
            while not counter.empty():
                _ = counter.get()
                completed += 1
                pbar.update(1)

            # Get the results
            batch_inputs = async_result.get()

            # Close progress bar
            pbar.close()

    return batch_inputs


def generate_text_from_sample(model, processor, sample, max_new_tokens=10_000, device="cuda"):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample[1:2], tokenize=False, add_generation_prompt=True  # Use the sample without the system message
    )

    # Process the visual input from the sample
    image_inputs, _ = process_vision_info(sample)

    # Prepare the inputs for the model
    model_inputs = processor(
        text=[text_input],
        images=image_inputs,
        return_tensors="pt",
    ).to(
        device
    )  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, truncation=True, max_length=8000
    )

    return output_text[0]


def get_model(model_id=None):
    vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_id)
    vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True,
                                                                         torch_dtype=torch.bfloat16)
    vl_gpt = vl_gpt.to(torch.bfloat16)

    # model = setup_model_with_rope_wrapper(model, accelerator)
    vl_gpt.config.use_cache = False
    return vl_gpt, vl_chat_processor


def calculate_lora_parameters(model):
    lora_params = 0
    trainable_params = 0
    for name, param in model.named_parameters():
        if "lora" in name and param.requires_grad:
            lora_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return lora_params, trainable_params


def check_rope_implementation(model):
    print("Checking RoPE implementation...")

    # Check if model has get_rope_index
    has_rope_index = hasattr(model, 'get_rope_index')
    print(f"Has get_rope_index method: {has_rope_index}")

    # Check if language model has RoPE
    if hasattr(model, 'language_model'):
        print("Checking language model attributes...")
        # Check for rotary embeddings in attention layers
        if hasattr(model.language_model, 'model') and hasattr(model.language_model.model, 'layers'):
            layer = model.language_model.model.layers[0]
            if hasattr(layer, 'self_attn'):
                has_rotary = hasattr(layer.self_attn, 'rotary_emb')
                print(f"Has rotary embeddings in attention: {has_rotary}")

    return has_rope_index


def setup_model_with_rope_wrapper(model, accelerator=None):
    import torch

    # For distributed training
    def get_device():
        if accelerator is not None:
            return accelerator.device
        elif torch.distributed.is_initialized():
            return f'cuda:{torch.distributed.get_rank()}'
        else:
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move model to device if not using accelerator
    if accelerator is None:
        model = model.to(get_device())

    if hasattr(model, 'get_rope_index'):
        original_get_rope_index = model.get_rope_index

        def get_rope_index_wrapper(*args, **kwargs):
            try:
                # Get current device from input_ids (first argument)
                target_device = args[0].device if args else get_device()

                # Handle attention mask in kwargs
                if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
                    kwargs['attention_mask'] = kwargs['attention_mask'].to(target_device)

                # Handle attention mask in positional args
                elif len(args) > 3 and args[3] is not None:
                    args = list(args)
                    args[3] = args[3].to(target_device)
                    args = tuple(args)

                return original_get_rope_index(*args, **kwargs)

            except Exception as e:
                print(f"Error in get_rope_index_wrapper: {str(e)}")
                print(f"Args devices: {[arg.device if torch.is_tensor(arg) else 'not tensor' for arg in args]}")
                raise e

        model.get_rope_index = get_rope_index_wrapper
        print("Successfully wrapped get_rope_index method for multi-GPU setup")

    return model


class DatasetWrapper:
    def __init__(self, data):
        self.data = data
        self.column_names = ['messages']  # Matches your data structure

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return raw messages, let collate_fn handle processing
        return self.data[idx]


class StreamingSFTTrainer(SFTTrainer):
    def _convert_to_native_types(self, obj):
        """Convert tensors and numpy arrays to native Python types"""
        if isinstance(obj, torch.Tensor):
            return obj.detach().cpu().item() if obj.numel() == 1 else obj.detach().cpu().tolist()
        elif isinstance(obj, np.ndarray):
            return obj.item() if obj.size == 1 else obj.tolist()
        elif isinstance(obj, dict) or dataclasses.is_dataclass(obj):
            if dataclasses.is_dataclass(obj):
                obj = dataclasses.asdict(obj)
            return {key: self._convert_to_native_types(value) for key, value in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [self._convert_to_native_types(item) for item in obj]
        return obj

    def evaluation_loop(
            self,
            dataloader,
            description: str,
            prediction_loss_only: bool = None,
            ignore_keys: bool = None,
            metric_key_prefix: str = "eval",
    ):
        """
        Override evaluation loop to compute metrics in streaming fashion
        """
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
        model.eval()

        # Initialize metrics storage
        total_eval_loss = 0.0  # Initialize as float
        total_samples = 0
        hit_rates = []
        coverages = []

        # Only compute metrics on rank 0 if using distributed training
        compute_metrics = dist.get_rank() == 0 if dist.is_initialized() else True

        for step, inputs in enumerate(dataloader):
            # Move inputs to appropriate device
            inputs = self._prepare_inputs(inputs)

            with torch.no_grad():
                # Forward pass
                outputs = model(**inputs)
                loss = outputs.loss

                # Update loss tracking - convert to float immediately
                total_eval_loss += loss.detach().cpu().float().item()
                total_samples += 1

                # Compute metrics if on rank 0
                if compute_metrics and self.compute_metrics is not None:
                    # Get predictions
                    logits = outputs.logits.detach().cpu()
                    labels = inputs["labels"].detach().cpu()

                    # Create EvalPrediction object
                    eval_pred = EvalPrediction(
                        predictions=(logits,),
                        label_ids=labels
                    )

                    # Compute metrics for this batch
                    try:
                        metrics = self.compute_metrics(eval_pred)
                        hit_rates.append(float(metrics.get("hit_rate", 0.0)))  # Convert to float
                        coverages.append(float(metrics.get("coverage", 0.0)))  # Convert to float
                    except Exception as e:
                        self.log({"eval_metric_error": str(e)})
                        continue

            # Log progress periodically
            if step % self.args.logging_steps == 0:
                current_metrics = {
                    "eval_step": step,
                    "eval_loss": total_eval_loss / max(total_samples, 1),
                    "current_hit_rate": float(np.mean(hit_rates)) if hit_rates else 0.0,
                    "current_coverage": float(np.mean(coverages)) if coverages else 0.0
                }
                # Convert metrics to native types before logging
                self.log(self._convert_to_native_types(current_metrics))

        # Compute final metrics
        metrics = {
            f"{metric_key_prefix}_loss": total_eval_loss / max(total_samples, 1)
        }

        # Add custom metrics if on rank 0
        if compute_metrics and hit_rates:
            metrics.update({
                f"{metric_key_prefix}_hit_rate": float(np.mean(hit_rates)),
                f"{metric_key_prefix}_coverage": float(np.mean(coverages)),
                f"{metric_key_prefix}_samples": len(hit_rates)
            })

        # Convert all metrics to native Python types
        metrics = self._convert_to_native_types(metrics)

        # Create output object matching parent class expectations
        class EvalLoopOutput:
            def __init__(self, metrics):
                self.metrics = metrics
                self.predictions = None
                self.label_ids = None
                self.num_samples = total_samples

        # Log final metrics
        self.log(metrics)

        return EvalLoopOutput(metrics)


def main():

    parser = argparse.ArgumentParser(description="Perform arithmetic operations.")
    parser.add_argument('-i', '--input_dataset', type=str, default=None, help="parquet input file")
    parser.add_argument('-o', '--output_dir', type=str, default=None, help="outputs folders")
    parser.add_argument('-v', '--verbose', action='store_true', help="Enable verbose output.")
    parser.add_argument('-g', '--gpu', type=int, default=1, help="Number of gpu to use.")
    parser.add_argument('-r', '--run_name', type=str, default="_", help="run_name decorator")
    parser.add_argument('-d', '--debug', action='store_true', help="Enable debug mode.")
    args = parser.parse_args()

    logger.info("Stage. Loading LLM Model")
    model_id = "deepseek-ai/deepseek-vl-7b-chat"

    with open("research/utilities/system_prompt.txt", "r") as f:
        system_prompt = f.read()

    with open("research/utilities/alternative_prompts/openai_prompt_v0.5.txt", "r") as f:
        user_prompt = f.read()

    logger.info("Stage. Reading Data")
    stage_start = time.time()
    complete_dataset = pd.read_parquet(f"{args.input_dataset}/complete_dataset.parquet")
    complete_dataset = complete_dataset[complete_dataset['example']!=""].reset_index(drop=True)

    # if in debug mode, limit the dataset to 100 rows
    if args.debug:
        complete_dataset = complete_dataset.iloc[:100]

    logger.info(f"raw data size: {len(complete_dataset)} | Stage duration: {(time.time() - stage_start) / 60:.2f} minutes")

    stage_start = time.time()
    complete_batchs = generate_examples(complete_dataset, system_prompt, user_prompt, (896, 896))

    training_dataset = DatasetWrapper(complete_batchs[:(int(len(complete_batchs)*0.9))])
    validation_dataset = DatasetWrapper(complete_batchs[int(len(complete_batchs) * 0.9):int(len(complete_batchs) * 0.95)])
    test_dataset = DatasetWrapper(complete_batchs[int(len(complete_batchs) * 0.95):len(complete_batchs)])
    del complete_batchs
    logger.info(f"Training dataset size: {len(training_dataset)}\n Validation dataset size: {len(validation_dataset)}\n "
                f"Test dataset size: {len(test_dataset)} | Stage duration: {(time.time() - stage_start) / 60:.2f} minutes")

    # remove LLM GPU signature and start training.
    logger.info("Stage. Model Parameters Definition")
    model, processor = get_model(model_id=model_id)
    # Check RoPE implementation
    needs_rope_wrapper = check_rope_implementation(model)

    if needs_rope_wrapper:
        logger.info("Model uses RoPE, applying wrapper...")
        logger.info = setup_model_with_rope_wrapper(model, accelerator)
    else:
        print("Model doesn't use get_rope_index method, no wrapper needed")

    model = VLTrainingWrapper(processor, model)

    logger.info("Stage. Starting Evaluation")
    stage_start = time.time()

    # Training arguments
    training_args = SFTConfig(
        output_dir=os.path.join(args.output_dir, f"deepseek-7b-chat-sft-facetune{args.run_name}"),
        num_train_epochs=3,
        per_device_train_batch_size=1,  # Batch size for training
        per_device_eval_batch_size=1,  # Batch size for evaluation
        gradient_accumulation_steps=1,  # Steps to accumulate gradients
        gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
        optim="adamw_torch_fused",  # Optimizer type
        fsdp_transformer_layer_cls_to_wrap="LlamaDecoderLayer",
        learning_rate=2e-4,  # Learning rate for training
        lr_scheduler_type="constant",  # Type of learning rate scheduler
        # Logging and evaluation
        logging_steps=50,  # Steps interval for logging
        eval_steps=100,  # Steps interval for evaluation
        eval_strategy="steps",  # Strategy for evaluation
        save_strategy="steps",  # Strategy for saving the model
        save_total_limit=1,
        bf16=True,  # Use bfloat16 precision
        tf32=True,  # Use TensorFloat-32 precision
        save_steps=100,  # Steps interval for saving
        metric_for_best_model="eval_loss",  # Metric to evaluate the best model
        greater_is_better=False,  # Whether higher metric values are better
        load_best_model_at_end=True,  # Load the best model after training
        # Mixed precision and gradient settings
        max_grad_norm=1.0,  # Maximum norm for gradient clipping
        warmup_ratio=0.03,  # Ratio of total steps for warmup
        report_to="wandb",
        run_name=f"ds-suggest-edit-vlm-deepseek{args.run_name}",
        # gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
        dataset_text_field="messages",
        dataset_kwargs={"skip_prepare_dataset": True},
        remove_unused_columns=False,
        ddp_find_unused_parameters=False,
        dataloader_drop_last=True,
        save_safetensors=True,
        push_to_hub=False,
        save_only_model = True
    )

    # Initialize the trainer
    # trainer = SFTTrainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=training_dataset,
    #     eval_dataset=validation_dataset,
    #     data_collator=VLCollator(processor, model),
    #     tokenizer=processor.tokenizer
    # )

    trainer = StreamingSFTTrainer(
        model=model,
        args=training_args,
        train_dataset=training_dataset,
        eval_dataset=validation_dataset,
        data_collator=VLCollator(processor, model),
        tokenizer=processor.tokenizer,
        compute_metrics=deep_seek_get_compute_metrics_fn(processor.tokenizer, logger)
    )

    trainer = accelerator.prepare(trainer)
    lora_params, trainable_params = calculate_lora_parameters(model)
    logger.info(f"2. LoRA params: {lora_params}, Trainable params: {trainable_params}")

    # eval_results = trainer.evaluate()
    # logger.info("Evaluation Results:", eval_results)

    logger.info("Stage. Training Begins")
    logger.info(f"Accelerator config: {accelerator.state}")
    logger.info(f"Local Rank: {accelerator.local_process_index}")
    logger.info(f"World Size: {accelerator.num_processes}")
    logger.info(f"Device: {accelerator.device}")

    # callback = DeviceCheckCallback()
    # callback.on_train_begin_prepare(trainer)  # Pass the trainer reference
    # trainer.add_callback(callback)

    # wandb_callback = LLMSampleCB(trainer, validation_dataset, num_samples=50, max_new_tokens=10_000, logger=logger)
    # trainer.add_callback(wandb_callback)

    # trainer.evaluate()
    trainer.train()

    # Save and push to hub
    # trainer.save_model(os.path.join(args.output_dir, f"deepseek-vl-7b-chat-sft-facetune{args.run_name}"))

    if trainer.is_world_process_zero():  # Only save on main process
        output_dir = os.path.join(args.output_dir, f"deepseek-7b-chat-sft-facetune{args.run_name}-final")
        # Save the base model configuration
        trainer.model.language_model.config.save_pretrained(output_dir)
        # Save the PEFT model and its configuration
        trainer.model.language_model.save_pretrained(output_dir)
        # Save the processor/tokenizer if needed
        processor.save_pretrained(output_dir)

    logger.info(f"Stage. Evaluation Complete | Duration {(time.time() - stage_start)/60:.2f} minutes")

    # adapter_path = os.path.join(args.output_dir, "qwen2-7b-instruct-sft-facetune-12k")
    # model.load_adapter(adapter_path)


if __name__ == '__main__':
    main()

Sign up or log in to comment