Refactor
Browse files- app_upload.py +8 -3
- trainer.py +19 -17
- uploader.py +0 -1
app_upload.py
CHANGED
@@ -13,9 +13,14 @@ from utils import find_exp_dirs
|
|
13 |
|
14 |
|
15 |
class LoRAModelUploader(Uploader):
|
16 |
-
def upload_lora_model(
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
if not repo_name:
|
20 |
repo_name = pathlib.Path(folder_path).name
|
21 |
repo_name = slugify.slugify(repo_name)
|
|
|
13 |
|
14 |
|
15 |
class LoRAModelUploader(Uploader):
|
16 |
+
def upload_lora_model(
|
17 |
+
self,
|
18 |
+
folder_path: str,
|
19 |
+
repo_name: str,
|
20 |
+
upload_to: str,
|
21 |
+
private: bool,
|
22 |
+
delete_existing_repo: bool,
|
23 |
+
) -> str:
|
24 |
if not repo_name:
|
25 |
repo_name = pathlib.Path(folder_path).name
|
26 |
repo_name = slugify.slugify(repo_name)
|
trainer.py
CHANGED
@@ -13,7 +13,7 @@ import slugify
|
|
13 |
import torch
|
14 |
from huggingface_hub import HfApi
|
15 |
|
16 |
-
from
|
17 |
|
18 |
|
19 |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
@@ -32,8 +32,8 @@ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
32 |
|
33 |
class Trainer:
|
34 |
def __init__(self, hf_token: str | None = None):
|
35 |
-
self.hf_token = hf_token
|
36 |
self.api = HfApi(token=hf_token)
|
|
|
37 |
|
38 |
def prepare_dataset(self, instance_images: list, resolution: int,
|
39 |
instance_data_dir: pathlib.Path) -> None:
|
@@ -91,8 +91,7 @@ class Trainer:
|
|
91 |
output_dir = repo_dir / 'experiments' / output_model_name
|
92 |
if overwrite_existing_model or upload_to_hub:
|
93 |
shutil.rmtree(output_dir, ignore_errors=True)
|
94 |
-
|
95 |
-
output_dir.mkdir(parents=True)
|
96 |
|
97 |
instance_data_dir = repo_dir / 'training_data' / output_model_name
|
98 |
self.prepare_dataset(instance_images, resolution, instance_data_dir)
|
@@ -121,16 +120,23 @@ class Trainer:
|
|
121 |
command += ' --use_8bit_adam'
|
122 |
if use_wandb:
|
123 |
command += ' --report_to wandb'
|
124 |
-
if upload_to_hub:
|
125 |
-
command += f' --push_to_hub --hub_token {self.hf_token}'
|
126 |
-
if use_private_repo:
|
127 |
-
command += ' --private_repo'
|
128 |
-
if delete_existing_repo:
|
129 |
-
command += ' --delete_existing_repo'
|
130 |
-
if upload_to == UploadTarget.LORA_LIBRARY.value:
|
131 |
-
command += ' --upload_to_lora_library'
|
132 |
|
|
|
|
|
|
|
133 |
subprocess.run(shlex.split(command))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
if remove_gpu_after_training:
|
136 |
space_id = os.getenv('SPACE_ID')
|
@@ -138,8 +144,4 @@ class Trainer:
|
|
138 |
self.api.request_space_hardware(repo_id=space_id,
|
139 |
hardware='cpu-basic')
|
140 |
|
141 |
-
|
142 |
-
command_s = ' '.join(command.split())
|
143 |
-
f.write(command_s)
|
144 |
-
|
145 |
-
return 'Training completed!'
|
|
|
13 |
import torch
|
14 |
from huggingface_hub import HfApi
|
15 |
|
16 |
+
from app_upload import LoRAModelUploader
|
17 |
|
18 |
|
19 |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
|
32 |
|
33 |
class Trainer:
|
34 |
def __init__(self, hf_token: str | None = None):
|
|
|
35 |
self.api = HfApi(token=hf_token)
|
36 |
+
self.model_uploader = LoRAModelUploader(hf_token)
|
37 |
|
38 |
def prepare_dataset(self, instance_images: list, resolution: int,
|
39 |
instance_data_dir: pathlib.Path) -> None:
|
|
|
91 |
output_dir = repo_dir / 'experiments' / output_model_name
|
92 |
if overwrite_existing_model or upload_to_hub:
|
93 |
shutil.rmtree(output_dir, ignore_errors=True)
|
94 |
+
output_dir.mkdir(parents=True)
|
|
|
95 |
|
96 |
instance_data_dir = repo_dir / 'training_data' / output_model_name
|
97 |
self.prepare_dataset(instance_images, resolution, instance_data_dir)
|
|
|
120 |
command += ' --use_8bit_adam'
|
121 |
if use_wandb:
|
122 |
command += ' --report_to wandb'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
with open(output_dir / 'train.sh', 'w') as f:
|
125 |
+
command_s = ' '.join(command.split())
|
126 |
+
f.write(command_s)
|
127 |
subprocess.run(shlex.split(command))
|
128 |
+
message = 'Training completed!'
|
129 |
+
print(message)
|
130 |
+
|
131 |
+
if upload_to_hub:
|
132 |
+
upload_message = self.model_uploader.upload_lora_model(
|
133 |
+
folder_path=output_dir.as_posix(),
|
134 |
+
repo_name=output_model_name,
|
135 |
+
upload_to=upload_to,
|
136 |
+
private=use_private_repo,
|
137 |
+
delete_existing_repo=delete_existing_repo)
|
138 |
+
print(upload_message)
|
139 |
+
message = message + '\n' + upload_message
|
140 |
|
141 |
if remove_gpu_after_training:
|
142 |
space_id = os.getenv('SPACE_ID')
|
|
|
144 |
self.api.request_space_hardware(repo_id=space_id,
|
145 |
hardware='cpu-basic')
|
146 |
|
147 |
+
return message
|
|
|
|
|
|
|
|
uploader.py
CHANGED
@@ -35,5 +35,4 @@ class Uploader:
|
|
35 |
message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
|
36 |
except Exception as e:
|
37 |
message = str(e)
|
38 |
-
print(message)
|
39 |
return message
|
|
|
35 |
message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
|
36 |
except Exception as e:
|
37 |
message = str(e)
|
|
|
38 |
return message
|