File size: 2,500 Bytes
19b3da3 1bc457e 19b3da3 fd5252e 19b3da3 1bc457e 19b3da3 fd5252e b71808f 1bc457e a3f5c82 19b3da3 86248f3 1bc457e 19b3da3 fd5252e 19b3da3 a3f5c82 19b3da3 1bc457e 19b3da3 1bc457e 19b3da3 a3f5c82 19b3da3 b71808f fd5252e b71808f a3f5c82 10230ea 19b3da3 b71808f 19b3da3 f1235a4 19b3da3 5e62aa8 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
from pathlib import Path
from typing import Union
from internals.data.task import Task
from internals.util.model_loader import ModelConfig
env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = "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"
hf_cache_dir = "/tmp/hf_hub"
base_dimension = 512 # needed for high res
num_return_sequences = 4 # the number of results to generate
os.makedirs(hf_cache_dir, exist_ok=True)
def set_hf_cache_dir(dir: Union[str, Path]):
global hf_cache_dir
hf_cache_dir = str(dir)
def get_hf_cache_dir():
global hf_cache_dir
return hf_cache_dir
def set_root_dir(main_file: str):
global root_dir
root_dir = os.path.dirname(os.path.abspath(main_file))
def set_model_config(config: ModelConfig):
global model_config
model_config = config
def set_configs_from_task(task: Task):
global env, nsfw_threshold, nsfw_access, access_token, base_dimension
name = task.get_queue_name()
if name.startswith("gamma"):
env = "gamma"
else:
env = "prod"
nsfw_threshold = task.get_nsfw_threshold()
nsfw_access = task.can_access_nsfw()
access_token = task.get_access_token()
base_dimension = task.get_base_dimension()
def get_model_dir():
global model_config
return model_config.base_model_path # pyright: ignore
def get_inpaint_model_path():
global model_config
return model_config.base_inpaint_model_path # pyright: ignore
def get_base_dimension():
global base_dimension
return base_dimension
def get_is_sdxl():
global model_config
return model_config.is_sdxl # pyright: ignore
def get_root_dir():
global root_dir
return root_dir
def get_environment():
global env
return env
def get_nsfw_threshold():
global nsfw_threshold
return nsfw_threshold
def get_nsfw_access():
global nsfw_access
return nsfw_access
def get_hf_token():
global hf_token
return hf_token
def api_headers():
return {
"Access-Token": access_token,
}
def api_endpoint():
if env == "prod":
return "https://api.autodraft.in"
else:
return "https://gamma-api.autodraft.in"
def comic_url():
if env == "prod":
return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
else:
return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
|