File size: 2,605 Bytes
0163a2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import json
import os
import time
import gradio.routes
import scripts.runner as runner
import scripts.shared as shared
from scripts.shared import ROOT_DIR, is_webui_extension
from scripts.ui import create_ui
def create_js():
jsfile = os.path.join(ROOT_DIR, "script.js")
with open(jsfile, mode="r") as f:
js = f.read()
js = js.replace("kohya_sd_webui__help_map", json.dumps(shared.help_title_map))
js = js.replace(
"kohya_sd_webui__all_tabs",
json.dumps(shared.loaded_tabs),
)
return js
def create_head():
head = f'<script type="text/javascript">{create_js()}</script>'
def template_response_for_webui(*args, **kwargs):
res = shared.gradio_template_response_original(*args, **kwargs)
res.body = res.body.replace(b"</head>", f"{head}</head>".encode("utf8"))
return res
def template_response(*args, **kwargs):
res = template_response_for_webui(*args, **kwargs)
res.init_headers()
return res
if is_webui_extension():
import modules.shared
modules.shared.GradioTemplateResponseOriginal = template_response_for_webui
else:
gradio.routes.templates.TemplateResponse = template_response
def wait_on_server():
while 1:
time.sleep(0.5)
def on_ui_tabs():
cssfile = os.path.join(ROOT_DIR, "style.css")
with open(cssfile, mode="r") as f:
css = f.read()
sd_scripts = create_ui(css)
create_head()
return [(sd_scripts, "Kohya sd-scripts", "kohya_sd_scripts")]
def launch():
block, _, _ = on_ui_tabs()[0]
if shared.cmd_opts.ngrok is not None:
import scripts.ngrok as ngrok
address = ngrok.connect(
shared.cmd_opts.ngrok,
shared.cmd_opts.port if shared.cmd_opts.port is not None else 7860,
shared.cmd_opts.ngrok_region,
)
print("Running on ngrok URL: " + address)
app, local_url, share_url = block.launch(
share=shared.cmd_opts.share,
server_port=shared.cmd_opts.port,
server_name=shared.cmd_opts.host,
prevent_thread_lock=True,
)
runner.initialize_api(app)
wait_on_server()
if not hasattr(shared, "gradio_template_response_original"):
shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse
if is_webui_extension():
from modules import script_callbacks
def initialize_api(_, app):
runner.initialize_api(app)
script_callbacks.on_ui_tabs(on_ui_tabs)
script_callbacks.on_app_started(initialize_api)
if __name__ == "__main__":
launch()
|