#!/usr/bin/env python # -*- coding: utf-8 -*- """ This script automates the process of updating a Stable Diffusion training script with settings extracted from a LoRA model's JSON metadata. It performs the following main tasks: 1. Reads a JSON file containing LoRA model metadata 2. Parses an existing Stable Diffusion training script 3. Maps metadata keys to corresponding script arguments 4. Updates the script with values from the metadata 5. Handles special cases and complex arguments (e.g., network_args) 6. Writes the updated script to a new file Usage: python steal_sdscripts_metadata This tool is particularly useful for replicating training conditions or fine-tuning existing models based on successful previous runs. """ import json import re import argparse # Parse command-line arguments parser = argparse.ArgumentParser( description='Update training script based on metadata.' ) parser.add_argument( 'metadata_file', type=str, help='Path to the metadata JSON file' ) parser.add_argument( 'script_file', type=str, help='Path to the training script file' ) parser.add_argument( 'output_file', type=str, help='Path to save the updated training script' ) args = parser.parse_args() # Read the metadata JSON file with open(args.metadata_file, 'r', encoding='utf-8') as f: metadata = json.load(f) # Read the training script with open(args.script_file, 'r', encoding='utf-8') as f: script_content = f.read() # Define mappings between JSON keys and script arguments mappings = { 'ss_network_dim': '--network_dim', 'ss_network_alpha': '--network_alpha', 'ss_learning_rate': '--learning_rate', 'ss_unet_lr': '--unet_lr', 'ss_text_encoder_lr': '--text_encoder_lr', 'ss_max_train_steps': '--max_train_steps', 'ss_train_batch_size': '--train_batch_size', 'ss_gradient_accumulation_steps': '--gradient_accumulation_steps', 'ss_mixed_precision': '--mixed_precision', 'ss_seed': '--seed', 'ss_resolution': '--resolution', 'ss_clip_skip': '--clip_skip', 'ss_lr_scheduler': '--lr_scheduler', 'ss_network_module': '--network_module', } # Update script content based on metadata for json_key, script_arg in mappings.items(): if json_key in metadata: value = metadata[json_key] # Handle special cases if json_key == 'ss_resolution': value = f'"{value[1:-1]}"' # Remove parentheses and add quotes elif isinstance(value, str): value = f'"{value}"' # Replace or add the argument in the script pattern = f'{script_arg}=\\S+' replacement = f'{script_arg}={value}' if re.search(pattern, script_content): script_content = re.sub(pattern, replacement, script_content) else: script_content = script_content.replace( 'args=(', f'args=(\n {replacement}' ) # Handle network_args separately as it's more complex if 'ss_network_args' in metadata: network_args = metadata['ss_network_args'] NETWORK_ARGS_STR = ' '.join( [f'"{k}={v}"' for k, v in network_args.items()] ) PATTERN = r'--network_args(\s+".+")+' replacement = f'--network_args\n {NETWORK_ARGS_STR}' script_content = re.sub(PATTERN, replacement, script_content) # Write the updated script with open(args.output_file, 'w', encoding='utf-8') as f: f.write(script_content) print(f"Updated training script has been saved as '{args.output_file}'")