File size: 3,634 Bytes
9509d1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This script automates the process of updating a Stable Diffusion training
script with settings extracted from a LoRA model's JSON metadata.
It performs the following main tasks:
1. Reads a JSON file containing LoRA model metadata
2. Parses an existing Stable Diffusion training script
3. Maps metadata keys to corresponding script arguments
4. Updates the script with values from the metadata
5. Handles special cases and complex arguments (e.g., network_args)
6. Writes the updated script to a new file
Usage:
python steal_sdscripts_metadata <metadata_file> <script_file> <output_file>
This tool is particularly useful for replicating training conditions or
fine-tuning existing models based on successful previous runs.
"""
import json
import re
import argparse
# Parse command-line arguments
parser = argparse.ArgumentParser(
description='Update training script based on metadata.'
)
parser.add_argument(
'metadata_file', type=str, help='Path to the metadata JSON file'
)
parser.add_argument(
'script_file', type=str, help='Path to the training script file'
)
parser.add_argument(
'output_file', type=str, help='Path to save the updated training script'
)
args = parser.parse_args()
# Read the metadata JSON file
with open(args.metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Read the training script
with open(args.script_file, 'r', encoding='utf-8') as f:
script_content = f.read()
# Define mappings between JSON keys and script arguments
mappings = {
'ss_network_dim': '--network_dim',
'ss_network_alpha': '--network_alpha',
'ss_learning_rate': '--learning_rate',
'ss_unet_lr': '--unet_lr',
'ss_text_encoder_lr': '--text_encoder_lr',
'ss_max_train_steps': '--max_train_steps',
'ss_train_batch_size': '--train_batch_size',
'ss_gradient_accumulation_steps': '--gradient_accumulation_steps',
'ss_mixed_precision': '--mixed_precision',
'ss_seed': '--seed',
'ss_resolution': '--resolution',
'ss_clip_skip': '--clip_skip',
'ss_lr_scheduler': '--lr_scheduler',
'ss_network_module': '--network_module',
}
# Update script content based on metadata
for json_key, script_arg in mappings.items():
if json_key in metadata:
value = metadata[json_key]
# Handle special cases
if json_key == 'ss_resolution':
value = f'"{value[1:-1]}"' # Remove parentheses and add quotes
elif isinstance(value, str):
value = f'"{value}"'
# Replace or add the argument in the script
pattern = f'{script_arg}=\\S+'
replacement = f'{script_arg}={value}'
if re.search(pattern, script_content):
script_content = re.sub(pattern, replacement, script_content)
else:
script_content = script_content.replace(
'args=(', f'args=(\n {replacement}'
)
# Handle network_args separately as it's more complex
if 'ss_network_args' in metadata:
network_args = metadata['ss_network_args']
NETWORK_ARGS_STR = ' '.join(
[f'"{k}={v}"' for k, v in network_args.items()]
)
PATTERN = r'--network_args(\s+".+")+'
replacement = f'--network_args\n {NETWORK_ARGS_STR}'
script_content = re.sub(PATTERN, replacement, script_content)
# Write the updated script
with open(args.output_file, 'w', encoding='utf-8') as f:
f.write(script_content)
print(f"Updated training script has been saved as '{args.output_file}'")
|