File size: 2,672 Bytes
e6a4021
19b3da3
1bc457e
 
19b3da3
 
fd5252e
19b3da3
1bc457e
19b3da3
 
 
 
fd5252e
e6a4021
 
 
1bc457e
9387217
a3f5c82
19b3da3
86248f3
 
1bc457e
 
 
 
 
 
 
 
 
 
 
 
19b3da3
 
 
 
 
 
fd5252e
 
 
 
 
19b3da3
a3f5c82
19b3da3
1bc457e
19b3da3
1bc457e
 
19b3da3
 
 
a3f5c82
19b3da3
 
b71808f
fd5252e
 
 
 
 
 
 
b71808f
 
a3f5c82
9387217
 
 
 
a3f5c82
 
10230ea
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71808f
 
 
 
 
19b3da3
 
 
 
 
 
 
 
f1235a4
19b3da3
5e62aa8
19b3da3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import base64
import os
from pathlib import Path
from typing import Union

from internals.data.task import Task
from internals.util.model_loader import ModelConfig

env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = base64.b64decode(
    b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
).decode()
hf_cache_dir = "/tmp/hf_hub"

base_dimension = 512  # needed for high res

num_return_sequences = 4  # the number of results to generate

os.makedirs(hf_cache_dir, exist_ok=True)


def set_hf_cache_dir(dir: Union[str, Path]):
    global hf_cache_dir
    hf_cache_dir = str(dir)


def get_hf_cache_dir():
    global hf_cache_dir
    return hf_cache_dir


def set_root_dir(main_file: str):
    global root_dir
    root_dir = os.path.dirname(os.path.abspath(main_file))


def set_model_config(config: ModelConfig):
    global model_config
    model_config = config


def set_configs_from_task(task: Task):
    global env, nsfw_threshold, nsfw_access, access_token, base_dimension
    name = task.get_queue_name()
    if name.startswith("gamma"):
        env = "gamma"
    else:
        env = "prod"
    nsfw_threshold = task.get_nsfw_threshold()
    nsfw_access = task.can_access_nsfw()
    access_token = task.get_access_token()
    base_dimension = task.get_base_dimension()


def get_model_dir():
    global model_config
    return model_config.base_model_path  # pyright: ignore


def get_inpaint_model_path():
    global model_config
    return model_config.base_inpaint_model_path  # pyright: ignore


def get_base_dimension():
    global global_base_dimension, base_dimension
    if base_dimension:
        return base_dimension
    return model_config.base_dimension  # pyright: ignore


def get_is_sdxl():
    global model_config
    return model_config.is_sdxl  # pyright: ignore


def get_root_dir():
    global root_dir
    return root_dir


def get_environment():
    global env
    return env


def get_nsfw_threshold():
    global nsfw_threshold
    return nsfw_threshold


def get_nsfw_access():
    global nsfw_access
    return nsfw_access


def get_hf_token():
    global hf_token
    return hf_token


def api_headers():
    return {
        "Access-Token": access_token,
    }


def api_endpoint():
    if env == "prod":
        return "https://api.autodraft.in"
    else:
        return "https://gamma-api.autodraft.in"


def comic_url():
    if env == "prod":
        return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
    else:
        return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"