|
import json |
|
import os |
|
import time |
|
|
|
import gradio.routes |
|
|
|
import scripts.runner as runner |
|
import scripts.shared as shared |
|
from scripts.shared import ROOT_DIR, is_webui_extension |
|
from scripts.ui import create_ui |
|
|
|
|
|
def create_js(): |
|
jsfile = os.path.join(ROOT_DIR, "script.js") |
|
with open(jsfile, mode="r") as f: |
|
js = f.read() |
|
|
|
js = js.replace("kohya_sd_webui__help_map", json.dumps(shared.help_title_map)) |
|
js = js.replace( |
|
"kohya_sd_webui__all_tabs", |
|
json.dumps(shared.loaded_tabs), |
|
) |
|
return js |
|
|
|
|
|
def create_head(): |
|
head = f'<script type="text/javascript">{create_js()}</script>' |
|
|
|
def template_response_for_webui(*args, **kwargs): |
|
res = shared.gradio_template_response_original(*args, **kwargs) |
|
res.body = res.body.replace(b"</head>", f"{head}</head>".encode("utf8")) |
|
return res |
|
|
|
def template_response(*args, **kwargs): |
|
res = template_response_for_webui(*args, **kwargs) |
|
res.init_headers() |
|
return res |
|
|
|
if is_webui_extension(): |
|
import modules.shared |
|
|
|
modules.shared.GradioTemplateResponseOriginal = template_response_for_webui |
|
else: |
|
gradio.routes.templates.TemplateResponse = template_response |
|
|
|
|
|
def wait_on_server(): |
|
while 1: |
|
time.sleep(0.5) |
|
|
|
|
|
def on_ui_tabs(): |
|
cssfile = os.path.join(ROOT_DIR, "style.css") |
|
with open(cssfile, mode="r") as f: |
|
css = f.read() |
|
sd_scripts = create_ui(css) |
|
create_head() |
|
return [(sd_scripts, "Kohya sd-scripts", "kohya_sd_scripts")] |
|
|
|
|
|
def launch(): |
|
block, _, _ = on_ui_tabs()[0] |
|
if shared.cmd_opts.ngrok is not None: |
|
import scripts.ngrok as ngrok |
|
|
|
address = ngrok.connect( |
|
shared.cmd_opts.ngrok, |
|
shared.cmd_opts.port if shared.cmd_opts.port is not None else 7860, |
|
shared.cmd_opts.ngrok_region, |
|
) |
|
print("Running on ngrok URL: " + address) |
|
|
|
app, local_url, share_url = block.launch( |
|
share=shared.cmd_opts.share, |
|
server_port=shared.cmd_opts.port, |
|
server_name=shared.cmd_opts.host, |
|
prevent_thread_lock=True, |
|
) |
|
|
|
runner.initialize_api(app) |
|
|
|
wait_on_server() |
|
|
|
|
|
if not hasattr(shared, "gradio_template_response_original"): |
|
shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse |
|
|
|
if is_webui_extension(): |
|
from modules import script_callbacks |
|
|
|
def initialize_api(_, app): |
|
runner.initialize_api(app) |
|
|
|
script_callbacks.on_ui_tabs(on_ui_tabs) |
|
script_callbacks.on_app_started(initialize_api) |
|
|
|
if __name__ == "__main__": |
|
launch() |
|
|