import base64
import os
from pathlib import Path
from typing import Union

from internals.data.task import Task
from internals.util.model_loader import ModelConfig

env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = base64.b64decode(
    b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
).decode()
hf_cache_dir = "/tmp/hf_hub"

base_dimension = 512  # needed for high res

num_return_sequences = 4  # the number of results to generate

os.makedirs(hf_cache_dir, exist_ok=True)


def set_hf_cache_dir(dir: Union[str, Path]):
    global hf_cache_dir
    hf_cache_dir = str(dir)


def get_hf_cache_dir():
    global hf_cache_dir
    return hf_cache_dir


def set_root_dir(main_file: str):
    global root_dir
    root_dir = os.path.dirname(os.path.abspath(main_file))


def set_model_config(config: ModelConfig):
    global model_config
    model_config = config


def set_configs_from_task(task: Task):
    global env, nsfw_threshold, nsfw_access, access_token, base_dimension
    name = task.get_queue_name()
    if name.startswith("gamma"):
        env = "gamma"
    else:
        env = "prod"
    nsfw_threshold = task.get_nsfw_threshold()
    nsfw_access = task.can_access_nsfw()
    access_token = task.get_access_token()
    base_dimension = task.get_base_dimension()


def get_model_dir():
    global model_config
    return model_config.base_model_path  # pyright: ignore


def get_inpaint_model_path():
    global model_config
    return model_config.base_inpaint_model_path  # pyright: ignore


def get_base_dimension():
    global global_base_dimension, base_dimension
    if base_dimension:
        return base_dimension
    return model_config.base_dimension  # pyright: ignore


def get_is_sdxl():
    global model_config
    return model_config.is_sdxl  # pyright: ignore


def get_root_dir():
    global root_dir
    return root_dir


def get_environment():
    global env
    return env


def get_nsfw_threshold():
    global nsfw_threshold
    return nsfw_threshold


def get_nsfw_access():
    global nsfw_access
    return nsfw_access


def get_hf_token():
    global hf_token
    return hf_token


def api_headers():
    return {
        "Access-Token": access_token,
    }


def api_endpoint():
    if env == "prod":
        return "https://api.autodraft.in"
    else:
        return "https://gamma-api.autodraft.in"


def comic_url():
    if env == "prod":
        return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
    else:
        return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"