how to fine tune?
#1
by
NickyNicky
- opened
Thank you very much for the model, I really liked it and I would like to be able to use it to make fine tune, do you have a Google Colab so I can interact and be able to make fine tune?
thank you so much.
+1
+1
+1
I am running this, so train the language part of it. if something has a more generic one that would be helpful
import os
import re
from typing import Dict, Union, List
import dataclasses
import gc
import json
import time
from datetime import datetime
import logging
import argparse
import pandas as pd
import numpy as np
from functools import partial
from IPython.display import display
from tqdm.auto import tqdm
import openai
from PIL import Image
from pillow_heif import register_heif_opener
import io
import base64
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TrainingArguments, EvalPrediction
from transformers.integrations import WandbCallback
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from string import Template
from research.utilities.custom_callbacks import *
from trl import ModelConfig, SFTConfig, SFTTrainer
from transformers.trainer_callback import TrainerCallback
from peft import LoraConfig, get_peft_model, TaskType
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import wandb
from multiprocessing import Pool, cpu_count, Manager
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
StateDictType,
)
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
use_orig_params=True
)
register_heif_opener()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
pd.set_option('display.max_columns', None)
logging.basicConfig(
level=logging.INFO, # Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()]
)
logger = logging.getLogger("VLM fine tuning | Training")
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
class VLTrainingWrapper(torch.nn.Module):
def __init__(self, processor, model):
super().__init__()
# Store only what we need, avoid keeping multiple references
vision_model = model.vision_model
aligner = model.aligner
language_model = model.language_model
self.processor = processor
self.processor.tokenizer.padding_side = 'right'
# Clear the original model to free memory and avoid refs
del model
# Store components we need
self.vision_model = vision_model
self.aligner = aligner
# Freeze vision components
self.vision_model.requires_grad_(False)
self.aligner.requires_grad_(False)
# Apply LoRA only to the language model part
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
inference_mode=False
)
# Convert and wrap language model
language_model = language_model.to(torch.bfloat16)
self.language_model = get_peft_model(language_model, peft_config)
del language_model
# Ensure LoRA parameters are trainable
for name, param in self.language_model.named_parameters():
if 'lora' in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
def get_vision_device(self):
return next(self.vision_model.parameters()).device
def get_language_device(self):
return next(self.language_model.parameters()).device
def prepare_inputs_embeds(self, input_ids=None, pixel_values=None,
images_seq_mask=None, images_emb_mask=None):
vision_device = self.get_vision_device()
language_device = self.get_language_device()
# Process vision inputs
if len(pixel_values.shape) == 5:
pixel_values = pixel_values.squeeze(1)
if pixel_values.shape[-1] == 3:
pixel_values = pixel_values.permute(0, 3, 1, 2)
pixel_values = pixel_values.to(vision_device, dtype=torch.bfloat16)
with torch.no_grad():
vision_embeds = self.vision_model(pixel_values)
aligned_embeds = self.aligner(vision_embeds)
del vision_embeds # Free memory
# Process text inputs
input_ids = input_ids.to(language_device)
images_seq_mask = images_seq_mask.to(language_device)
text_embeds = self.language_model.get_input_embeddings()(input_ids)
# Get dimensions
bsz = input_ids.shape[0]
text_seq_length = text_embeds.shape[1]
vision_seq_length = aligned_embeds.shape[1]
seq_length = images_seq_mask.shape[1]
# Create embeddings tensor
inputs_embeds = torch.zeros(
(bsz, seq_length, text_embeds.shape[-1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
requires_grad=True
)
# Move aligned embeds to same device
aligned_embeds = aligned_embeds.to(text_embeds.device)
# Combine embeddings
for b in range(bsz):
text_idx = 0
vision_idx = 0
for pos in range(seq_length):
if pos < images_seq_mask.shape[1] and images_seq_mask[b][pos]:
if vision_idx < vision_seq_length:
inputs_embeds[b, pos] = aligned_embeds[b, vision_idx]
vision_idx += 1
else:
if text_idx < text_seq_length:
inputs_embeds[b, pos] = text_embeds[b, text_idx]
text_idx += 1
del aligned_embeds, text_embeds # Free memory
return inputs_embeds
def forward(self, inputs_embeds, attention_mask, labels=None):
device = self.get_language_device()
inputs_embeds = inputs_embeds.to(device)
attention_mask = attention_mask.to(device)
if labels is not None:
labels = labels.to(device)
return self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
pad_token_id=self.processor.tokenizer.eos_token_id,
bos_token_id=self.processor.tokenizer.bos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
do_sample=False,
use_cache=True
)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if hasattr(self.language_model, "gradient_checkpointing_enable"):
self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
# Ensure trainable parameters have requires_grad=True after enabling checkpointing
for name, param in self.language_model.named_parameters():
if 'lora' in name:
param.requires_grad_(True)
def gradient_checkpointing_disable(self):
if hasattr(self.language_model, "gradient_checkpointing_disable"):
self.language_model.gradient_checkpointing_disable()
def clear_memory(model=None):
if "inputs" in globals():
del globals()["inputs"]
if "model" in globals():
del globals()["model"]
if "processor" in globals():
del globals()["processor"]
if "trainer" in globals():
del globals()["trainer"]
if "peft_model" in globals():
del globals()["peft_model"]
if "bnb_config" in globals():
del globals()["bnb_config"]
time.sleep(2)
if model is not None:
del model.llm_engine.model_executor.driver_worker
del model # Isn't necessary for releasing memory, but why not
gc.collect()
torch.cuda.empty_cache()
time.sleep(2)
torch.cuda.synchronize()
time.sleep(2)
gc.collect()
print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB")
class TextValidator:
def __init__(self, text, injected_list=None):
self.text = text
self.summary = None
self.list_data = None
self.json_data = None
self.injected_list = injected_list
self.result = Template("Summary:\n$summary\n\nRecommendations:\n$list_data\n")
def extract_summary(self):
match = re.search(r"(?<=[Ss]ummary:\n)[^\n]*", self.text, re.S)
if match:
self.summary = match.group(0).strip()
def extract_list(self):
list_pattern = re.search(r"\[([\s\S]*?)\]", self.text)
if list_pattern:
extracted_list = list_pattern.group(1)
self.list_data = eval(f"[{extracted_list.strip()}]")
else:
self.list_data = []
def reassemble(self):
self.extract_summary()
self.extract_list()
res = ""
if self.injected_list is not None:
if (self.summary is not None) & (self.injected_list is not None):
res = self.result.substitute(summary=self.summary,
list_data=json.dumps(self.injected_list, indent=4))
else:
if (self.summary is not None) & (self.list_data is not None):
res = self.result.substitute(summary=self.summary,
list_data=json.dumps(self.list_data, indent=4))
return res
def encode_image(image_path, max_size=(512, 512)):
try:
# Open the image file
with Image.open(image_path) as img:
# Check if the image exceeds max dimensions
if img.size[0] > max_size[0] or img.size[1] > max_size[1]:
# Resize the image to fit within max dimensions, maintaining aspect ratio
img.thumbnail(max_size)
if img.mode == "RGBA":
img = img.convert("RGB")
# display(img)
# Encode the bytes buffer to base64
return img
except Exception as e:
print(f"Error encoding image: {e}")
return None
@dataclasses.dataclass
class VLCollator:
def __init__(self, processor, model):
self.processor = processor
self.model = model
def pad_tensor(self, tensor: torch.Tensor, target_size: int, padding_value: int = 0) -> torch.Tensor:
current_size = tensor.size(1)
if current_size >= target_size:
return tensor[:, :target_size]
batch_size = tensor.size(0)
remaining_dims = tensor.size()[2:] if len(tensor.size()) > 2 else []
padding_shape = (batch_size, target_size - current_size, *remaining_dims)
return torch.cat([
tensor,
torch.full(padding_shape, padding_value, dtype=tensor.dtype, device='cpu')
], dim=1)
def __call__(self, features: List[List[Dict]]) -> Dict:
def create_assistant_mask(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
"""Create a mask for everything before the Assistant part.
Returns None if no Assistant token is found, indicating to keep original labels."""
batch_size, seq_length = input_ids.shape
mask = torch.ones_like(input_ids, dtype=torch.bool) # True means will be masked
assistant_token_id = tokenizer.encode("Assistant", add_special_tokens=False)[0]
found_assistant = False
for i in range(batch_size):
# Find the position of the Assistant token
assistant_positions = (input_ids[i] == assistant_token_id).nonzero()
if len(assistant_positions) > 0:
# Get the last Assistant token position (in case there are multiple)
assistant_pos = assistant_positions[-1]
# Unmask everything after the Assistant token
mask[i, assistant_pos + 1:] = False
found_assistant = True
return mask if found_assistant else None
# Process each conversation
batch_outputs = [
self.processor(
conversations=feature,
images=[feature[1]['images'][0]],
force_batchify=True
) for feature in features
]
# Find maximum sequence length
max_length = max(output.input_ids.size(1) for output in batch_outputs)
# Prepare padded tensors, keeping everything on CPU
padded_outputs = [{
'input_ids': self.pad_tensor(
output.input_ids.cpu(),
max_length,
self.processor.tokenizer.pad_token_id or 0
),
'attention_mask': self.pad_tensor(output.attention_mask.cpu(), max_length, 0),
'images_seq_mask': self.pad_tensor(output.images_seq_mask.cpu(), max_length, 0),
'images_emb_mask': output.images_emb_mask.cpu(),
'pixel_values': output.pixel_values.to(torch.bfloat16).cpu()
} for output in batch_outputs]
# Stack tensors on CPU
input_ids = torch.cat([out['input_ids'] for out in padded_outputs], dim=0)
pixel_values = torch.cat([out['pixel_values'] for out in padded_outputs], dim=0)
attention_mask = torch.cat([out['attention_mask'] for out in padded_outputs], dim=0)
images_seq_mask = torch.cat([out['images_seq_mask'] for out in padded_outputs], dim=0)
images_emb_mask = torch.cat([out['images_emb_mask'] for out in padded_outputs], dim=0)
# Get embeddings and prepare labels
with torch.no_grad():
# Let the model wrapper handle device placement
inputs_embeds = self.model.prepare_inputs_embeds(
input_ids=input_ids,
pixel_values=pixel_values,
images_seq_mask=images_seq_mask,
images_emb_mask=images_emb_mask
)
labels = input_ids.clone()
assistant_mask = create_assistant_mask(input_ids, self.processor.tokenizer)
if assistant_mask is not None:
labels.masked_fill_(assistant_mask | images_seq_mask | (attention_mask == 0), -100)
else:
# Fallback to original masking if no Assistant token found
labels.masked_fill_(images_seq_mask | (attention_mask == 0), -100)
# labels.masked_fill_(images_seq_mask | (attention_mask == 0), -100)
return {
"inputs_embeds": inputs_embeds.cpu(), # Return to CPU for pin_memory
"attention_mask": attention_mask,
"labels": labels
}
def process_row(row_args):
row, system_prompt, user_prompt, thumbnail_size, counter = row_args
sized_image = encode_image(row['images'], max_size=thumbnail_size)
# Construct the messages
messages = [
{
"role": "System",
"content": system_prompt,
},
{
"role": "User",
"content": f"<image_placeholder>{user_prompt}",
"images": [sized_image],
},
{
"role": "Assistant",
"content": row['example'],
},
]
counter.put(1)
return messages
def generate_examples(df, system_prompt, user_prompt, thumbnail_size):
rows = df.to_dict(orient='records') # Convert DataFrame to a list of dictionaries
total_rows = len(rows)
with Manager() as manager:
counter = manager.Queue()
row_args = [(row, system_prompt, user_prompt, thumbnail_size, counter)
for row in rows]
pbar = tqdm(total=total_rows, desc="Processing rows", unit=" rows")
with Pool(cpu_count()) as pool:
# Start the async result
async_result = pool.map_async(process_row, row_args)
# Update progress bar until completion
completed = 0
while not async_result.ready():
# Get all current updates from queue
while not counter.empty():
_ = counter.get()
completed += 1
pbar.update(1)
async_result.wait(0.1) # Small timeout to prevent busy waiting
# Get any remaining updates
while not counter.empty():
_ = counter.get()
completed += 1
pbar.update(1)
# Get the results
batch_inputs = async_result.get()
# Close progress bar
pbar.close()
return batch_inputs
def generate_text_from_sample(model, processor, sample, max_new_tokens=10_000, device="cuda"):
# Prepare the text input by applying the chat template
text_input = processor.apply_chat_template(
sample[1:2], tokenize=False, add_generation_prompt=True # Use the sample without the system message
)
# Process the visual input from the sample
image_inputs, _ = process_vision_info(sample)
# Prepare the inputs for the model
model_inputs = processor(
text=[text_input],
images=image_inputs,
return_tensors="pt",
).to(
device
) # Move inputs to the specified device
# Generate text with the model
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
# Trim the generated ids to remove the input ids
trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]
# Decode the output text
output_text = processor.batch_decode(
trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, truncation=True, max_length=8000
)
return output_text[0]
def get_model(model_id=None):
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_id)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True,
torch_dtype=torch.bfloat16)
vl_gpt = vl_gpt.to(torch.bfloat16)
# model = setup_model_with_rope_wrapper(model, accelerator)
vl_gpt.config.use_cache = False
return vl_gpt, vl_chat_processor
def calculate_lora_parameters(model):
lora_params = 0
trainable_params = 0
for name, param in model.named_parameters():
if "lora" in name and param.requires_grad:
lora_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return lora_params, trainable_params
def check_rope_implementation(model):
print("Checking RoPE implementation...")
# Check if model has get_rope_index
has_rope_index = hasattr(model, 'get_rope_index')
print(f"Has get_rope_index method: {has_rope_index}")
# Check if language model has RoPE
if hasattr(model, 'language_model'):
print("Checking language model attributes...")
# Check for rotary embeddings in attention layers
if hasattr(model.language_model, 'model') and hasattr(model.language_model.model, 'layers'):
layer = model.language_model.model.layers[0]
if hasattr(layer, 'self_attn'):
has_rotary = hasattr(layer.self_attn, 'rotary_emb')
print(f"Has rotary embeddings in attention: {has_rotary}")
return has_rope_index
def setup_model_with_rope_wrapper(model, accelerator=None):
import torch
# For distributed training
def get_device():
if accelerator is not None:
return accelerator.device
elif torch.distributed.is_initialized():
return f'cuda:{torch.distributed.get_rank()}'
else:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to device if not using accelerator
if accelerator is None:
model = model.to(get_device())
if hasattr(model, 'get_rope_index'):
original_get_rope_index = model.get_rope_index
def get_rope_index_wrapper(*args, **kwargs):
try:
# Get current device from input_ids (first argument)
target_device = args[0].device if args else get_device()
# Handle attention mask in kwargs
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
kwargs['attention_mask'] = kwargs['attention_mask'].to(target_device)
# Handle attention mask in positional args
elif len(args) > 3 and args[3] is not None:
args = list(args)
args[3] = args[3].to(target_device)
args = tuple(args)
return original_get_rope_index(*args, **kwargs)
except Exception as e:
print(f"Error in get_rope_index_wrapper: {str(e)}")
print(f"Args devices: {[arg.device if torch.is_tensor(arg) else 'not tensor' for arg in args]}")
raise e
model.get_rope_index = get_rope_index_wrapper
print("Successfully wrapped get_rope_index method for multi-GPU setup")
return model
class DatasetWrapper:
def __init__(self, data):
self.data = data
self.column_names = ['messages'] # Matches your data structure
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# Return raw messages, let collate_fn handle processing
return self.data[idx]
class StreamingSFTTrainer(SFTTrainer):
def _convert_to_native_types(self, obj):
"""Convert tensors and numpy arrays to native Python types"""
if isinstance(obj, torch.Tensor):
return obj.detach().cpu().item() if obj.numel() == 1 else obj.detach().cpu().tolist()
elif isinstance(obj, np.ndarray):
return obj.item() if obj.size == 1 else obj.tolist()
elif isinstance(obj, dict) or dataclasses.is_dataclass(obj):
if dataclasses.is_dataclass(obj):
obj = dataclasses.asdict(obj)
return {key: self._convert_to_native_types(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [self._convert_to_native_types(item) for item in obj]
return obj
def evaluation_loop(
self,
dataloader,
description: str,
prediction_loss_only: bool = None,
ignore_keys: bool = None,
metric_key_prefix: str = "eval",
):
"""
Override evaluation loop to compute metrics in streaming fashion
"""
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
model.eval()
# Initialize metrics storage
total_eval_loss = 0.0 # Initialize as float
total_samples = 0
hit_rates = []
coverages = []
# Only compute metrics on rank 0 if using distributed training
compute_metrics = dist.get_rank() == 0 if dist.is_initialized() else True
for step, inputs in enumerate(dataloader):
# Move inputs to appropriate device
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
# Forward pass
outputs = model(**inputs)
loss = outputs.loss
# Update loss tracking - convert to float immediately
total_eval_loss += loss.detach().cpu().float().item()
total_samples += 1
# Compute metrics if on rank 0
if compute_metrics and self.compute_metrics is not None:
# Get predictions
logits = outputs.logits.detach().cpu()
labels = inputs["labels"].detach().cpu()
# Create EvalPrediction object
eval_pred = EvalPrediction(
predictions=(logits,),
label_ids=labels
)
# Compute metrics for this batch
try:
metrics = self.compute_metrics(eval_pred)
hit_rates.append(float(metrics.get("hit_rate", 0.0))) # Convert to float
coverages.append(float(metrics.get("coverage", 0.0))) # Convert to float
except Exception as e:
self.log({"eval_metric_error": str(e)})
continue
# Log progress periodically
if step % self.args.logging_steps == 0:
current_metrics = {
"eval_step": step,
"eval_loss": total_eval_loss / max(total_samples, 1),
"current_hit_rate": float(np.mean(hit_rates)) if hit_rates else 0.0,
"current_coverage": float(np.mean(coverages)) if coverages else 0.0
}
# Convert metrics to native types before logging
self.log(self._convert_to_native_types(current_metrics))
# Compute final metrics
metrics = {
f"{metric_key_prefix}_loss": total_eval_loss / max(total_samples, 1)
}
# Add custom metrics if on rank 0
if compute_metrics and hit_rates:
metrics.update({
f"{metric_key_prefix}_hit_rate": float(np.mean(hit_rates)),
f"{metric_key_prefix}_coverage": float(np.mean(coverages)),
f"{metric_key_prefix}_samples": len(hit_rates)
})
# Convert all metrics to native Python types
metrics = self._convert_to_native_types(metrics)
# Create output object matching parent class expectations
class EvalLoopOutput:
def __init__(self, metrics):
self.metrics = metrics
self.predictions = None
self.label_ids = None
self.num_samples = total_samples
# Log final metrics
self.log(metrics)
return EvalLoopOutput(metrics)
def main():
parser = argparse.ArgumentParser(description="Perform arithmetic operations.")
parser.add_argument('-i', '--input_dataset', type=str, default=None, help="parquet input file")
parser.add_argument('-o', '--output_dir', type=str, default=None, help="outputs folders")
parser.add_argument('-v', '--verbose', action='store_true', help="Enable verbose output.")
parser.add_argument('-g', '--gpu', type=int, default=1, help="Number of gpu to use.")
parser.add_argument('-r', '--run_name', type=str, default="_", help="run_name decorator")
parser.add_argument('-d', '--debug', action='store_true', help="Enable debug mode.")
args = parser.parse_args()
logger.info("Stage. Loading LLM Model")
model_id = "deepseek-ai/deepseek-vl-7b-chat"
with open("research/utilities/system_prompt.txt", "r") as f:
system_prompt = f.read()
with open("research/utilities/alternative_prompts/openai_prompt_v0.5.txt", "r") as f:
user_prompt = f.read()
logger.info("Stage. Reading Data")
stage_start = time.time()
complete_dataset = pd.read_parquet(f"{args.input_dataset}/complete_dataset.parquet")
complete_dataset = complete_dataset[complete_dataset['example']!=""].reset_index(drop=True)
# if in debug mode, limit the dataset to 100 rows
if args.debug:
complete_dataset = complete_dataset.iloc[:100]
logger.info(f"raw data size: {len(complete_dataset)} | Stage duration: {(time.time() - stage_start) / 60:.2f} minutes")
stage_start = time.time()
complete_batchs = generate_examples(complete_dataset, system_prompt, user_prompt, (896, 896))
training_dataset = DatasetWrapper(complete_batchs[:(int(len(complete_batchs)*0.9))])
validation_dataset = DatasetWrapper(complete_batchs[int(len(complete_batchs) * 0.9):int(len(complete_batchs) * 0.95)])
test_dataset = DatasetWrapper(complete_batchs[int(len(complete_batchs) * 0.95):len(complete_batchs)])
del complete_batchs
logger.info(f"Training dataset size: {len(training_dataset)}\n Validation dataset size: {len(validation_dataset)}\n "
f"Test dataset size: {len(test_dataset)} | Stage duration: {(time.time() - stage_start) / 60:.2f} minutes")
# remove LLM GPU signature and start training.
logger.info("Stage. Model Parameters Definition")
model, processor = get_model(model_id=model_id)
# Check RoPE implementation
needs_rope_wrapper = check_rope_implementation(model)
if needs_rope_wrapper:
logger.info("Model uses RoPE, applying wrapper...")
logger.info = setup_model_with_rope_wrapper(model, accelerator)
else:
print("Model doesn't use get_rope_index method, no wrapper needed")
model = VLTrainingWrapper(processor, model)
logger.info("Stage. Starting Evaluation")
stage_start = time.time()
# Training arguments
training_args = SFTConfig(
output_dir=os.path.join(args.output_dir, f"deepseek-7b-chat-sft-facetune{args.run_name}"),
num_train_epochs=3,
per_device_train_batch_size=1, # Batch size for training
per_device_eval_batch_size=1, # Batch size for evaluation
gradient_accumulation_steps=1, # Steps to accumulate gradients
gradient_checkpointing=True, # Enable gradient checkpointing for memory efficiency
optim="adamw_torch_fused", # Optimizer type
fsdp_transformer_layer_cls_to_wrap="LlamaDecoderLayer",
learning_rate=2e-4, # Learning rate for training
lr_scheduler_type="constant", # Type of learning rate scheduler
# Logging and evaluation
logging_steps=50, # Steps interval for logging
eval_steps=100, # Steps interval for evaluation
eval_strategy="steps", # Strategy for evaluation
save_strategy="steps", # Strategy for saving the model
save_total_limit=1,
bf16=True, # Use bfloat16 precision
tf32=True, # Use TensorFloat-32 precision
save_steps=100, # Steps interval for saving
metric_for_best_model="eval_loss", # Metric to evaluate the best model
greater_is_better=False, # Whether higher metric values are better
load_best_model_at_end=True, # Load the best model after training
# Mixed precision and gradient settings
max_grad_norm=1.0, # Maximum norm for gradient clipping
warmup_ratio=0.03, # Ratio of total steps for warmup
report_to="wandb",
run_name=f"ds-suggest-edit-vlm-deepseek{args.run_name}",
# gradient_checkpointing_kwargs={"use_reentrant": False}, # Options for gradient checkpointing
dataset_text_field="messages",
dataset_kwargs={"skip_prepare_dataset": True},
remove_unused_columns=False,
ddp_find_unused_parameters=False,
dataloader_drop_last=True,
save_safetensors=True,
push_to_hub=False,
save_only_model = True
)
# Initialize the trainer
# trainer = SFTTrainer(
# model=model,
# args=training_args,
# train_dataset=training_dataset,
# eval_dataset=validation_dataset,
# data_collator=VLCollator(processor, model),
# tokenizer=processor.tokenizer
# )
trainer = StreamingSFTTrainer(
model=model,
args=training_args,
train_dataset=training_dataset,
eval_dataset=validation_dataset,
data_collator=VLCollator(processor, model),
tokenizer=processor.tokenizer,
compute_metrics=deep_seek_get_compute_metrics_fn(processor.tokenizer, logger)
)
trainer = accelerator.prepare(trainer)
lora_params, trainable_params = calculate_lora_parameters(model)
logger.info(f"2. LoRA params: {lora_params}, Trainable params: {trainable_params}")
# eval_results = trainer.evaluate()
# logger.info("Evaluation Results:", eval_results)
logger.info("Stage. Training Begins")
logger.info(f"Accelerator config: {accelerator.state}")
logger.info(f"Local Rank: {accelerator.local_process_index}")
logger.info(f"World Size: {accelerator.num_processes}")
logger.info(f"Device: {accelerator.device}")
# callback = DeviceCheckCallback()
# callback.on_train_begin_prepare(trainer) # Pass the trainer reference
# trainer.add_callback(callback)
# wandb_callback = LLMSampleCB(trainer, validation_dataset, num_samples=50, max_new_tokens=10_000, logger=logger)
# trainer.add_callback(wandb_callback)
# trainer.evaluate()
trainer.train()
# Save and push to hub
# trainer.save_model(os.path.join(args.output_dir, f"deepseek-vl-7b-chat-sft-facetune{args.run_name}"))
if trainer.is_world_process_zero(): # Only save on main process
output_dir = os.path.join(args.output_dir, f"deepseek-7b-chat-sft-facetune{args.run_name}-final")
# Save the base model configuration
trainer.model.language_model.config.save_pretrained(output_dir)
# Save the PEFT model and its configuration
trainer.model.language_model.save_pretrained(output_dir)
# Save the processor/tokenizer if needed
processor.save_pretrained(output_dir)
logger.info(f"Stage. Evaluation Complete | Duration {(time.time() - stage_start)/60:.2f} minutes")
# adapter_path = os.path.join(args.output_dir, "qwen2-7b-instruct-sft-facetune-12k")
# model.load_adapter(adapter_path)
if __name__ == '__main__':
main()