File size: 1,973 Bytes
19b3da3
1bc457e
 
19b3da3
 
 
1bc457e
19b3da3
 
 
 
b71808f
 
1bc457e
19b3da3
86248f3
 
1bc457e
 
 
 
 
 
 
 
 
 
 
 
19b3da3
b71808f
 
 
 
 
19b3da3
 
 
 
 
 
 
 
1bc457e
19b3da3
1bc457e
 
19b3da3
 
 
 
 
b71808f
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71808f
 
 
 
 
19b3da3
 
 
 
 
 
 
 
f1235a4
19b3da3
5e62aa8
19b3da3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from pathlib import Path
from typing import Union

from internals.data.task import Task

env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_dir = ""
hf_token = "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"
hf_cache_dir = "/tmp/hf_hub"

num_return_sequences = 4  # the number of results to generate

os.makedirs(hf_cache_dir, exist_ok=True)


def set_hf_cache_dir(dir: Union[str, Path]):
    global hf_cache_dir
    hf_cache_dir = str(dir)


def get_hf_cache_dir():
    global hf_cache_dir
    return hf_cache_dir


def set_model_dir(dir: str):
    global model_dir
    model_dir = dir


def set_root_dir(main_file: str):
    global root_dir
    root_dir = os.path.dirname(os.path.abspath(main_file))


def set_configs_from_task(task: Task):
    global env, nsfw_threshold, nsfw_access, access_token
    name = task.get_queue_name()
    if name.startswith("gamma"):
        env = "gamma"
    else:
        env = "prod"
    nsfw_threshold = task.get_nsfw_threshold()
    nsfw_access = task.can_access_nsfw()
    access_token = task.get_access_token()


def get_model_dir():
    global model_dir
    return model_dir


def get_root_dir():
    global root_dir
    return root_dir


def get_environment():
    global env
    return env


def get_nsfw_threshold():
    global nsfw_threshold
    return nsfw_threshold


def get_nsfw_access():
    global nsfw_access
    return nsfw_access


def get_hf_token():
    global hf_token
    return hf_token


def api_headers():
    return {
        "Access-Token": access_token,
    }


def api_endpoint():
    if env == "prod":
        return "https://api.autodraft.in"
    else:
        return "https://gamma-api.autodraft.in"


def comic_url():
    if env == "prod":
        return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
    else:
        return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"