File size: 2,500 Bytes
19b3da3
1bc457e
 
19b3da3
 
fd5252e
19b3da3
1bc457e
19b3da3
 
 
 
fd5252e
b71808f
1bc457e
a3f5c82
19b3da3
86248f3
 
1bc457e
 
 
 
 
 
 
 
 
 
 
 
19b3da3
 
 
 
 
 
fd5252e
 
 
 
 
19b3da3
a3f5c82
19b3da3
1bc457e
19b3da3
1bc457e
 
19b3da3
 
 
a3f5c82
19b3da3
 
b71808f
fd5252e
 
 
 
 
 
 
b71808f
 
a3f5c82
 
 
 
 
10230ea
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71808f
 
 
 
 
19b3da3
 
 
 
 
 
 
 
f1235a4
19b3da3
5e62aa8
19b3da3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from pathlib import Path
from typing import Union

from internals.data.task import Task
from internals.util.model_loader import ModelConfig

env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = "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"
hf_cache_dir = "/tmp/hf_hub"
base_dimension = 512  # needed for high res

num_return_sequences = 4  # the number of results to generate

os.makedirs(hf_cache_dir, exist_ok=True)


def set_hf_cache_dir(dir: Union[str, Path]):
    global hf_cache_dir
    hf_cache_dir = str(dir)


def get_hf_cache_dir():
    global hf_cache_dir
    return hf_cache_dir


def set_root_dir(main_file: str):
    global root_dir
    root_dir = os.path.dirname(os.path.abspath(main_file))


def set_model_config(config: ModelConfig):
    global model_config
    model_config = config


def set_configs_from_task(task: Task):
    global env, nsfw_threshold, nsfw_access, access_token, base_dimension
    name = task.get_queue_name()
    if name.startswith("gamma"):
        env = "gamma"
    else:
        env = "prod"
    nsfw_threshold = task.get_nsfw_threshold()
    nsfw_access = task.can_access_nsfw()
    access_token = task.get_access_token()
    base_dimension = task.get_base_dimension()


def get_model_dir():
    global model_config
    return model_config.base_model_path  # pyright: ignore


def get_inpaint_model_path():
    global model_config
    return model_config.base_inpaint_model_path  # pyright: ignore


def get_base_dimension():
    global base_dimension
    return base_dimension


def get_is_sdxl():
    global model_config
    return model_config.is_sdxl  # pyright: ignore


def get_root_dir():
    global root_dir
    return root_dir


def get_environment():
    global env
    return env


def get_nsfw_threshold():
    global nsfw_threshold
    return nsfw_threshold


def get_nsfw_access():
    global nsfw_access
    return nsfw_access


def get_hf_token():
    global hf_token
    return hf_token


def api_headers():
    return {
        "Access-Token": access_token,
    }


def api_endpoint():
    if env == "prod":
        return "https://api.autodraft.in"
    else:
        return "https://gamma-api.autodraft.in"


def comic_url():
    if env == "prod":
        return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
    else:
        return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"