|
|
|
|
|
|
|
"""
|
|
This script automates the process of updating a Stable Diffusion training
|
|
script with settings extracted from a LoRA model's JSON metadata.
|
|
|
|
It performs the following main tasks:
|
|
1. Reads a JSON file containing LoRA model metadata
|
|
2. Parses an existing Stable Diffusion training script
|
|
3. Maps metadata keys to corresponding script arguments
|
|
4. Updates the script with values from the metadata
|
|
5. Handles special cases and complex arguments (e.g., network_args)
|
|
6. Writes the updated script to a new file
|
|
|
|
Usage:
|
|
python steal_sdscripts_metadata <metadata_file> <script_file> <output_file>
|
|
|
|
This tool is particularly useful for replicating training conditions or
|
|
fine-tuning existing models based on successful previous runs.
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import argparse
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Update training script based on metadata.'
|
|
)
|
|
parser.add_argument(
|
|
'metadata_file', type=str, help='Path to the metadata JSON file'
|
|
)
|
|
parser.add_argument(
|
|
'script_file', type=str, help='Path to the training script file'
|
|
)
|
|
parser.add_argument(
|
|
'output_file', type=str, help='Path to save the updated training script'
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
|
|
with open(args.metadata_file, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
|
|
|
|
with open(args.script_file, 'r', encoding='utf-8') as f:
|
|
script_content = f.read()
|
|
|
|
|
|
mappings = {
|
|
'ss_network_dim': '--network_dim',
|
|
'ss_network_alpha': '--network_alpha',
|
|
'ss_learning_rate': '--learning_rate',
|
|
'ss_unet_lr': '--unet_lr',
|
|
'ss_text_encoder_lr': '--text_encoder_lr',
|
|
'ss_max_train_steps': '--max_train_steps',
|
|
'ss_train_batch_size': '--train_batch_size',
|
|
'ss_gradient_accumulation_steps': '--gradient_accumulation_steps',
|
|
'ss_mixed_precision': '--mixed_precision',
|
|
'ss_seed': '--seed',
|
|
'ss_resolution': '--resolution',
|
|
'ss_clip_skip': '--clip_skip',
|
|
'ss_lr_scheduler': '--lr_scheduler',
|
|
'ss_network_module': '--network_module',
|
|
}
|
|
|
|
|
|
for json_key, script_arg in mappings.items():
|
|
if json_key in metadata:
|
|
value = metadata[json_key]
|
|
|
|
|
|
if json_key == 'ss_resolution':
|
|
value = f'"{value[1:-1]}"'
|
|
elif isinstance(value, str):
|
|
value = f'"{value}"'
|
|
|
|
|
|
pattern = f'{script_arg}=\\S+'
|
|
replacement = f'{script_arg}={value}'
|
|
if re.search(pattern, script_content):
|
|
script_content = re.sub(pattern, replacement, script_content)
|
|
else:
|
|
script_content = script_content.replace(
|
|
'args=(', f'args=(\n {replacement}'
|
|
)
|
|
|
|
|
|
if 'ss_network_args' in metadata:
|
|
network_args = metadata['ss_network_args']
|
|
NETWORK_ARGS_STR = ' '.join(
|
|
[f'"{k}={v}"' for k, v in network_args.items()]
|
|
)
|
|
PATTERN = r'--network_args(\s+".+")+'
|
|
replacement = f'--network_args\n {NETWORK_ARGS_STR}'
|
|
script_content = re.sub(PATTERN, replacement, script_content)
|
|
|
|
|
|
with open(args.output_file, 'w', encoding='utf-8') as f:
|
|
f.write(script_content)
|
|
|
|
print(f"Updated training script has been saved as '{args.output_file}'")
|
|
|