MonsterMMORPG commited on
Commit
e2e62c9
·
verified ·
1 Parent(s): 59db811

Upload kohya_gui_kaggle.py

Browse files
Files changed (1) hide show
  1. kohya_gui_kaggle.py +51 -57
kohya_gui_kaggle.py CHANGED
@@ -4,57 +4,56 @@ import argparse
4
  from dreambooth_gui import dreambooth_tab
5
  from finetune_gui import finetune_tab
6
  from textual_inversion_gui import ti_tab
7
- from library.utilities import utilities_tab
8
  from lora_gui import lora_tab
9
- from library.class_lora_tab import LoRATools
10
 
11
- import os
12
- from library.custom_logging import setup_logging
13
- from library.localization_ext import add_javascript
14
 
15
  # Set up logging
16
  log = setup_logging()
17
 
18
 
19
  def UI(**kwargs):
20
- add_javascript(kwargs.get('language'))
21
- css = ''
22
 
23
- headless = kwargs.get('headless', False)
24
- log.info(f'headless: {headless}')
25
 
26
- if os.path.exists('./style.css'):
27
- with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
28
- log.info('Load CSS...')
29
- css += file.read() + '\n'
30
 
31
- if os.path.exists('./.release'):
32
- with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
33
  release = file.read()
34
 
35
- if os.path.exists('./README.md'):
36
- with open(os.path.join('./README.md'), 'r', encoding='utf8') as file:
37
  README = file.read()
38
 
39
  interface = gr.Blocks(
40
- css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default()
41
  )
42
 
43
  with interface:
44
- with gr.Tab('Dreambooth'):
45
  (
46
  train_data_dir_input,
47
  reg_data_dir_input,
48
  output_dir_input,
49
  logging_dir_input,
50
  ) = dreambooth_tab(headless=headless)
51
- with gr.Tab('LoRA'):
52
  lora_tab(headless=headless)
53
- with gr.Tab('Textual Inversion'):
54
  ti_tab(headless=headless)
55
- with gr.Tab('Finetuning'):
56
  finetune_tab(headless=headless)
57
- with gr.Tab('Utilities'):
58
  utilities_tab(
59
  train_data_dir_input=train_data_dir_input,
60
  reg_data_dir_input=reg_data_dir_input,
@@ -63,11 +62,11 @@ def UI(**kwargs):
63
  enable_copy_info_button=True,
64
  headless=headless,
65
  )
66
- with gr.Tab('LoRA'):
67
  _ = LoRATools(headless=headless)
68
- with gr.Tab('About'):
69
- gr.Markdown(f'kohya_ss GUI release {release}')
70
- with gr.Tab('README'):
71
  gr.Markdown(README)
72
 
73
  htmlStr = f"""
@@ -80,62 +79,57 @@ def UI(**kwargs):
80
  gr.HTML(htmlStr)
81
  # Show the interface
82
  launch_kwargs = {}
83
- username = kwargs.get('username')
84
- password = kwargs.get('password')
85
- server_port = kwargs.get('server_port', 0)
86
- inbrowser = kwargs.get('inbrowser', False)
87
  share = False
88
- server_name = kwargs.get('listen')
89
 
90
- launch_kwargs['server_name'] = server_name
91
  if username and password:
92
- launch_kwargs['auth'] = (username, password)
93
  if server_port > 0:
94
- launch_kwargs['server_port'] = server_port
95
  if inbrowser:
96
- launch_kwargs['inbrowser'] = inbrowser
97
  if share:
98
- launch_kwargs['share'] = share
99
- interface.launch(**launch_kwargs, share=False)
 
100
 
101
 
102
- if __name__ == '__main__':
103
  # torch.cuda.set_per_process_memory_fraction(0.48)
104
  parser = argparse.ArgumentParser()
105
  parser.add_argument(
106
- '--listen',
107
  type=str,
108
- default='127.0.0.1',
109
- help='IP to listen on for connections to Gradio',
110
  )
111
  parser.add_argument(
112
- '--username', type=str, default='', help='Username for authentication'
113
  )
114
  parser.add_argument(
115
- '--password', type=str, default='', help='Password for authentication'
116
  )
117
  parser.add_argument(
118
- '--server_port',
119
  type=int,
120
  default=0,
121
- help='Port to run the server listener on',
122
- )
123
- parser.add_argument(
124
- '--inbrowser', action='store_true', help='Open in browser'
125
  )
 
 
126
  parser.add_argument(
127
- '--share', action='store_true', help='Share the gradio UI'
128
  )
129
  parser.add_argument(
130
- '--headless', action='store_true', help='Is the server headless'
131
- )
132
- parser.add_argument(
133
- '--language', type=str, default=None, help='Set custom language'
134
  )
135
 
136
- parser.add_argument(
137
- '--use-ipex', action='store_true', help='Use IPEX environment'
138
- )
139
 
140
  args = parser.parse_args()
141
 
 
4
  from dreambooth_gui import dreambooth_tab
5
  from finetune_gui import finetune_tab
6
  from textual_inversion_gui import ti_tab
7
+ from kohya_gui.utilities import utilities_tab
8
  from lora_gui import lora_tab
9
+ from kohya_gui.class_lora_tab import LoRATools
10
 
11
+ from kohya_gui.custom_logging import setup_logging
12
+ from kohya_gui.localization_ext import add_javascript
 
13
 
14
  # Set up logging
15
  log = setup_logging()
16
 
17
 
18
  def UI(**kwargs):
19
+ add_javascript(kwargs.get("language"))
20
+ css = ""
21
 
22
+ headless = kwargs.get("headless", False)
23
+ log.info(f"headless: {headless}")
24
 
25
+ if os.path.exists("./style.css"):
26
+ with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
27
+ log.info("Load CSS...")
28
+ css += file.read() + "\n"
29
 
30
+ if os.path.exists("./.release"):
31
+ with open(os.path.join("./.release"), "r", encoding="utf8") as file:
32
  release = file.read()
33
 
34
+ if os.path.exists("./README.md"):
35
+ with open(os.path.join("./README.md"), "r", encoding="utf8") as file:
36
  README = file.read()
37
 
38
  interface = gr.Blocks(
39
+ css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
40
  )
41
 
42
  with interface:
43
+ with gr.Tab("Dreambooth"):
44
  (
45
  train_data_dir_input,
46
  reg_data_dir_input,
47
  output_dir_input,
48
  logging_dir_input,
49
  ) = dreambooth_tab(headless=headless)
50
+ with gr.Tab("LoRA"):
51
  lora_tab(headless=headless)
52
+ with gr.Tab("Textual Inversion"):
53
  ti_tab(headless=headless)
54
+ with gr.Tab("Finetuning"):
55
  finetune_tab(headless=headless)
56
+ with gr.Tab("Utilities"):
57
  utilities_tab(
58
  train_data_dir_input=train_data_dir_input,
59
  reg_data_dir_input=reg_data_dir_input,
 
62
  enable_copy_info_button=True,
63
  headless=headless,
64
  )
65
+ with gr.Tab("LoRA"):
66
  _ = LoRATools(headless=headless)
67
+ with gr.Tab("About"):
68
+ gr.Markdown(f"kohya_ss GUI release {release}")
69
+ with gr.Tab("README"):
70
  gr.Markdown(README)
71
 
72
  htmlStr = f"""
 
79
  gr.HTML(htmlStr)
80
  # Show the interface
81
  launch_kwargs = {}
82
+ username = kwargs.get("username")
83
+ password = kwargs.get("password")
84
+ server_port = kwargs.get("server_port", 0)
85
+ inbrowser = kwargs.get("inbrowser", False)
86
  share = False
87
+ server_name = kwargs.get("listen")
88
 
89
+ launch_kwargs["server_name"] = server_name
90
  if username and password:
91
+ launch_kwargs["auth"] = (username, password)
92
  if server_port > 0:
93
+ launch_kwargs["server_port"] = server_port
94
  if inbrowser:
95
+ launch_kwargs["inbrowser"] = inbrowser
96
  if share:
97
+ launch_kwargs["share"] = False
98
+ launch_kwargs["debug"] = True
99
+ interface.launch(**launch_kwargs)
100
 
101
 
102
+ if __name__ == "__main__":
103
  # torch.cuda.set_per_process_memory_fraction(0.48)
104
  parser = argparse.ArgumentParser()
105
  parser.add_argument(
106
+ "--listen",
107
  type=str,
108
+ default="127.0.0.1",
109
+ help="IP to listen on for connections to Gradio",
110
  )
111
  parser.add_argument(
112
+ "--username", type=str, default="", help="Username for authentication"
113
  )
114
  parser.add_argument(
115
+ "--password", type=str, default="", help="Password for authentication"
116
  )
117
  parser.add_argument(
118
+ "--server_port",
119
  type=int,
120
  default=0,
121
+ help="Port to run the server listener on",
 
 
 
122
  )
123
+ parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
124
+ parser.add_argument("--share", action="store_true", help="Share the gradio UI")
125
  parser.add_argument(
126
+ "--headless", action="store_true", help="Is the server headless"
127
  )
128
  parser.add_argument(
129
+ "--language", type=str, default=None, help="Set custom language"
 
 
 
130
  )
131
 
132
+ parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
 
 
133
 
134
  args = parser.parse_args()
135