File size: 2,605 Bytes
0163a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import json
import os
import time

import gradio.routes

import scripts.runner as runner
import scripts.shared as shared
from scripts.shared import ROOT_DIR, is_webui_extension
from scripts.ui import create_ui


def create_js():
    jsfile = os.path.join(ROOT_DIR, "script.js")
    with open(jsfile, mode="r") as f:
        js = f.read()

    js = js.replace("kohya_sd_webui__help_map", json.dumps(shared.help_title_map))
    js = js.replace(
        "kohya_sd_webui__all_tabs",
        json.dumps(shared.loaded_tabs),
    )
    return js


def create_head():
    head = f'<script type="text/javascript">{create_js()}</script>'

    def template_response_for_webui(*args, **kwargs):
        res = shared.gradio_template_response_original(*args, **kwargs)
        res.body = res.body.replace(b"</head>", f"{head}</head>".encode("utf8"))
        return res

    def template_response(*args, **kwargs):
        res = template_response_for_webui(*args, **kwargs)
        res.init_headers()
        return res

    if is_webui_extension():
        import modules.shared

        modules.shared.GradioTemplateResponseOriginal = template_response_for_webui
    else:
        gradio.routes.templates.TemplateResponse = template_response


def wait_on_server():
    while 1:
        time.sleep(0.5)


def on_ui_tabs():
    cssfile = os.path.join(ROOT_DIR, "style.css")
    with open(cssfile, mode="r") as f:
        css = f.read()
    sd_scripts = create_ui(css)
    create_head()
    return [(sd_scripts, "Kohya sd-scripts", "kohya_sd_scripts")]


def launch():
    block, _, _ = on_ui_tabs()[0]
    if shared.cmd_opts.ngrok is not None:
        import scripts.ngrok as ngrok

        address = ngrok.connect(
            shared.cmd_opts.ngrok,
            shared.cmd_opts.port if shared.cmd_opts.port is not None else 7860,
            shared.cmd_opts.ngrok_region,
        )
        print("Running on ngrok URL: " + address)

    app, local_url, share_url = block.launch(
        share=shared.cmd_opts.share,
        server_port=shared.cmd_opts.port,
        server_name=shared.cmd_opts.host,
        prevent_thread_lock=True,
    )

    runner.initialize_api(app)

    wait_on_server()


if not hasattr(shared, "gradio_template_response_original"):
    shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse

if is_webui_extension():
    from modules import script_callbacks

    def initialize_api(_, app):
        runner.initialize_api(app)

    script_callbacks.on_ui_tabs(on_ui_tabs)
    script_callbacks.on_app_started(initialize_api)

if __name__ == "__main__":
    launch()