Merge pull request #20 from argilla-io/feat/improve-support-local-deployment
Browse files- README.md +11 -7
- examples/{argilla_deployment.py → argilla-deployment.py} +5 -2
- examples/fine-tune-modernbert-classifier.ipynb +1 -1
- examples/{openai_local.py → hf-dedicated-or-tgi-deployment.py} +7 -4
- examples/{enforce_mapgie_template copy.py → hf-serverless-deployment.py} +3 -2
- examples/ollama-deployment.py +22 -0
- examples/{ollama_local.py → openai-deployment.py} +6 -3
- examples/vllm-deployment.py +21 -0
- pdm.lock +51 -16
- pyproject.toml +1 -1
- src/synthetic_dataset_generator/_distiset.py +10 -0
- src/synthetic_dataset_generator/apps/base.py +7 -2
- src/synthetic_dataset_generator/apps/chat.py +8 -11
- src/synthetic_dataset_generator/apps/eval.py +3 -1
- src/synthetic_dataset_generator/apps/textcat.py +41 -20
- src/synthetic_dataset_generator/constants.py +53 -25
- src/synthetic_dataset_generator/pipelines/base.py +132 -1
- src/synthetic_dataset_generator/pipelines/chat.py +36 -72
- src/synthetic_dataset_generator/pipelines/textcat.py +12 -52
- src/synthetic_dataset_generator/utils.py +5 -0
README.md
CHANGED
@@ -28,7 +28,7 @@ hf_oauth_scopes:
|
|
28 |
|
29 |
## Introduction
|
30 |
|
31 |
-
Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also
|
32 |
|
33 |
Supported Tasks:
|
34 |
|
@@ -76,21 +76,25 @@ launch()
|
|
76 |
|
77 |
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
|
78 |
|
79 |
-
|
80 |
|
81 |
- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
|
82 |
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
|
83 |
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
|
84 |
|
85 |
-
Optionally, you can use different
|
86 |
|
87 |
-
- `
|
88 |
-
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
|
89 |
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
SFT and Chat Data generation is
|
92 |
|
93 |
-
- `
|
|
|
94 |
|
95 |
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
96 |
|
|
|
28 |
|
29 |
## Introduction
|
30 |
|
31 |
+
Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
|
32 |
|
33 |
Supported Tasks:
|
34 |
|
|
|
76 |
|
77 |
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
|
78 |
|
79 |
+
You can set the following environment variables to customize the generation process.
|
80 |
|
81 |
- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
|
82 |
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
|
83 |
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
|
84 |
|
85 |
+
Optionally, you can use different API providers and models.
|
86 |
|
87 |
+
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
|
|
88 |
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
|
89 |
+
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
90 |
+
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
|
91 |
+
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
92 |
+
- `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
|
93 |
|
94 |
+
SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
|
95 |
|
96 |
+
- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
|
97 |
+
- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
|
98 |
|
99 |
Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
|
100 |
|
examples/{argilla_deployment.py → argilla-deployment.py}
RENAMED
@@ -9,7 +9,10 @@ import os
|
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
# Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
|
12 |
-
os.environ["
|
13 |
-
os.environ["
|
|
|
|
|
|
|
14 |
|
15 |
launch()
|
|
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
# Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
|
12 |
+
os.environ["HF_TOKEN"] = "hf_..."
|
13 |
+
os.environ["ARGILLA_API_URL"] = (
|
14 |
+
"https://[your-owner-name]-[your_space_name].hf.space" # argilla base url
|
15 |
+
)
|
16 |
+
os.environ["ARGILLA_API_KEY"] = "my_api_key" # argilla api key
|
17 |
|
18 |
launch()
|
examples/fine-tune-modernbert-classifier.ipynb
CHANGED
@@ -530,7 +530,7 @@
|
|
530 |
"name": "python",
|
531 |
"nbconvert_exporter": "python",
|
532 |
"pygments_lexer": "ipython3",
|
533 |
-
"version": "3.11.
|
534 |
}
|
535 |
},
|
536 |
"nbformat": 4,
|
|
|
530 |
"name": "python",
|
531 |
"nbconvert_exporter": "python",
|
532 |
"pygments_lexer": "ipython3",
|
533 |
+
"version": "3.11.11"
|
534 |
}
|
535 |
},
|
536 |
"nbformat": 4,
|
examples/{openai_local.py → hf-dedicated-or-tgi-deployment.py}
RENAMED
@@ -8,9 +8,12 @@ import os
|
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
-
|
12 |
-
os.environ["
|
13 |
-
os.environ["
|
14 |
-
os.environ["
|
|
|
|
|
|
|
15 |
|
16 |
launch()
|
|
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
+
os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
|
13 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template
|
14 |
+
os.environ["TOKENIZER_ID"] = (
|
15 |
+
"meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
|
16 |
+
)
|
17 |
+
os.environ["MODEL"] = None # model is linked to endpoint
|
18 |
|
19 |
launch()
|
examples/{enforce_mapgie_template copy.py → hf-serverless-deployment.py}
RENAMED
@@ -8,7 +8,8 @@ import os
|
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
-
os.environ["
|
12 |
-
os.environ["MODEL"] = "
|
|
|
13 |
|
14 |
launch()
|
|
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
+
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
|
13 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
14 |
|
15 |
launch()
|
examples/ollama-deployment.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11,<3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "synthetic-dataset-generator",
|
5 |
+
# ]
|
6 |
+
# ///
|
7 |
+
# ollama serve
|
8 |
+
# ollama run qwen2.5:32b-instruct-q5_K_S
|
9 |
+
import os
|
10 |
+
|
11 |
+
from synthetic_dataset_generator import launch
|
12 |
+
|
13 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
14 |
+
os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
|
15 |
+
os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id
|
16 |
+
os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id
|
17 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
|
18 |
+
os.environ["MAX_NUM_ROWS"] = "10000"
|
19 |
+
os.environ["DEFAULT_BATCH_SIZE"] = "2"
|
20 |
+
os.environ["MAX_NUM_TOKENS"] = "1024"
|
21 |
+
|
22 |
+
launch()
|
examples/{ollama_local.py → openai-deployment.py}
RENAMED
@@ -4,12 +4,15 @@
|
|
4 |
# "synthetic-dataset-generator",
|
5 |
# ]
|
6 |
# ///
|
|
|
7 |
import os
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
-
|
12 |
-
os.environ["
|
13 |
-
os.environ["
|
|
|
|
|
14 |
|
15 |
launch()
|
|
|
4 |
# "synthetic-dataset-generator",
|
5 |
# ]
|
6 |
# ///
|
7 |
+
|
8 |
import os
|
9 |
|
10 |
from synthetic_dataset_generator import launch
|
11 |
|
12 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
13 |
+
os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
|
14 |
+
os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key
|
15 |
+
os.environ["MODEL"] = "gpt-4o" # model id
|
16 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI
|
17 |
|
18 |
launch()
|
examples/vllm-deployment.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11,<3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "synthetic-dataset-generator",
|
5 |
+
# ]
|
6 |
+
# ///
|
7 |
+
# vllm serve Qwen/Qwen2.5-1.5B-Instruct
|
8 |
+
import os
|
9 |
+
|
10 |
+
from synthetic_dataset_generator import launch
|
11 |
+
|
12 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
13 |
+
os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
|
14 |
+
os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id
|
15 |
+
os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id
|
16 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
|
17 |
+
os.environ["MAX_NUM_ROWS"] = "10000"
|
18 |
+
os.environ["DEFAULT_BATCH_SIZE"] = "2"
|
19 |
+
os.environ["MAX_NUM_TOKENS"] = "1024"
|
20 |
+
|
21 |
+
launch()
|
pdm.lock
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
groups = ["default"]
|
6 |
strategy = ["inherit_metadata"]
|
7 |
lock_version = "4.5.0"
|
8 |
-
content_hash = "sha256:
|
9 |
|
10 |
[[metadata.targets]]
|
11 |
requires_python = ">=3.10,<3.13"
|
@@ -491,8 +491,11 @@ files = [
|
|
491 |
|
492 |
[[package]]
|
493 |
name = "distilabel"
|
494 |
-
version = "1.
|
495 |
requires_python = ">=3.9"
|
|
|
|
|
|
|
496 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
497 |
groups = ["default"]
|
498 |
dependencies = [
|
@@ -512,30 +515,30 @@ dependencies = [
|
|
512 |
"typer>=0.9.0",
|
513 |
"universal-pathlib>=0.2.2",
|
514 |
]
|
515 |
-
files = [
|
516 |
-
{file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
|
517 |
-
{file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
|
518 |
-
]
|
519 |
|
520 |
[[package]]
|
521 |
name = "distilabel"
|
522 |
-
version = "1.
|
523 |
-
extras = ["argilla", "hf-inference-endpoints", "instructor", "outlines"]
|
524 |
requires_python = ">=3.9"
|
|
|
|
|
|
|
525 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
526 |
groups = ["default"]
|
527 |
dependencies = [
|
528 |
"argilla>=2.0.0",
|
529 |
-
"distilabel
|
530 |
"huggingface-hub>=0.22.0",
|
531 |
"instructor>=1.2.3",
|
532 |
"ipython",
|
|
|
533 |
"numba>=0.54.0",
|
|
|
|
|
534 |
"outlines>=0.0.40",
|
535 |
-
|
536 |
-
|
537 |
-
{file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
|
538 |
-
{file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
|
539 |
]
|
540 |
|
541 |
[[package]]
|
@@ -824,7 +827,7 @@ files = [
|
|
824 |
|
825 |
[[package]]
|
826 |
name = "httpx"
|
827 |
-
version = "0.
|
828 |
requires_python = ">=3.8"
|
829 |
summary = "The next generation HTTP client."
|
830 |
groups = ["default"]
|
@@ -833,10 +836,11 @@ dependencies = [
|
|
833 |
"certifi",
|
834 |
"httpcore==1.*",
|
835 |
"idna",
|
|
|
836 |
]
|
837 |
files = [
|
838 |
-
{file = "httpx-0.
|
839 |
-
{file = "httpx-0.
|
840 |
]
|
841 |
|
842 |
[[package]]
|
@@ -1068,6 +1072,22 @@ files = [
|
|
1068 |
{file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
|
1069 |
]
|
1070 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1071 |
[[package]]
|
1072 |
name = "llvmlite"
|
1073 |
version = "0.43.0"
|
@@ -1538,6 +1558,21 @@ files = [
|
|
1538 |
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
|
1539 |
]
|
1540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1541 |
[[package]]
|
1542 |
name = "openai"
|
1543 |
version = "1.57.4"
|
|
|
5 |
groups = ["default"]
|
6 |
strategy = ["inherit_metadata"]
|
7 |
lock_version = "4.5.0"
|
8 |
+
content_hash = "sha256:e95140895657d62ad438ff1815ddf1798abbb342ddd2649ae462620b8b3f5350"
|
9 |
|
10 |
[[metadata.targets]]
|
11 |
requires_python = ">=3.10,<3.13"
|
|
|
491 |
|
492 |
[[package]]
|
493 |
name = "distilabel"
|
494 |
+
version = "1.5.0"
|
495 |
requires_python = ">=3.9"
|
496 |
+
git = "https://github.com/argilla-io/distilabel.git"
|
497 |
+
ref = "feat/add-magpie-support-llama-cpp-ollama"
|
498 |
+
revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
|
499 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
500 |
groups = ["default"]
|
501 |
dependencies = [
|
|
|
515 |
"typer>=0.9.0",
|
516 |
"universal-pathlib>=0.2.2",
|
517 |
]
|
|
|
|
|
|
|
|
|
518 |
|
519 |
[[package]]
|
520 |
name = "distilabel"
|
521 |
+
version = "1.5.0"
|
522 |
+
extras = ["argilla", "hf-inference-endpoints", "hf-transformers", "instructor", "llama-cpp", "ollama", "openai", "outlines"]
|
523 |
requires_python = ">=3.9"
|
524 |
+
git = "https://github.com/argilla-io/distilabel.git"
|
525 |
+
ref = "feat/add-magpie-support-llama-cpp-ollama"
|
526 |
+
revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
|
527 |
summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
|
528 |
groups = ["default"]
|
529 |
dependencies = [
|
530 |
"argilla>=2.0.0",
|
531 |
+
"distilabel @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
|
532 |
"huggingface-hub>=0.22.0",
|
533 |
"instructor>=1.2.3",
|
534 |
"ipython",
|
535 |
+
"llama-cpp-python>=0.2.0",
|
536 |
"numba>=0.54.0",
|
537 |
+
"ollama>=0.1.7",
|
538 |
+
"openai>=1.0.0",
|
539 |
"outlines>=0.0.40",
|
540 |
+
"torch>=2.0.0",
|
541 |
+
"transformers>=4.34.1",
|
|
|
|
|
542 |
]
|
543 |
|
544 |
[[package]]
|
|
|
827 |
|
828 |
[[package]]
|
829 |
name = "httpx"
|
830 |
+
version = "0.27.2"
|
831 |
requires_python = ">=3.8"
|
832 |
summary = "The next generation HTTP client."
|
833 |
groups = ["default"]
|
|
|
836 |
"certifi",
|
837 |
"httpcore==1.*",
|
838 |
"idna",
|
839 |
+
"sniffio",
|
840 |
]
|
841 |
files = [
|
842 |
+
{file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
|
843 |
+
{file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
|
844 |
]
|
845 |
|
846 |
[[package]]
|
|
|
1072 |
{file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
|
1073 |
]
|
1074 |
|
1075 |
+
[[package]]
|
1076 |
+
name = "llama-cpp-python"
|
1077 |
+
version = "0.3.5"
|
1078 |
+
requires_python = ">=3.8"
|
1079 |
+
summary = "Python bindings for the llama.cpp library"
|
1080 |
+
groups = ["default"]
|
1081 |
+
dependencies = [
|
1082 |
+
"diskcache>=5.6.1",
|
1083 |
+
"jinja2>=2.11.3",
|
1084 |
+
"numpy>=1.20.0",
|
1085 |
+
"typing-extensions>=4.5.0",
|
1086 |
+
]
|
1087 |
+
files = [
|
1088 |
+
{file = "llama_cpp_python-0.3.5.tar.gz", hash = "sha256:f5ce47499d53d3973e28ca5bdaf2dfe820163fa3fb67e3050f98e2e9b58d2cf6"},
|
1089 |
+
]
|
1090 |
+
|
1091 |
[[package]]
|
1092 |
name = "llvmlite"
|
1093 |
version = "0.43.0"
|
|
|
1558 |
{file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
|
1559 |
]
|
1560 |
|
1561 |
+
[[package]]
|
1562 |
+
name = "ollama"
|
1563 |
+
version = "0.4.4"
|
1564 |
+
requires_python = "<4.0,>=3.8"
|
1565 |
+
summary = "The official Python client for Ollama."
|
1566 |
+
groups = ["default"]
|
1567 |
+
dependencies = [
|
1568 |
+
"httpx<0.28.0,>=0.27.0",
|
1569 |
+
"pydantic<3.0.0,>=2.9.0",
|
1570 |
+
]
|
1571 |
+
files = [
|
1572 |
+
{file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"},
|
1573 |
+
{file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"},
|
1574 |
+
]
|
1575 |
+
|
1576 |
[[package]]
|
1577 |
name = "openai"
|
1578 |
version = "1.57.4"
|
pyproject.toml
CHANGED
@@ -18,7 +18,7 @@ readme = "README.md"
|
|
18 |
license = {text = "Apache 2"}
|
19 |
|
20 |
dependencies = [
|
21 |
-
"distilabel[hf-inference-endpoints,
|
22 |
"gradio[oauth]>=5.4.0,<6.0.0",
|
23 |
"transformers>=4.44.2,<5.0.0",
|
24 |
"sentence-transformers>=3.2.0,<4.0.0",
|
|
|
18 |
license = {text = "Apache 2"}
|
19 |
|
20 |
dependencies = [
|
21 |
+
"distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@develop",
|
22 |
"gradio[oauth]>=5.4.0,<6.0.0",
|
23 |
"transformers>=4.44.2,<5.0.0",
|
24 |
"sentence-transformers>=3.2.0,<4.0.0",
|
src/synthetic_dataset_generator/_distiset.py
CHANGED
@@ -81,6 +81,15 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
|
|
81 |
dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
|
82 |
)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
readme_metadata = {}
|
85 |
if repo_id and token:
|
86 |
readme_metadata = self._extract_readme_metadata(repo_id, token)
|
@@ -90,6 +99,7 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
|
|
90 |
"size_categories": size_categories_parser(
|
91 |
max(len(dataset) for dataset in self.values())
|
92 |
),
|
|
|
93 |
"tags": [
|
94 |
"synthetic",
|
95 |
"distilabel",
|
|
|
81 |
dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
|
82 |
)
|
83 |
|
84 |
+
keys = list(sample_records.keys())
|
85 |
+
if len(keys) != 2 or not (
|
86 |
+
("label" in keys and "text" in keys)
|
87 |
+
or ("labels" in keys and "text" in keys)
|
88 |
+
):
|
89 |
+
task_categories = ["text-classification"]
|
90 |
+
elif "prompt" in keys or "messages" in keys:
|
91 |
+
task_categories = ["text-generation", "text2text-generation"]
|
92 |
+
|
93 |
readme_metadata = {}
|
94 |
if repo_id and token:
|
95 |
readme_metadata = self._extract_readme_metadata(repo_id, token)
|
|
|
99 |
"size_categories": size_categories_parser(
|
100 |
max(len(dataset) for dataset in self.values())
|
101 |
),
|
102 |
+
"task_categories": task_categories,
|
103 |
"tags": [
|
104 |
"synthetic",
|
105 |
"distilabel",
|
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name):
|
|
77 |
return repo_id
|
78 |
|
79 |
|
80 |
-
def combine_datasets(
|
|
|
|
|
81 |
try:
|
82 |
new_dataset = load_dataset(
|
83 |
-
repo_id,
|
|
|
|
|
|
|
84 |
)
|
85 |
return concatenate_datasets([dataset, new_dataset])
|
86 |
except Exception:
|
|
|
77 |
return repo_id
|
78 |
|
79 |
|
80 |
+
def combine_datasets(
|
81 |
+
repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
|
82 |
+
) -> Dataset:
|
83 |
try:
|
84 |
new_dataset = load_dataset(
|
85 |
+
repo_id,
|
86 |
+
split="train",
|
87 |
+
download_mode="force_redownload",
|
88 |
+
token=oauth_token.token,
|
89 |
)
|
90 |
return concatenate_datasets([dataset, new_dataset])
|
91 |
except Exception:
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -25,12 +25,12 @@ from synthetic_dataset_generator.constants import (
|
|
25 |
MODEL,
|
26 |
SFT_AVAILABLE,
|
27 |
)
|
|
|
28 |
from synthetic_dataset_generator.pipelines.chat import (
|
29 |
DEFAULT_DATASET_DESCRIPTIONS,
|
30 |
generate_pipeline_code,
|
31 |
get_magpie_generator,
|
32 |
get_prompt_generator,
|
33 |
-
get_prompt_rewriter,
|
34 |
get_response_generator,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
@@ -40,6 +40,7 @@ from synthetic_dataset_generator.pipelines.embeddings import (
|
|
40 |
from synthetic_dataset_generator.utils import (
|
41 |
get_argilla_client,
|
42 |
get_org_dropdown,
|
|
|
43 |
swap_visibility,
|
44 |
)
|
45 |
|
@@ -106,7 +107,6 @@ def generate_dataset(
|
|
106 |
) -> pd.DataFrame:
|
107 |
num_rows = test_max_num_rows(num_rows)
|
108 |
progress(0.0, desc="(1/2) Generating instructions")
|
109 |
-
prompt_rewriter = get_prompt_rewriter()
|
110 |
magpie_generator = get_magpie_generator(
|
111 |
system_prompt, num_turns, temperature, is_sample
|
112 |
)
|
@@ -117,14 +117,7 @@ def generate_dataset(
|
|
117 |
batch_size = DEFAULT_BATCH_SIZE
|
118 |
|
119 |
# create prompt rewrites
|
120 |
-
|
121 |
-
{
|
122 |
-
"instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
|
123 |
-
}
|
124 |
-
for i in range(int(num_rows / 50))
|
125 |
-
]
|
126 |
-
batch = list(prompt_rewriter.process(inputs=inputs))
|
127 |
-
prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
|
128 |
|
129 |
# create instructions
|
130 |
n_processed = 0
|
@@ -142,6 +135,7 @@ def generate_dataset(
|
|
142 |
batch = list(magpie_generator.process(inputs=inputs))
|
143 |
magpie_results.extend(batch[0])
|
144 |
n_processed += batch_size
|
|
|
145 |
progress(0.5, desc="(1/2) Generating instructions")
|
146 |
|
147 |
# generate responses
|
@@ -158,6 +152,7 @@ def generate_dataset(
|
|
158 |
responses = list(response_generator.process(inputs=batch))
|
159 |
response_results.extend(responses[0])
|
160 |
n_processed += batch_size
|
|
|
161 |
for result in response_results:
|
162 |
result["prompt"] = result["instruction"]
|
163 |
result["completion"] = result["generation"]
|
@@ -178,6 +173,7 @@ def generate_dataset(
|
|
178 |
responses = list(response_generator.process(inputs=batch))
|
179 |
response_results.extend(responses[0])
|
180 |
n_processed += batch_size
|
|
|
181 |
for result in response_results:
|
182 |
result["messages"].append(
|
183 |
{"role": "assistant", "content": result["generation"]}
|
@@ -236,7 +232,7 @@ def push_dataset_to_hub(
|
|
236 |
dataframe = convert_dataframe_messages(dataframe)
|
237 |
progress(0.7, desc="Creating dataset")
|
238 |
dataset = Dataset.from_pandas(dataframe)
|
239 |
-
dataset = combine_datasets(repo_id, dataset)
|
240 |
progress(0.9, desc="Pushing dataset")
|
241 |
distiset = Distiset({"default": dataset})
|
242 |
distiset.push_to_hub(
|
@@ -600,4 +596,5 @@ with gr.Blocks() as app:
|
|
600 |
outputs=[dataset_description, system_prompt, num_turns, dataframe],
|
601 |
)
|
602 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
603 |
app.load(fn=swap_visibility, outputs=main_ui)
|
|
|
25 |
MODEL,
|
26 |
SFT_AVAILABLE,
|
27 |
)
|
28 |
+
from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.chat import (
|
30 |
DEFAULT_DATASET_DESCRIPTIONS,
|
31 |
generate_pipeline_code,
|
32 |
get_magpie_generator,
|
33 |
get_prompt_generator,
|
|
|
34 |
get_response_generator,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
|
|
40 |
from synthetic_dataset_generator.utils import (
|
41 |
get_argilla_client,
|
42 |
get_org_dropdown,
|
43 |
+
get_random_repo_name,
|
44 |
swap_visibility,
|
45 |
)
|
46 |
|
|
|
107 |
) -> pd.DataFrame:
|
108 |
num_rows = test_max_num_rows(num_rows)
|
109 |
progress(0.0, desc="(1/2) Generating instructions")
|
|
|
110 |
magpie_generator = get_magpie_generator(
|
111 |
system_prompt, num_turns, temperature, is_sample
|
112 |
)
|
|
|
117 |
batch_size = DEFAULT_BATCH_SIZE
|
118 |
|
119 |
# create prompt rewrites
|
120 |
+
prompt_rewrites = get_rewriten_prompts(system_prompt, num_rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# create instructions
|
123 |
n_processed = 0
|
|
|
135 |
batch = list(magpie_generator.process(inputs=inputs))
|
136 |
magpie_results.extend(batch[0])
|
137 |
n_processed += batch_size
|
138 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
139 |
progress(0.5, desc="(1/2) Generating instructions")
|
140 |
|
141 |
# generate responses
|
|
|
152 |
responses = list(response_generator.process(inputs=batch))
|
153 |
response_results.extend(responses[0])
|
154 |
n_processed += batch_size
|
155 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
156 |
for result in response_results:
|
157 |
result["prompt"] = result["instruction"]
|
158 |
result["completion"] = result["generation"]
|
|
|
173 |
responses = list(response_generator.process(inputs=batch))
|
174 |
response_results.extend(responses[0])
|
175 |
n_processed += batch_size
|
176 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
177 |
for result in response_results:
|
178 |
result["messages"].append(
|
179 |
{"role": "assistant", "content": result["generation"]}
|
|
|
232 |
dataframe = convert_dataframe_messages(dataframe)
|
233 |
progress(0.7, desc="Creating dataset")
|
234 |
dataset = Dataset.from_pandas(dataframe)
|
235 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
236 |
progress(0.9, desc="Pushing dataset")
|
237 |
distiset = Distiset({"default": dataset})
|
238 |
distiset.push_to_hub(
|
|
|
596 |
outputs=[dataset_description, system_prompt, num_turns, dataframe],
|
597 |
)
|
598 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
599 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
600 |
app.load(fn=swap_visibility, outputs=main_ui)
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -41,6 +41,7 @@ from synthetic_dataset_generator.utils import (
|
|
41 |
extract_column_names,
|
42 |
get_argilla_client,
|
43 |
get_org_dropdown,
|
|
|
44 |
pad_or_truncate_list,
|
45 |
process_columns,
|
46 |
swap_visibility,
|
@@ -359,7 +360,7 @@ def push_dataset_to_hub(
|
|
359 |
):
|
360 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
361 |
dataset = Dataset.from_pandas(dataframe)
|
362 |
-
dataset = combine_datasets(repo_id, dataset)
|
363 |
distiset = Distiset({"default": dataset})
|
364 |
distiset.push_to_hub(
|
365 |
repo_id=repo_id,
|
@@ -907,3 +908,4 @@ with gr.Blocks() as app:
|
|
907 |
|
908 |
app.load(fn=swap_visibility, outputs=main_ui)
|
909 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
|
41 |
extract_column_names,
|
42 |
get_argilla_client,
|
43 |
get_org_dropdown,
|
44 |
+
get_random_repo_name,
|
45 |
pad_or_truncate_list,
|
46 |
process_columns,
|
47 |
swap_visibility,
|
|
|
360 |
):
|
361 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
362 |
dataset = Dataset.from_pandas(dataframe)
|
363 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
364 |
distiset = Distiset({"default": dataset})
|
365 |
distiset.push_to_hub(
|
366 |
repo_id=repo_id,
|
|
|
908 |
|
909 |
app.load(fn=swap_visibility, outputs=main_ui)
|
910 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
911 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -20,6 +20,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
20 |
validate_push_to_hub,
|
21 |
)
|
22 |
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
|
|
23 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
24 |
get_embeddings,
|
25 |
get_sentence_embedding_dimensions,
|
@@ -35,6 +36,7 @@ from synthetic_dataset_generator.utils import (
|
|
35 |
get_argilla_client,
|
36 |
get_org_dropdown,
|
37 |
get_preprocess_labels,
|
|
|
38 |
swap_visibility,
|
39 |
)
|
40 |
|
@@ -106,7 +108,7 @@ def generate_dataset(
|
|
106 |
)
|
107 |
updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
|
108 |
if multi_label:
|
109 |
-
updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
|
110 |
labeller_generator = get_labeller_generator(
|
111 |
system_prompt=updated_system_prompt,
|
112 |
labels=labels,
|
@@ -118,6 +120,7 @@ def generate_dataset(
|
|
118 |
# create text classification data
|
119 |
n_processed = 0
|
120 |
textcat_results = []
|
|
|
121 |
while n_processed < num_rows:
|
122 |
progress(
|
123 |
2 * 0.5 * n_processed / num_rows,
|
@@ -128,25 +131,24 @@ def generate_dataset(
|
|
128 |
batch_size = min(batch_size, remaining_rows)
|
129 |
inputs = []
|
130 |
for _ in range(batch_size):
|
|
|
131 |
if multi_label:
|
132 |
num_labels = len(labels)
|
133 |
k = int(
|
134 |
random.betavariate(alpha=(num_labels - 1), beta=num_labels)
|
135 |
* num_labels
|
136 |
)
|
137 |
-
else:
|
138 |
-
k = 1
|
139 |
-
|
140 |
sampled_labels = random.sample(labels, min(k, len(labels)))
|
141 |
random.shuffle(sampled_labels)
|
142 |
inputs.append(
|
143 |
{
|
144 |
-
"task": f"{
|
145 |
}
|
146 |
)
|
147 |
batch = list(textcat_generator.process(inputs=inputs))
|
148 |
textcat_results.extend(batch[0])
|
149 |
n_processed += batch_size
|
|
|
150 |
for result in textcat_results:
|
151 |
result["text"] = result["input_text"]
|
152 |
|
@@ -164,6 +166,7 @@ def generate_dataset(
|
|
164 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
165 |
labeller_results.extend(labels_batch[0])
|
166 |
n_processed += batch_size
|
|
|
167 |
progress(
|
168 |
1,
|
169 |
total=total_steps,
|
@@ -178,26 +181,43 @@ def generate_dataset(
|
|
178 |
|
179 |
dataframe = pd.DataFrame(distiset_results)
|
180 |
if multi_label:
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
185 |
label.lower().strip()
|
186 |
-
if (label is not None and label.lower().strip() in labels)
|
187 |
-
else random.choice(labels)
|
188 |
for label in x
|
189 |
-
|
|
|
190 |
)
|
191 |
-
|
192 |
-
|
|
|
|
|
193 |
dataframe = dataframe[dataframe["labels"].notna()]
|
194 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
dataframe = dataframe.rename(columns={"labels": "label"})
|
196 |
-
dataframe["label"] = dataframe["label"].apply(
|
197 |
-
lambda x: x.lower().strip()
|
198 |
-
if x and x.lower().strip() in labels
|
199 |
-
else random.choice(labels)
|
200 |
-
)
|
201 |
dataframe = dataframe[dataframe["text"].notna()]
|
202 |
|
203 |
progress(1.0, desc="Dataset created")
|
@@ -235,7 +255,7 @@ def push_dataset_to_hub(
|
|
235 |
dataframe.reset_index(drop=True),
|
236 |
features=features,
|
237 |
)
|
238 |
-
dataset = combine_datasets(repo_id, dataset)
|
239 |
distiset = Distiset({"default": dataset})
|
240 |
progress(0.9, desc="Pushing dataset")
|
241 |
distiset.push_to_hub(
|
@@ -647,3 +667,4 @@ with gr.Blocks() as app:
|
|
647 |
|
648 |
app.load(fn=swap_visibility, outputs=main_ui)
|
649 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
|
|
|
20 |
validate_push_to_hub,
|
21 |
)
|
22 |
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
23 |
+
from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
|
24 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
25 |
get_embeddings,
|
26 |
get_sentence_embedding_dimensions,
|
|
|
36 |
get_argilla_client,
|
37 |
get_org_dropdown,
|
38 |
get_preprocess_labels,
|
39 |
+
get_random_repo_name,
|
40 |
swap_visibility,
|
41 |
)
|
42 |
|
|
|
108 |
)
|
109 |
updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
|
110 |
if multi_label:
|
111 |
+
updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is always better than applying too many labels."
|
112 |
labeller_generator = get_labeller_generator(
|
113 |
system_prompt=updated_system_prompt,
|
114 |
labels=labels,
|
|
|
120 |
# create text classification data
|
121 |
n_processed = 0
|
122 |
textcat_results = []
|
123 |
+
rewritten_system_prompts = get_rewriten_prompts(system_prompt, num_rows)
|
124 |
while n_processed < num_rows:
|
125 |
progress(
|
126 |
2 * 0.5 * n_processed / num_rows,
|
|
|
131 |
batch_size = min(batch_size, remaining_rows)
|
132 |
inputs = []
|
133 |
for _ in range(batch_size):
|
134 |
+
k = 1
|
135 |
if multi_label:
|
136 |
num_labels = len(labels)
|
137 |
k = int(
|
138 |
random.betavariate(alpha=(num_labels - 1), beta=num_labels)
|
139 |
* num_labels
|
140 |
)
|
|
|
|
|
|
|
141 |
sampled_labels = random.sample(labels, min(k, len(labels)))
|
142 |
random.shuffle(sampled_labels)
|
143 |
inputs.append(
|
144 |
{
|
145 |
+
"task": f"{random.choice(rewritten_system_prompts)}. The text represents the following categories: {', '.join(sampled_labels)}"
|
146 |
}
|
147 |
)
|
148 |
batch = list(textcat_generator.process(inputs=inputs))
|
149 |
textcat_results.extend(batch[0])
|
150 |
n_processed += batch_size
|
151 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
152 |
for result in textcat_results:
|
153 |
result["text"] = result["input_text"]
|
154 |
|
|
|
166 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
167 |
labeller_results.extend(labels_batch[0])
|
168 |
n_processed += batch_size
|
169 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
170 |
progress(
|
171 |
1,
|
172 |
total=total_steps,
|
|
|
181 |
|
182 |
dataframe = pd.DataFrame(distiset_results)
|
183 |
if multi_label:
|
184 |
+
|
185 |
+
def _validate_labels(x):
|
186 |
+
if isinstance(x, str): # single label
|
187 |
+
return [x.lower().strip()]
|
188 |
+
elif isinstance(x, list): # multiple labels
|
189 |
+
return list(
|
190 |
+
set(
|
191 |
label.lower().strip()
|
|
|
|
|
192 |
for label in x
|
193 |
+
if label.lower().strip() in labels
|
194 |
+
)
|
195 |
)
|
196 |
+
else:
|
197 |
+
return list(set([random.choice(labels)]))
|
198 |
+
|
199 |
+
dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
|
200 |
dataframe = dataframe[dataframe["labels"].notna()]
|
201 |
else:
|
202 |
+
|
203 |
+
def _validate_labels(x):
|
204 |
+
if isinstance(x, str) and x.lower().strip() in labels:
|
205 |
+
return x.lower().strip()
|
206 |
+
elif isinstance(x, list):
|
207 |
+
options = [
|
208 |
+
label.lower().strip()
|
209 |
+
for label in x
|
210 |
+
if isinstance(label, str) and label.lower().strip() in labels
|
211 |
+
]
|
212 |
+
if options:
|
213 |
+
return random.choice(options)
|
214 |
+
else:
|
215 |
+
return random.choice(labels)
|
216 |
+
else:
|
217 |
+
return random.choice(labels)
|
218 |
+
|
219 |
dataframe = dataframe.rename(columns={"labels": "label"})
|
220 |
+
dataframe["label"] = dataframe["label"].apply(_validate_labels)
|
|
|
|
|
|
|
|
|
221 |
dataframe = dataframe[dataframe["text"].notna()]
|
222 |
|
223 |
progress(1.0, desc="Dataset created")
|
|
|
255 |
dataframe.reset_index(drop=True),
|
256 |
features=features,
|
257 |
)
|
258 |
+
dataset = combine_datasets(repo_id, dataset, oauth_token)
|
259 |
distiset = Distiset({"default": dataset})
|
260 |
progress(0.9, desc="Pushing dataset")
|
261 |
distiset.push_to_hub(
|
|
|
667 |
|
668 |
app.load(fn=swap_visibility, outputs=main_ui)
|
669 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
670 |
+
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -7,39 +7,66 @@ import argilla as rg
|
|
7 |
TEXTCAT_TASK = "text_classification"
|
8 |
SFT_TASK = "supervised_fine_tuning"
|
9 |
|
10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
if not HF_TOKEN:
|
13 |
raise ValueError(
|
14 |
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
|
15 |
)
|
16 |
|
17 |
-
# Inference
|
18 |
-
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
19 |
-
MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
|
20 |
-
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
21 |
-
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
22 |
-
BASE_URL = os.getenv("BASE_URL", default=None)
|
23 |
-
|
24 |
_API_KEY = os.getenv("API_KEY")
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
]
|
31 |
API_KEYS = [token for token in API_KEYS if token]
|
32 |
|
33 |
# Determine if SFT is available
|
34 |
SFT_AVAILABLE = False
|
35 |
llama_options = ["llama3", "llama-3", "llama 3"]
|
36 |
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
|
37 |
-
|
|
|
38 |
SFT_AVAILABLE = True
|
39 |
-
passed_pre_query_template
|
40 |
-
if passed_pre_query_template.lower() in llama_options:
|
41 |
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
|
42 |
-
elif passed_pre_query_template
|
43 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
44 |
else:
|
45 |
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
|
@@ -54,12 +81,12 @@ elif MODEL.lower() in qwen_options or any(
|
|
54 |
SFT_AVAILABLE = True
|
55 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
56 |
|
57 |
-
if
|
58 |
SFT_AVAILABLE = False
|
59 |
|
60 |
if not SFT_AVAILABLE:
|
61 |
warnings.warn(
|
62 |
-
|
63 |
)
|
64 |
MAGPIE_PRE_QUERY_TEMPLATE = None
|
65 |
|
@@ -67,11 +94,12 @@ if not SFT_AVAILABLE:
|
|
67 |
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
|
68 |
|
69 |
# Argilla
|
70 |
-
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
|
76 |
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
|
77 |
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
|
|
|
7 |
TEXTCAT_TASK = "text_classification"
|
8 |
SFT_TASK = "supervised_fine_tuning"
|
9 |
|
10 |
+
# Inference
|
11 |
+
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
12 |
+
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
13 |
+
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
14 |
+
|
15 |
+
# Models
|
16 |
+
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
17 |
+
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
|
18 |
+
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
19 |
+
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
|
20 |
+
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
|
21 |
+
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
|
22 |
+
|
23 |
+
# check if model is set correctly
|
24 |
+
if HUGGINGFACE_BASE_URL and MODEL:
|
25 |
+
raise ValueError(
|
26 |
+
"`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
|
27 |
+
)
|
28 |
+
if not MODEL:
|
29 |
+
if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
|
30 |
+
raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
|
31 |
+
|
32 |
+
# Check if multiple base URLs are provided
|
33 |
+
base_urls = [
|
34 |
+
url
|
35 |
+
for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
|
36 |
+
if url
|
37 |
+
]
|
38 |
+
if len(base_urls) > 1:
|
39 |
+
raise ValueError(
|
40 |
+
f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
|
41 |
+
)
|
42 |
+
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
|
43 |
+
|
44 |
+
|
45 |
+
# API Keys
|
46 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
47 |
if not HF_TOKEN:
|
48 |
raise ValueError(
|
49 |
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
|
50 |
)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
_API_KEY = os.getenv("API_KEY")
|
53 |
+
API_KEYS = (
|
54 |
+
[_API_KEY]
|
55 |
+
if _API_KEY
|
56 |
+
else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
|
57 |
+
)
|
|
|
58 |
API_KEYS = [token for token in API_KEYS if token]
|
59 |
|
60 |
# Determine if SFT is available
|
61 |
SFT_AVAILABLE = False
|
62 |
llama_options = ["llama3", "llama-3", "llama 3"]
|
63 |
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
|
64 |
+
|
65 |
+
if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
|
66 |
SFT_AVAILABLE = True
|
67 |
+
if passed_pre_query_template in llama_options:
|
|
|
68 |
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
|
69 |
+
elif passed_pre_query_template in qwen_options:
|
70 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
71 |
else:
|
72 |
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
|
|
|
81 |
SFT_AVAILABLE = True
|
82 |
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
|
83 |
|
84 |
+
if OPENAI_BASE_URL:
|
85 |
SFT_AVAILABLE = False
|
86 |
|
87 |
if not SFT_AVAILABLE:
|
88 |
warnings.warn(
|
89 |
+
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
|
90 |
)
|
91 |
MAGPIE_PRE_QUERY_TEMPLATE = None
|
92 |
|
|
|
94 |
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
|
95 |
|
96 |
# Argilla
|
97 |
+
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
|
98 |
+
"ARGILLA_API_URL_SDG_REVIEWER"
|
99 |
+
)
|
100 |
+
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
|
101 |
+
"ARGILLA_API_KEY_SDG_REVIEWER"
|
102 |
+
)
|
103 |
|
104 |
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
|
105 |
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
|
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -1,4 +1,21 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
TOKEN_INDEX = 0
|
4 |
|
@@ -8,3 +25,117 @@ def _get_next_api_key():
|
|
8 |
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
|
9 |
TOKEN_INDEX += 1
|
10 |
return api_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from distilabel.llms import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
|
6 |
+
from distilabel.steps.tasks import TextGeneration
|
7 |
+
|
8 |
+
from synthetic_dataset_generator.constants import (
|
9 |
+
API_KEYS,
|
10 |
+
DEFAULT_BATCH_SIZE,
|
11 |
+
HUGGINGFACE_BASE_URL,
|
12 |
+
MAGPIE_PRE_QUERY_TEMPLATE,
|
13 |
+
MODEL,
|
14 |
+
OLLAMA_BASE_URL,
|
15 |
+
OPENAI_BASE_URL,
|
16 |
+
TOKENIZER_ID,
|
17 |
+
VLLM_BASE_URL,
|
18 |
+
)
|
19 |
|
20 |
TOKEN_INDEX = 0
|
21 |
|
|
|
25 |
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
|
26 |
TOKEN_INDEX += 1
|
27 |
return api_key
|
28 |
+
|
29 |
+
|
30 |
+
def _get_prompt_rewriter():
|
31 |
+
generation_kwargs = {
|
32 |
+
"temperature": 1,
|
33 |
+
}
|
34 |
+
system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
|
35 |
+
prompt_rewriter = TextGeneration(
|
36 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
37 |
+
system_prompt=system_prompt,
|
38 |
+
use_system_prompt=True,
|
39 |
+
)
|
40 |
+
prompt_rewriter.load()
|
41 |
+
return prompt_rewriter
|
42 |
+
|
43 |
+
|
44 |
+
def get_rewriten_prompts(prompt: str, num_rows: int):
|
45 |
+
prompt_rewriter = _get_prompt_rewriter()
|
46 |
+
# create prompt rewrites
|
47 |
+
inputs = [
|
48 |
+
{"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
|
49 |
+
for i in range(math.floor(num_rows / 100))
|
50 |
+
]
|
51 |
+
n_processed = 0
|
52 |
+
prompt_rewrites = [prompt]
|
53 |
+
while n_processed < num_rows:
|
54 |
+
batch = list(
|
55 |
+
prompt_rewriter.process(
|
56 |
+
inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
|
57 |
+
)
|
58 |
+
)
|
59 |
+
prompt_rewrites += [entry["generation"] for entry in batch[0]]
|
60 |
+
n_processed += DEFAULT_BATCH_SIZE
|
61 |
+
random.seed(a=random.randint(0, 2**32 - 1))
|
62 |
+
return prompt_rewrites
|
63 |
+
|
64 |
+
|
65 |
+
def _get_llm(use_magpie_template=False, **kwargs):
|
66 |
+
if OPENAI_BASE_URL:
|
67 |
+
llm = OpenAILLM(
|
68 |
+
model=MODEL,
|
69 |
+
base_url=OPENAI_BASE_URL,
|
70 |
+
api_key=_get_next_api_key(),
|
71 |
+
**kwargs,
|
72 |
+
)
|
73 |
+
if "generation_kwargs" in kwargs:
|
74 |
+
if "stop_sequences" in kwargs["generation_kwargs"]:
|
75 |
+
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
|
76 |
+
"stop_sequences"
|
77 |
+
]
|
78 |
+
del kwargs["generation_kwargs"]["stop_sequences"]
|
79 |
+
if "do_sample" in kwargs["generation_kwargs"]:
|
80 |
+
del kwargs["generation_kwargs"]["do_sample"]
|
81 |
+
elif OLLAMA_BASE_URL:
|
82 |
+
if "generation_kwargs" in kwargs:
|
83 |
+
if "max_new_tokens" in kwargs["generation_kwargs"]:
|
84 |
+
kwargs["generation_kwargs"]["num_predict"] = kwargs[
|
85 |
+
"generation_kwargs"
|
86 |
+
]["max_new_tokens"]
|
87 |
+
del kwargs["generation_kwargs"]["max_new_tokens"]
|
88 |
+
if "stop_sequences" in kwargs["generation_kwargs"]:
|
89 |
+
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
|
90 |
+
"stop_sequences"
|
91 |
+
]
|
92 |
+
del kwargs["generation_kwargs"]["stop_sequences"]
|
93 |
+
if "do_sample" in kwargs["generation_kwargs"]:
|
94 |
+
del kwargs["generation_kwargs"]["do_sample"]
|
95 |
+
options = kwargs["generation_kwargs"]
|
96 |
+
del kwargs["generation_kwargs"]
|
97 |
+
kwargs["generation_kwargs"] = {}
|
98 |
+
kwargs["generation_kwargs"]["options"] = options
|
99 |
+
llm = OllamaLLM(
|
100 |
+
model=MODEL,
|
101 |
+
host=OLLAMA_BASE_URL,
|
102 |
+
tokenizer_id=TOKENIZER_ID or MODEL,
|
103 |
+
**kwargs,
|
104 |
+
)
|
105 |
+
elif HUGGINGFACE_BASE_URL:
|
106 |
+
kwargs["generation_kwargs"]["do_sample"] = True
|
107 |
+
llm = InferenceEndpointsLLM(
|
108 |
+
api_key=_get_next_api_key(),
|
109 |
+
base_url=HUGGINGFACE_BASE_URL,
|
110 |
+
tokenizer_id=TOKENIZER_ID or MODEL,
|
111 |
+
**kwargs,
|
112 |
+
)
|
113 |
+
elif VLLM_BASE_URL:
|
114 |
+
if "generation_kwargs" in kwargs:
|
115 |
+
if "do_sample" in kwargs["generation_kwargs"]:
|
116 |
+
del kwargs["generation_kwargs"]["do_sample"]
|
117 |
+
llm = ClientvLLM(
|
118 |
+
base_url=VLLM_BASE_URL,
|
119 |
+
model=MODEL,
|
120 |
+
tokenizer=TOKENIZER_ID or MODEL,
|
121 |
+
api_key=_get_next_api_key(),
|
122 |
+
**kwargs,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
llm = InferenceEndpointsLLM(
|
126 |
+
api_key=_get_next_api_key(),
|
127 |
+
tokenizer_id=TOKENIZER_ID or MODEL,
|
128 |
+
model_id=MODEL,
|
129 |
+
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
130 |
+
**kwargs,
|
131 |
+
)
|
132 |
+
|
133 |
+
return llm
|
134 |
+
|
135 |
+
|
136 |
+
try:
|
137 |
+
llm = _get_llm()
|
138 |
+
llm.load()
|
139 |
+
llm.generate([[{"content": "Hello, world!", "role": "user"}]])
|
140 |
+
except Exception as e:
|
141 |
+
gr.Error(f"Error loading {llm.__class__.__name__}: {e}")
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from distilabel.llms import InferenceEndpointsLLM
|
2 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
3 |
|
4 |
from synthetic_dataset_generator.constants import (
|
@@ -7,7 +6,7 @@ from synthetic_dataset_generator.constants import (
|
|
7 |
MAX_NUM_TOKENS,
|
8 |
MODEL,
|
9 |
)
|
10 |
-
from synthetic_dataset_generator.pipelines.base import
|
11 |
|
12 |
INFORMATION_SEEKING_PROMPT = (
|
13 |
"You are an AI assistant designed to provide accurate and concise information on a wide"
|
@@ -149,18 +148,13 @@ def _get_output_mappings(num_turns):
|
|
149 |
|
150 |
|
151 |
def get_prompt_generator():
|
|
|
|
|
|
|
|
|
|
|
152 |
prompt_generator = TextGeneration(
|
153 |
-
llm=
|
154 |
-
api_key=_get_next_api_key(),
|
155 |
-
model_id=MODEL,
|
156 |
-
tokenizer_id=MODEL,
|
157 |
-
base_url=BASE_URL,
|
158 |
-
generation_kwargs={
|
159 |
-
"temperature": 0.8,
|
160 |
-
"max_new_tokens": MAX_NUM_TOKENS,
|
161 |
-
"do_sample": True,
|
162 |
-
},
|
163 |
-
),
|
164 |
system_prompt=PROMPT_CREATION_PROMPT,
|
165 |
use_system_prompt=True,
|
166 |
)
|
@@ -172,38 +166,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
172 |
input_mappings = _get_output_mappings(num_turns)
|
173 |
output_mappings = input_mappings.copy()
|
174 |
if num_turns == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
magpie_generator = Magpie(
|
176 |
-
llm=
|
177 |
-
|
178 |
-
tokenizer_id=MODEL,
|
179 |
-
base_url=BASE_URL,
|
180 |
-
api_key=_get_next_api_key(),
|
181 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
182 |
-
|
183 |
-
"temperature": temperature,
|
184 |
-
"do_sample": True,
|
185 |
-
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
|
186 |
-
"stop_sequences": _STOP_SEQUENCES,
|
187 |
-
},
|
188 |
),
|
189 |
n_turns=num_turns,
|
190 |
output_mappings=output_mappings,
|
191 |
only_instruction=True,
|
192 |
)
|
193 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
magpie_generator = Magpie(
|
195 |
-
llm=
|
196 |
-
|
197 |
-
tokenizer_id=MODEL,
|
198 |
-
base_url=BASE_URL,
|
199 |
-
api_key=_get_next_api_key(),
|
200 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
201 |
-
|
202 |
-
"temperature": temperature,
|
203 |
-
"do_sample": True,
|
204 |
-
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
205 |
-
"stop_sequences": _STOP_SEQUENCES,
|
206 |
-
},
|
207 |
),
|
208 |
end_with_user=True,
|
209 |
n_turns=num_turns,
|
@@ -213,51 +203,25 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
213 |
return magpie_generator
|
214 |
|
215 |
|
216 |
-
def get_prompt_rewriter():
|
217 |
-
prompt_rewriter = TextGeneration(
|
218 |
-
llm=InferenceEndpointsLLM(
|
219 |
-
model_id=MODEL,
|
220 |
-
tokenizer_id=MODEL,
|
221 |
-
base_url=BASE_URL,
|
222 |
-
api_key=_get_next_api_key(),
|
223 |
-
generation_kwargs={
|
224 |
-
"temperature": 1,
|
225 |
-
},
|
226 |
-
),
|
227 |
-
)
|
228 |
-
prompt_rewriter.load()
|
229 |
-
return prompt_rewriter
|
230 |
-
|
231 |
-
|
232 |
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
233 |
if num_turns == 1:
|
|
|
|
|
|
|
|
|
234 |
response_generator = TextGeneration(
|
235 |
-
llm=
|
236 |
-
model_id=MODEL,
|
237 |
-
tokenizer_id=MODEL,
|
238 |
-
base_url=BASE_URL,
|
239 |
-
api_key=_get_next_api_key(),
|
240 |
-
generation_kwargs={
|
241 |
-
"temperature": temperature,
|
242 |
-
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
243 |
-
},
|
244 |
-
),
|
245 |
system_prompt=system_prompt,
|
246 |
output_mappings={"generation": "completion"},
|
247 |
input_mappings={"instruction": "prompt"},
|
248 |
)
|
249 |
else:
|
|
|
|
|
|
|
|
|
250 |
response_generator = ChatGeneration(
|
251 |
-
llm=
|
252 |
-
model_id=MODEL,
|
253 |
-
tokenizer_id=MODEL,
|
254 |
-
base_url=BASE_URL,
|
255 |
-
api_key=_get_next_api_key(),
|
256 |
-
generation_kwargs={
|
257 |
-
"temperature": temperature,
|
258 |
-
"max_new_tokens": MAX_NUM_TOKENS,
|
259 |
-
},
|
260 |
-
),
|
261 |
output_mappings={"generation": "completion"},
|
262 |
input_mappings={"conversation": "messages"},
|
263 |
)
|
@@ -293,7 +257,7 @@ with Pipeline(name="sft") as pipeline:
|
|
293 |
"max_new_tokens": {MAX_NUM_TOKENS},
|
294 |
"stop_sequences": {_STOP_SEQUENCES}
|
295 |
}},
|
296 |
-
api_key=os.environ["
|
297 |
),
|
298 |
n_turns={num_turns},
|
299 |
num_rows={num_rows},
|
|
|
|
|
1 |
from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
2 |
|
3 |
from synthetic_dataset_generator.constants import (
|
|
|
6 |
MAX_NUM_TOKENS,
|
7 |
MODEL,
|
8 |
)
|
9 |
+
from synthetic_dataset_generator.pipelines.base import _get_llm
|
10 |
|
11 |
INFORMATION_SEEKING_PROMPT = (
|
12 |
"You are an AI assistant designed to provide accurate and concise information on a wide"
|
|
|
148 |
|
149 |
|
150 |
def get_prompt_generator():
|
151 |
+
generation_kwargs = {
|
152 |
+
"temperature": 0.8,
|
153 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
154 |
+
"do_sample": True,
|
155 |
+
}
|
156 |
prompt_generator = TextGeneration(
|
157 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
system_prompt=PROMPT_CREATION_PROMPT,
|
159 |
use_system_prompt=True,
|
160 |
)
|
|
|
166 |
input_mappings = _get_output_mappings(num_turns)
|
167 |
output_mappings = input_mappings.copy()
|
168 |
if num_turns == 1:
|
169 |
+
generation_kwargs = {
|
170 |
+
"temperature": temperature,
|
171 |
+
"do_sample": True,
|
172 |
+
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
|
173 |
+
"stop_sequences": _STOP_SEQUENCES,
|
174 |
+
}
|
175 |
magpie_generator = Magpie(
|
176 |
+
llm=_get_llm(
|
177 |
+
generation_kwargs=generation_kwargs,
|
|
|
|
|
|
|
178 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
179 |
+
use_magpie_template=True,
|
|
|
|
|
|
|
|
|
|
|
180 |
),
|
181 |
n_turns=num_turns,
|
182 |
output_mappings=output_mappings,
|
183 |
only_instruction=True,
|
184 |
)
|
185 |
else:
|
186 |
+
generation_kwargs = {
|
187 |
+
"temperature": temperature,
|
188 |
+
"do_sample": True,
|
189 |
+
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
190 |
+
"stop_sequences": _STOP_SEQUENCES,
|
191 |
+
}
|
192 |
magpie_generator = Magpie(
|
193 |
+
llm=_get_llm(
|
194 |
+
generation_kwargs=generation_kwargs,
|
|
|
|
|
|
|
195 |
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
|
196 |
+
use_magpie_template=True,
|
|
|
|
|
|
|
|
|
|
|
197 |
),
|
198 |
end_with_user=True,
|
199 |
n_turns=num_turns,
|
|
|
203 |
return magpie_generator
|
204 |
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
207 |
if num_turns == 1:
|
208 |
+
generation_kwargs = {
|
209 |
+
"temperature": temperature,
|
210 |
+
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
211 |
+
}
|
212 |
response_generator = TextGeneration(
|
213 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
system_prompt=system_prompt,
|
215 |
output_mappings={"generation": "completion"},
|
216 |
input_mappings={"instruction": "prompt"},
|
217 |
)
|
218 |
else:
|
219 |
+
generation_kwargs = {
|
220 |
+
"temperature": temperature,
|
221 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
222 |
+
}
|
223 |
response_generator = ChatGeneration(
|
224 |
+
llm=_get_llm(generation_kwargs=generation_kwargs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
output_mappings={"generation": "completion"},
|
226 |
input_mappings={"conversation": "messages"},
|
227 |
)
|
|
|
257 |
"max_new_tokens": {MAX_NUM_TOKENS},
|
258 |
"stop_sequences": {_STOP_SEQUENCES}
|
259 |
}},
|
260 |
+
api_key=os.environ["API_KEY"],
|
261 |
),
|
262 |
n_turns={num_turns},
|
263 |
num_rows={num_rows},
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import random
|
2 |
from typing import List
|
3 |
|
4 |
-
from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
|
5 |
from distilabel.steps.tasks import (
|
6 |
GenerateTextClassificationData,
|
7 |
TextClassification,
|
@@ -9,8 +8,12 @@ from distilabel.steps.tasks import (
|
|
9 |
)
|
10 |
from pydantic import BaseModel, Field
|
11 |
|
12 |
-
from synthetic_dataset_generator.constants import
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
from synthetic_dataset_generator.utils import get_preprocess_labels
|
15 |
|
16 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
@@ -69,23 +72,10 @@ def get_prompt_generator():
|
|
69 |
"temperature": 0.8,
|
70 |
"max_new_tokens": MAX_NUM_TOKENS,
|
71 |
}
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
api_key=_get_next_api_key(),
|
77 |
-
structured_output=structured_output,
|
78 |
-
generation_kwargs=generation_kwargs,
|
79 |
-
)
|
80 |
-
else:
|
81 |
-
generation_kwargs["do_sample"] = True
|
82 |
-
llm = InferenceEndpointsLLM(
|
83 |
-
api_key=_get_next_api_key(),
|
84 |
-
model_id=MODEL,
|
85 |
-
base_url=BASE_URL,
|
86 |
-
structured_output=structured_output,
|
87 |
-
generation_kwargs=generation_kwargs,
|
88 |
-
)
|
89 |
|
90 |
prompt_generator = TextGeneration(
|
91 |
llm=llm,
|
@@ -103,22 +93,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
103 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
104 |
"top_p": 0.95,
|
105 |
}
|
106 |
-
|
107 |
-
llm = OpenAILLM(
|
108 |
-
model=MODEL,
|
109 |
-
base_url=BASE_URL,
|
110 |
-
api_key=_get_next_api_key(),
|
111 |
-
generation_kwargs=generation_kwargs,
|
112 |
-
)
|
113 |
-
else:
|
114 |
-
generation_kwargs["do_sample"] = True
|
115 |
-
llm = InferenceEndpointsLLM(
|
116 |
-
model_id=MODEL,
|
117 |
-
base_url=BASE_URL,
|
118 |
-
api_key=_get_next_api_key(),
|
119 |
-
generation_kwargs=generation_kwargs,
|
120 |
-
)
|
121 |
-
|
122 |
textcat_generator = GenerateTextClassificationData(
|
123 |
llm=llm,
|
124 |
difficulty=None if difficulty == "mixed" else difficulty,
|
@@ -134,22 +109,7 @@ def get_labeller_generator(system_prompt, labels, multi_label):
|
|
134 |
"temperature": 0.01,
|
135 |
"max_new_tokens": MAX_NUM_TOKENS,
|
136 |
}
|
137 |
-
|
138 |
-
if BASE_URL:
|
139 |
-
llm = OpenAILLM(
|
140 |
-
model=MODEL,
|
141 |
-
base_url=BASE_URL,
|
142 |
-
api_key=_get_next_api_key(),
|
143 |
-
generation_kwargs=generation_kwargs,
|
144 |
-
)
|
145 |
-
else:
|
146 |
-
llm = InferenceEndpointsLLM(
|
147 |
-
model_id=MODEL,
|
148 |
-
base_url=BASE_URL,
|
149 |
-
api_key=_get_next_api_key(),
|
150 |
-
generation_kwargs=generation_kwargs,
|
151 |
-
)
|
152 |
-
|
153 |
labeller_generator = TextClassification(
|
154 |
llm=llm,
|
155 |
context=system_prompt,
|
|
|
1 |
import random
|
2 |
from typing import List
|
3 |
|
|
|
4 |
from distilabel.steps.tasks import (
|
5 |
GenerateTextClassificationData,
|
6 |
TextClassification,
|
|
|
8 |
)
|
9 |
from pydantic import BaseModel, Field
|
10 |
|
11 |
+
from synthetic_dataset_generator.constants import (
|
12 |
+
BASE_URL,
|
13 |
+
MAX_NUM_TOKENS,
|
14 |
+
MODEL,
|
15 |
+
)
|
16 |
+
from synthetic_dataset_generator.pipelines.base import _get_llm
|
17 |
from synthetic_dataset_generator.utils import get_preprocess_labels
|
18 |
|
19 |
PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
|
|
|
72 |
"temperature": 0.8,
|
73 |
"max_new_tokens": MAX_NUM_TOKENS,
|
74 |
}
|
75 |
+
llm = _get_llm(
|
76 |
+
structured_output=structured_output,
|
77 |
+
generation_kwargs=generation_kwargs,
|
78 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
prompt_generator = TextGeneration(
|
81 |
llm=llm,
|
|
|
93 |
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
94 |
"top_p": 0.95,
|
95 |
}
|
96 |
+
llm = _get_llm(generation_kwargs=generation_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
textcat_generator = GenerateTextClassificationData(
|
98 |
llm=llm,
|
99 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|
111 |
}
|
112 |
+
llm = _get_llm(generation_kwargs=generation_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
labeller_generator = TextClassification(
|
114 |
llm=llm,
|
115 |
context=system_prompt,
|
src/synthetic_dataset_generator/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import warnings
|
3 |
from typing import List, Optional, Union
|
4 |
|
@@ -55,6 +56,10 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None):
|
|
55 |
return organizations
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
58 |
def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
|
59 |
if oauth_token is not None:
|
60 |
orgs = list_orgs(oauth_token)
|
|
|
1 |
import json
|
2 |
+
import uuid
|
3 |
import warnings
|
4 |
from typing import List, Optional, Union
|
5 |
|
|
|
56 |
return organizations
|
57 |
|
58 |
|
59 |
+
def get_random_repo_name():
|
60 |
+
return f"my-distiset-{str(uuid.uuid4())[:8]}"
|
61 |
+
|
62 |
+
|
63 |
def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
|
64 |
if oauth_token is not None:
|
65 |
orgs = list_orgs(oauth_token)
|