File size: 3,701 Bytes
cd47483 5ac0c97 f5ab4cb 8dfc799 f5ab4cb 32d8669 8dfc799 f5ab4cb 8dfc799 f5ab4cb 8dfc799 f5ab4cb cd47483 ab34078 cd47483 1bff30e f5ab4cb cd47483 4106f96 62bb2f6 f5ab4cb a0cefd0 f5ab4cb a0cefd0 f5ab4cb a0cefd0 3b90025 85b97c4 62bb2f6 85b97c4 62bb2f6 4106f96 f5ab4cb cd47483 4106f96 cd47483 2841b26 cd47483 5ac0c97 cd47483 f5ab4cb cd47483 ab34078 cd47483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import warnings
import argilla as rg
# Tasks
TEXTCAT_TASK = "text_classification"
SFT_TASK = "supervised_fine_tuning"
# Inference
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
# Models
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
# check if model is set correctly
if HUGGINGFACE_BASE_URL and MODEL:
raise ValueError(
"`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
)
if not MODEL:
if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
# Check if multiple base URLs are provided
base_urls = [
url
for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
if url
]
if len(base_urls) > 1:
raise ValueError(
f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
)
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
# API Keys
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError(
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
)
_API_KEY = os.getenv("API_KEY")
API_KEYS = (
[_API_KEY]
if _API_KEY
else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
)
API_KEYS = [token for token in API_KEYS if token]
# Determine if SFT is available
SFT_AVAILABLE = False
llama_options = ["llama3", "llama-3", "llama 3"]
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
SFT_AVAILABLE = True
if passed_pre_query_template in llama_options:
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif passed_pre_query_template in qwen_options:
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
else:
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
elif MODEL.lower() in llama_options or any(
option in MODEL.lower() for option in llama_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options or any(
option in MODEL.lower() for option in qwen_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
if OPENAI_BASE_URL:
SFT_AVAILABLE = False
if not SFT_AVAILABLE:
warnings.warn(
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
)
MAGPIE_PRE_QUERY_TEMPLATE = None
# Embeddings
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
# Argilla
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
"ARGILLA_API_URL_SDG_REVIEWER"
)
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
"ARGILLA_API_KEY_SDG_REVIEWER"
)
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
argilla_client = None
else:
argilla_client = rg.Argilla(
api_url=ARGILLA_API_URL,
api_key=ARGILLA_API_KEY,
)
|