# app.py import gradio as gr import torch from torch.utils.data import DataLoader from torchvision import transforms from transformers import ( AutoTokenizer, CLIPTextModel, ) from diffusers import ( StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler, ) from diffusers.optimization import get_scheduler from datasets import load_dataset, Dataset from huggingface_hub import login, HfApi, Repository from pathlib import Path import os import zipfile from PIL import Image import pandas as pd import math from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from tqdm.auto import tqdm import torch.nn.functional as F # Set up logging logger = get_logger(__name__) def create_app(): with gr.Blocks() as demo: gr.Markdown("# Stable Diffusion Fine-Tuning Application") # Remove the Authentication Box since we'll use the environment variable """ # Authentication with gr.Box(): gr.Markdown("## Hugging Face Authentication") hf_token = gr.Textbox( label="Hugging Face API Token", placeholder="Enter your Hugging Face API token with write permissions", type="password", ) """ # Model Selection with gr.Row(): base_model = gr.Textbox( label="Base Model Name", placeholder="e.g., CompVis/stable-diffusion-v1-4", value="stabilityai/stable-diffusion-2-1-base", ) output_model_name = gr.Textbox( label="Output Model Repository Name", placeholder="Enter a unique name for your fine-tuned model (e.g., username/my-fine-tuned-model)", ) # Dataset Selection with gr.Group(): gr.Markdown("## Dataset Selection") dataset_source = gr.Radio( label="Dataset Source", choices=["Select from Hugging Face", "Upload your own"], value="Select from Hugging Face", ) dataset_name = gr.Textbox( label="Dataset Name (from Hugging Face Hub)", placeholder="Enter dataset path, e.g., username/dataset_name", visible=True, ) dataset_viewer_toggle = gr.Checkbox( label="Preview Dataset", value=False, ) dataset_preview = gr.Gallery( label="Dataset Preview", visible=False, height='auto', ) dataset_upload = gr.File( label="Upload Dataset (ZIP file containing images and annotations)", file_types=[".zip"], visible=False, ) def toggle_dataset_source(choice): return { dataset_name: gr.update(visible=choice == "Select from Hugging Face"), dataset_upload: gr.update(visible=choice == "Upload your own"), dataset_viewer_toggle: gr.update(visible=choice == "Select from Hugging Face"), } dataset_source.change( fn=toggle_dataset_source, inputs=dataset_source, outputs=[dataset_name, dataset_upload, dataset_viewer_toggle], ) # Column Mapping with gr.Group(): gr.Markdown("## Column Mapping") image_column = gr.Textbox( label="Image Column Name", placeholder="Column name for images", value="image", ) caption_column = gr.Textbox( label="Caption Column Name", placeholder="Column name for captions", value="text", ) # Training Parameters with gr.Group(): gr.Markdown("## Training Parameters") with gr.Row(): num_train_epochs = gr.Slider( label="Number of Training Epochs", minimum=1, maximum=100, value=1, step=1, ) max_train_steps = gr.Number( label="Max Training Steps", value=1000, ) train_batch_size = gr.Number( label="Train Batch Size", value=4, ) with gr.Row(): learning_rate = gr.Number( label="Learning Rate", value=5e-6, ) gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", value=1, ) checkpointing_steps = gr.Number( label="Checkpointing Steps", value=500, ) with gr.Row(): mixed_precision = gr.Radio( label="Mixed Precision", choices=["no", "fp16", "bf16"], value="fp16", ) use_8bit_adam = gr.Checkbox( label="Use 8-bit Adam Optimizer", value=True, ) use_xformers = gr.Checkbox( label="Enable XFormers Memory Efficient Attention", value=True, ) with gr.Row(): resolution = gr.Slider( label="Image Resolution", minimum=256, maximum=1024, value=512, step=64, ) lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], value="constant", ) lr_warmup_steps = gr.Number( label="Learning Rate Warmup Steps", value=0, ) seed = gr.Number( label="Seed", value=42, ) # Start Training Button start_training = gr.Button("Start Training") # Output training_output = gr.Textbox( label="Training Status", placeholder="Logs will appear here...", lines=10, ) # Dataset Viewer Functionality def preview_dataset(dataset_name, preview): if preview: try: dataset = load_dataset(dataset_name, split="train") images = [] for i in range(min(4, len(dataset))): image = dataset[i][image_column.value] if not isinstance(image, Image.Image): image = Image.open(image) images.append((image, dataset[i][caption_column.value])) return gr.update(visible=True), images except Exception as e: return gr.update(visible=False), f"Error loading dataset: {str(e)}" else: return gr.update(visible=False), None dataset_viewer_toggle.change( fn=preview_dataset, inputs=[dataset_name, dataset_viewer_toggle], outputs=[dataset_preview, dataset_preview], ) # Training Function def start_training_fn( # Removed hf_token from inputs since we're using the environment variable base_model_name, output_model_name, dataset_source, dataset_name, dataset_upload, image_column_name, caption_column_name, num_train_epochs, max_train_steps, train_batch_size, learning_rate, gradient_accumulation_steps, checkpointing_steps, mixed_precision, use_8bit_adam, use_xformers, resolution, lr_scheduler_type, lr_warmup_steps, seed, ): try: # Get the Hugging Face token from the environment variable hf_token = os.environ.get("HUGGINGFACE_TOKEN") if not hf_token: return "HUGGINGFACE_TOKEN environment variable not found. Please set it in your Space's secrets." # Validate inputs if not base_model_name.strip(): return "Please provide a base model name." if not output_model_name.strip(): return "Please provide an output model repository name." # Login to Hugging Face login(hf_token, add_to_git_credential=True) api = HfApi() # Load dataset if dataset_source == "Select from Hugging Face": if not dataset_name.strip(): return "Please provide the Hugging Face dataset name." dataset = load_dataset(dataset_name, split="train") else: if dataset_upload is None: return "Please upload a dataset." dataset = load_custom_dataset(dataset_upload.name) # Check if the specified columns exist if image_column_name not in dataset.column_names: return f"Image column '{image_column_name}' not found in the dataset." if caption_column_name not in dataset.column_names: return f"Caption column '{caption_column_name}' not found in the dataset." # Preprocess the dataset dataset = preprocess_dataset(dataset, image_column_name, caption_column_name, resolution) # Start training result = train_model( hf_token=hf_token, base_model_name=base_model_name, dataset=dataset, output_model_name=output_model_name, num_train_epochs=int(num_train_epochs), max_train_steps=int(max_train_steps), train_batch_size=int(train_batch_size), learning_rate=float(learning_rate), gradient_accumulation_steps=int(gradient_accumulation_steps), checkpointing_steps=int(checkpointing_steps), mixed_precision=mixed_precision, use_8bit_adam=use_8bit_adam, use_xformers=use_xformers, lr_scheduler_type=lr_scheduler_type, lr_warmup_steps=int(lr_warmup_steps), resolution=int(resolution), seed=int(seed), ) return result except Exception as e: return f"An error occurred during training: {str(e)}" start_training.click( fn=start_training_fn, inputs=[ # Removed hf_token from inputs base_model, output_model_name, dataset_source, dataset_name, dataset_upload, image_column, caption_column, num_train_epochs, max_train_steps, train_batch_size, learning_rate, gradient_accumulation_steps, checkpointing_steps, mixed_precision, use_8bit_adam, use_xformers, resolution, lr_scheduler, lr_warmup_steps, seed, ], outputs=training_output, ) return demo def preprocess_dataset(dataset, image_column_name, caption_column_name, resolution): tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") def process_example(example): # Load and preprocess image image = example[image_column_name] if not isinstance(image, Image.Image): image = Image.open(image).convert("RGB") transform = transforms.Compose([ transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) image = transform(image) # Tokenize caption caption = example[caption_column_name] tokens = tokenizer( caption, truncation=True, max_length=tokenizer.model_max_length, padding="max_length", return_tensors="pt", ) return { "pixel_values": image, "input_ids": tokens.input_ids.squeeze(), "attention_mask": tokens.attention_mask.squeeze(), } # Remove unused columns and map the dataset columns_to_remove = set(dataset.column_names) - {image_column_name, caption_column_name} dataset = dataset.map( process_example, remove_columns=list(columns_to_remove), batched=False, ) return dataset def load_custom_dataset(zip_file_path): # Extract the zip file with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: extract_path = Path("extracted_dataset") zip_ref.extractall(extract_path) # Find images and annotations image_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif'] images = [] captions = [] # Assuming there is annotations.csv with 'file_name' and 'caption' columns annotations_file = extract_path / 'annotations.csv' if not annotations_file.exists(): raise ValueError("annotations.csv not found in the dataset.") annotations = pd.read_csv(annotations_file) if 'file_name' not in annotations.columns or 'caption' not in annotations.columns: raise ValueError("annotations.csv must contain 'file_name' and 'caption' columns.") for idx, row in annotations.iterrows(): image_path = extract_path / row['file_name'] if image_path.exists(): images.append(str(image_path)) captions.append(row['caption']) else: raise ValueError(f"Image file {row['file_name']} not found in the dataset.") # Create dataset data = { "image": images, "text": captions, } dataset = Dataset.from_dict(data) return dataset def train_model( hf_token, base_model_name, dataset, output_model_name, num_train_epochs, max_train_steps, train_batch_size, learning_rate, gradient_accumulation_steps, checkpointing_steps, mixed_precision, use_8bit_adam, use_xformers, lr_scheduler_type, lr_warmup_steps, resolution, seed, ): # Set seed for reproducibility set_seed(seed) # Initialize Accelerator accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision, ) # Handle xformers if use_xformers: try: import xformers from xformers.ops import MemoryEfficientAttentionFlashAttentionOp xformers_available = True except ImportError: xformers_available = False print("xformers is not available. Please install it or disable xformers.") # Load tokenizer and models tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained( base_model_name, subfolder="text_encoder", ) vae = AutoencoderKL.from_pretrained( base_model_name, subfolder="vae", revision=None, ) unet = UNet2DConditionModel.from_pretrained( base_model_name, subfolder="unet", revision=None, ) # Freeze vae and text_encoder vae.eval() text_encoder.eval() for param in vae.parameters(): param.requires_grad = False for param in text_encoder.parameters(): param.requires_grad = False # Enable xformers if use_xformers: if xformers_available: unet.enable_xformers_memory_efficient_attention() else: return "Error: xformers is not installed. Please install xformers or disable it." # Prepare optimizer if use_8bit_adam: try: import bitsandbytes as bnb except ImportError: return "Error: bitsandbytes is not installed. Please install bitsandbytes or disable 8-bit Adam." optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( unet.parameters(), lr=learning_rate, ) # Prepare data loader train_dataloader = DataLoader( dataset, batch_size=train_batch_size, shuffle=True, num_workers=4 ) # Calculate total training steps overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) if max_train_steps is None or max_train_steps == 0: max_train_steps = num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True else: num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # Prepare learning rate scheduler lr_scheduler = get_scheduler( lr_scheduler_type, optimizer=optimizer, num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, num_training_steps=max_train_steps * gradient_accumulation_steps, ) # Prepare everything with accelerator unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) vae.to(accelerator.device) text_encoder.to(accelerator.device) # Move first sample to device to check for any errors try: batch = next(iter(train_dataloader)) batch['pixel_values'] = batch['pixel_values'].to(accelerator.device) batch['input_ids'] = batch['input_ids'].to(accelerator.device) batch['attention_mask'] = batch['attention_mask'].to(accelerator.device) except Exception as e: return f"Error in moving batch to device: {str(e)}" # Set up the noise scheduler noise_scheduler = DDPMScheduler.from_config(base_model_name, subfolder="scheduler") # Training loop total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps print("***** Running training *****") print(f" Num examples = {len(dataset)}") print(f" Num Epochs = {num_train_epochs}") print(f" Instantaneous batch size per device = {train_batch_size}") print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") print(f" Gradient Accumulation steps = {gradient_accumulation_steps}") print(f" Total optimization steps = {max_train_steps}") progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Training") global_step = 0 for epoch in range(num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=accelerator.dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise to add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() # Add noise to the latents according to the noise magnitude at each timestep noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Compute loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) # Update the model parameters optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Logging if accelerator.is_main_process: progress_bar.update(1) progress_bar.set_postfix(loss=loss.item()) global_step += 1 if global_step % checkpointing_steps == 0: # Save a checkpoint save_path = f"{output_model_name}_checkpoint_{global_step}" accelerator.save_state(save_path) if global_step >= max_train_steps: break if global_step >= max_train_steps: break # Save the final model if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) pipeline = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=DDPMScheduler.from_config(base_model_name, subfolder="scheduler"), safety_checker=None, feature_extractor=None, ) pipeline.save_pretrained(output_model_name) # Upload to Hugging Face Hub api = HfApi() repo_url = api.create_repo( name=output_model_name, token=hf_token, private=False, exist_ok=True, ) repo = Repository(output_model_name, clone_from=repo_url) repo.push_to_hub(commit_message=f"Fine-tuned model at step {global_step}") return f"Training complete. The model has been uploaded to Hugging Face Hub at {repo_url}" app = create_app() # Start the Gradio app if __name__ == "__main__": app.launch()