akoyaki commited on
Commit
49cf8a5
1 Parent(s): 06543b3

Upload script.py

Browse files
Files changed (1) hide show
  1. script.py +129 -0
script.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import subprocess
3
+ from safetensors.torch import load_file
4
+ from diffusers import AutoPipelineForText2Image
5
+ from datasets import load_dataset
6
+ from huggingface_hub.repocard import RepoCard
7
+ from huggingface_hub import HfApi
8
+ import torch
9
+ import re
10
+ import argparse
11
+ import os
12
+ import zipfile
13
+
14
+ def do_preprocess(class_data_dir):
15
+ print("Unzipping dataset")
16
+ zip_file_path = f"{class_data_dir}/class_images.zip"
17
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
18
+ zip_ref.extractall(class_data_dir)
19
+ os.remove(zip_file_path)
20
+
21
+ def do_train(script_args):
22
+ # Pass all arguments to trainer.py
23
+ print("Starting training...")
24
+ result = subprocess.run(['python', 'trainer.py'] + script_args)
25
+ if result.returncode != 0:
26
+ raise Exception("Training failed.")
27
+
28
+ def replace_output_dir(text, output_dir, replacement):
29
+ # Define a pattern that matches the output_dir followed by whitespace, '/', new line, or "'"
30
+ # Add system name from HF only in the correct spots
31
+ pattern = rf"{output_dir}(?=[\s/'\n])"
32
+ return re.sub(pattern, replacement, text)
33
+
34
+ def do_inference(dataset_name, output_dir, num_tokens):
35
+ widget_content = []
36
+ try:
37
+ print("Starting inference to generate example images...")
38
+ dataset = load_dataset(dataset_name)
39
+ pipe = AutoPipelineForText2Image.from_pretrained(
40
+ "./lora-ease-wsl/tPonynai3_v6.safetensors", torch_dtype=torch.float16
41
+ )
42
+ pipe = pipe.to("cuda")
43
+ pipe.load_lora_weights(f'{output_dir}/pytorch_lora_weights.safetensors')
44
+
45
+ prompts = dataset["train"]["prompt"]
46
+ if(num_tokens > 0):
47
+ tokens_sequence = ''.join(f'<s{i}>' for i in range(num_tokens))
48
+ tokens_list = [f'<s{i}>' for i in range(num_tokens)]
49
+
50
+ state_dict = load_file(f"{output_dir}/{output_dir}_emb.safetensors")
51
+ pipe.load_textual_inversion(state_dict["clip_l"], token=tokens_list, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
52
+ pipe.load_textual_inversion(state_dict["clip_g"], token=tokens_list, text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
53
+
54
+ prompts = [prompt.replace("TOK", tokens_sequence) for prompt in prompts]
55
+
56
+ for i, prompt in enumerate(prompts):
57
+ image = pipe(prompt, num_inference_steps=25, guidance_scale=7.5).images[0]
58
+ filename = f"image-{i}.png"
59
+ image.save(f"{output_dir}/{filename}")
60
+ card_dict = {
61
+ "text": prompt,
62
+ "output": {
63
+ "url": filename
64
+ }
65
+ }
66
+ widget_content.append(card_dict)
67
+ except Exception as e:
68
+ print("Something went wrong with generating images, specifically: ", e)
69
+
70
+ try:
71
+ api = HfApi()
72
+ username = api.whoami()["name"]
73
+ repo_id = api.create_repo(f"{username}/{output_dir}", exist_ok=True, private=True).repo_id
74
+
75
+ with open(f'{output_dir}/README.md', 'r') as file:
76
+ readme_content = file.read()
77
+
78
+
79
+ readme_content = replace_output_dir(readme_content, output_dir, f"{username}/{output_dir}")
80
+
81
+ card = RepoCard(readme_content)
82
+ if widget_content:
83
+ card.data["widget"] = widget_content
84
+ card.save(f'{output_dir}/README.md')
85
+
86
+ print("Starting upload...")
87
+ api.upload_folder(
88
+ folder_path=output_dir,
89
+ repo_id=f"{username}/{output_dir}",
90
+ repo_type="model",
91
+ )
92
+ except Exception as e:
93
+ print("Something went wrong with uploading your model, specificaly: ", e)
94
+ else:
95
+ print("Upload finished!")
96
+
97
+ import sys
98
+ import argparse
99
+
100
+ def main():
101
+ # Capture all arguments except the script name
102
+ script_args = sys.argv[1:]
103
+
104
+ # Create the argument parser
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument('--dataset_name', required=True)
107
+ parser.add_argument('--output_dir', required=True)
108
+ parser.add_argument('--num_new_tokens_per_abstraction', type=int, default=0)
109
+ parser.add_argument('--train_text_encoder_ti', action='store_true')
110
+ parser.add_argument('--class_data_dir', help="Name of the class images dataset")
111
+
112
+ # Parse known arguments
113
+ args, _ = parser.parse_known_args(script_args)
114
+
115
+ # Set num_tokens to 0 if '--train_text_encoder_ti' is not present
116
+ if not args.train_text_encoder_ti:
117
+ args.num_new_tokens_per_abstraction = 0
118
+
119
+ # Proceed with training and inference
120
+ if args.class_data_dir:
121
+ do_preprocess(args.class_data_dir)
122
+ print("Pre-processing finished!")
123
+ do_train(script_args)
124
+ print("Training finished!")
125
+ do_inference(args.dataset_name, args.output_dir, args.num_new_tokens_per_abstraction)
126
+ print("All finished!")
127
+
128
+ if __name__ == "__main__":
129
+ main()