File size: 3,701 Bytes
cd47483
 
 
 
 
5ac0c97
 
 
 
f5ab4cb
 
 
 
 
 
 
 
 
 
 
8dfc799
 
 
f5ab4cb
 
 
 
32d8669
8dfc799
f5ab4cb
 
 
 
8dfc799
 
 
f5ab4cb
 
 
 
 
8dfc799
f5ab4cb
 
 
cd47483
ab34078
cd47483
 
 
 
1bff30e
f5ab4cb
 
 
 
 
cd47483
 
4106f96
 
62bb2f6
 
f5ab4cb
 
a0cefd0
f5ab4cb
a0cefd0
f5ab4cb
a0cefd0
 
3b90025
85b97c4
 
 
62bb2f6
 
85b97c4
 
 
62bb2f6
 
4106f96
f5ab4cb
cd47483
4106f96
 
cd47483
2841b26
cd47483
 
 
5ac0c97
 
 
cd47483
f5ab4cb
 
 
 
 
 
cd47483
ab34078
 
cd47483
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import warnings

import argilla as rg

# Tasks
TEXTCAT_TASK = "text_classification"
SFT_TASK = "supervised_fine_tuning"

# Inference
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))

# Models
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")

# check if model is set correctly
if HUGGINGFACE_BASE_URL and MODEL:
    raise ValueError(
        "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
    )
if not MODEL:
    if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
        raise ValueError("`MODEL` is not set. Please provide a model id for inference.")

# Check if multiple base URLs are provided
base_urls = [
    url
    for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
    if url
]
if len(base_urls) > 1:
    raise ValueError(
        f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
    )
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL


# API Keys
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError(
        "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
    )

_API_KEY = os.getenv("API_KEY")
API_KEYS = (
    [_API_KEY]
    if _API_KEY
    else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
)
API_KEYS = [token for token in API_KEYS if token]

# Determine if SFT is available
SFT_AVAILABLE = False
llama_options = ["llama3", "llama-3", "llama 3"]
qwen_options = ["qwen2", "qwen-2", "qwen 2"]

if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
    SFT_AVAILABLE = True
    if passed_pre_query_template in llama_options:
        MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
    elif passed_pre_query_template in qwen_options:
        MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
    else:
        MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
elif MODEL.lower() in llama_options or any(
    option in MODEL.lower() for option in llama_options
):
    SFT_AVAILABLE = True
    MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options or any(
    option in MODEL.lower() for option in qwen_options
):
    SFT_AVAILABLE = True
    MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"

if OPENAI_BASE_URL:
    SFT_AVAILABLE = False

if not SFT_AVAILABLE:
    warnings.warn(
        "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
    )
    MAGPIE_PRE_QUERY_TEMPLATE = None

# Embeddings
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"

# Argilla
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
    "ARGILLA_API_URL_SDG_REVIEWER"
)
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
    "ARGILLA_API_KEY_SDG_REVIEWER"
)

if not ARGILLA_API_URL or not ARGILLA_API_KEY:
    warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
    argilla_client = None
else:
    argilla_client = rg.Argilla(
        api_url=ARGILLA_API_URL,
        api_key=ARGILLA_API_KEY,
    )