davidberenstein1957 HF staff commited on
Commit
3c6a88c
·
unverified ·
2 Parent(s): 9a45614 ce401b1

Merge pull request #20 from argilla-io/feat/improve-support-local-deployment

Browse files
README.md CHANGED
@@ -28,7 +28,7 @@ hf_oauth_scopes:
28
 
29
  ## Introduction
30
 
31
- Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also wathh the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
32
 
33
  Supported Tasks:
34
 
@@ -76,21 +76,25 @@ launch()
76
 
77
  - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
78
 
79
- Optionally, you can set the following environment variables to customize the generation process.
80
 
81
  - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
82
  - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
83
  - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
84
 
85
- Optionally, you can use different models and APIs. For providers outside of Hugging Face, we provide an integration through [LiteLLM](https://docs.litellm.ai/docs/providers).
86
 
87
- - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`, `http://127.0.0.1:11434/v1/`.
88
- - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
89
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
 
 
 
 
90
 
91
- SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
92
 
93
- - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. Llama3 and Qwen2 are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"` respectively. For other models, you can pass a custom pre-query template string.
 
94
 
95
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
96
 
 
28
 
29
  ## Introduction
30
 
31
+ Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
32
 
33
  Supported Tasks:
34
 
 
76
 
77
  - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.
78
 
79
+ You can set the following environment variables to customize the generation process.
80
 
81
  - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
82
  - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
83
  - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
84
 
85
+ Optionally, you can use different API providers and models.
86
 
87
+ - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
 
88
  - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
89
+ - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
90
+ - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
91
+ - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
92
+ - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
93
 
94
+ SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
95
 
96
+ - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
97
+ - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.
98
 
99
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
100
 
examples/{argilla_deployment.py → argilla-deployment.py} RENAMED
@@ -9,7 +9,10 @@ import os
9
  from synthetic_dataset_generator import launch
10
 
11
  # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
12
- os.environ["ARGILLA_API_URL"] = "https://[your-owner-name]-[your_space_name].hf.space"
13
- os.environ["ARGILLA_API_KEY"] = "my_api_key"
 
 
 
14
 
15
  launch()
 
9
  from synthetic_dataset_generator import launch
10
 
11
  # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
12
+ os.environ["HF_TOKEN"] = "hf_..."
13
+ os.environ["ARGILLA_API_URL"] = (
14
+ "https://[your-owner-name]-[your_space_name].hf.space" # argilla base url
15
+ )
16
+ os.environ["ARGILLA_API_KEY"] = "my_api_key" # argilla api key
17
 
18
  launch()
examples/fine-tune-modernbert-classifier.ipynb CHANGED
@@ -530,7 +530,7 @@
530
  "name": "python",
531
  "nbconvert_exporter": "python",
532
  "pygments_lexer": "ipython3",
533
- "version": "3.11.9"
534
  }
535
  },
536
  "nbformat": 4,
 
530
  "name": "python",
531
  "nbconvert_exporter": "python",
532
  "pygments_lexer": "ipython3",
533
+ "version": "3.11.11"
534
  }
535
  },
536
  "nbformat": 4,
examples/{openai_local.py → hf-dedicated-or-tgi-deployment.py} RENAMED
@@ -8,9 +8,12 @@ import os
8
 
9
  from synthetic_dataset_generator import launch
10
 
11
- assert os.getenv("HF_TOKEN") # push the data to huggingface
12
- os.environ["BASE_URL"] = "https://api.openai.com/v1/"
13
- os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY")
14
- os.environ["MODEL"] = "gpt-4o"
 
 
 
15
 
16
  launch()
 
8
 
9
  from synthetic_dataset_generator import launch
10
 
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
13
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template
14
+ os.environ["TOKENIZER_ID"] = (
15
+ "meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
16
+ )
17
+ os.environ["MODEL"] = None # model is linked to endpoint
18
 
19
  launch()
examples/{enforce_mapgie_template copy.py → hf-serverless-deployment.py} RENAMED
@@ -8,7 +8,8 @@ import os
8
 
9
  from synthetic_dataset_generator import launch
10
 
11
- os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "my_custom_template"
12
- os.environ["MODEL"] = "google/gemma-2-9b-it"
 
13
 
14
  launch()
 
8
 
9
  from synthetic_dataset_generator import launch
10
 
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
13
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
14
 
15
  launch()
examples/ollama-deployment.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # ollama serve
8
+ # ollama run qwen2.5:32b-instruct-q5_K_S
9
+ import os
10
+
11
+ from synthetic_dataset_generator import launch
12
+
13
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
14
+ os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
15
+ os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id
16
+ os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id
17
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
18
+ os.environ["MAX_NUM_ROWS"] = "10000"
19
+ os.environ["DEFAULT_BATCH_SIZE"] = "2"
20
+ os.environ["MAX_NUM_TOKENS"] = "1024"
21
+
22
+ launch()
examples/{ollama_local.py → openai-deployment.py} RENAMED
@@ -4,12 +4,15 @@
4
  # "synthetic-dataset-generator",
5
  # ]
6
  # ///
 
7
  import os
8
 
9
  from synthetic_dataset_generator import launch
10
 
11
- assert os.getenv("HF_TOKEN") # push the data to huggingface
12
- os.environ["BASE_URL"] = "http://127.0.0.1:11434/v1/"
13
- os.environ["MODEL"] = "llama3.1"
 
 
14
 
15
  launch()
 
4
  # "synthetic-dataset-generator",
5
  # ]
6
  # ///
7
+
8
  import os
9
 
10
  from synthetic_dataset_generator import launch
11
 
12
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
+ os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
14
+ os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key
15
+ os.environ["MODEL"] = "gpt-4o" # model id
16
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI
17
 
18
  launch()
examples/vllm-deployment.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # vllm serve Qwen/Qwen2.5-1.5B-Instruct
8
+ import os
9
+
10
+ from synthetic_dataset_generator import launch
11
+
12
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
13
+ os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
14
+ os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id
15
+ os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id
16
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
17
+ os.environ["MAX_NUM_ROWS"] = "10000"
18
+ os.environ["DEFAULT_BATCH_SIZE"] = "2"
19
+ os.environ["MAX_NUM_TOKENS"] = "1024"
20
+
21
+ launch()
pdm.lock CHANGED
@@ -5,7 +5,7 @@
5
  groups = ["default"]
6
  strategy = ["inherit_metadata"]
7
  lock_version = "4.5.0"
8
- content_hash = "sha256:a6d7f86e9a168e7eb78801faafa8a9ea13270310ee9c21a0e23aeab277d973a0"
9
 
10
  [[metadata.targets]]
11
  requires_python = ">=3.10,<3.13"
@@ -491,8 +491,11 @@ files = [
491
 
492
  [[package]]
493
  name = "distilabel"
494
- version = "1.4.1"
495
  requires_python = ">=3.9"
 
 
 
496
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
497
  groups = ["default"]
498
  dependencies = [
@@ -512,30 +515,30 @@ dependencies = [
512
  "typer>=0.9.0",
513
  "universal-pathlib>=0.2.2",
514
  ]
515
- files = [
516
- {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
517
- {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
518
- ]
519
 
520
  [[package]]
521
  name = "distilabel"
522
- version = "1.4.1"
523
- extras = ["argilla", "hf-inference-endpoints", "instructor", "outlines"]
524
  requires_python = ">=3.9"
 
 
 
525
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
526
  groups = ["default"]
527
  dependencies = [
528
  "argilla>=2.0.0",
529
- "distilabel==1.4.1",
530
  "huggingface-hub>=0.22.0",
531
  "instructor>=1.2.3",
532
  "ipython",
 
533
  "numba>=0.54.0",
 
 
534
  "outlines>=0.0.40",
535
- ]
536
- files = [
537
- {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"},
538
- {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"},
539
  ]
540
 
541
  [[package]]
@@ -824,7 +827,7 @@ files = [
824
 
825
  [[package]]
826
  name = "httpx"
827
- version = "0.28.1"
828
  requires_python = ">=3.8"
829
  summary = "The next generation HTTP client."
830
  groups = ["default"]
@@ -833,10 +836,11 @@ dependencies = [
833
  "certifi",
834
  "httpcore==1.*",
835
  "idna",
 
836
  ]
837
  files = [
838
- {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
839
- {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
840
  ]
841
 
842
  [[package]]
@@ -1068,6 +1072,22 @@ files = [
1068
  {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
1069
  ]
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  [[package]]
1072
  name = "llvmlite"
1073
  version = "0.43.0"
@@ -1538,6 +1558,21 @@ files = [
1538
  {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
1539
  ]
1540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1541
  [[package]]
1542
  name = "openai"
1543
  version = "1.57.4"
 
5
  groups = ["default"]
6
  strategy = ["inherit_metadata"]
7
  lock_version = "4.5.0"
8
+ content_hash = "sha256:e95140895657d62ad438ff1815ddf1798abbb342ddd2649ae462620b8b3f5350"
9
 
10
  [[metadata.targets]]
11
  requires_python = ">=3.10,<3.13"
 
491
 
492
  [[package]]
493
  name = "distilabel"
494
+ version = "1.5.0"
495
  requires_python = ">=3.9"
496
+ git = "https://github.com/argilla-io/distilabel.git"
497
+ ref = "feat/add-magpie-support-llama-cpp-ollama"
498
+ revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
499
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
500
  groups = ["default"]
501
  dependencies = [
 
515
  "typer>=0.9.0",
516
  "universal-pathlib>=0.2.2",
517
  ]
 
 
 
 
518
 
519
  [[package]]
520
  name = "distilabel"
521
+ version = "1.5.0"
522
+ extras = ["argilla", "hf-inference-endpoints", "hf-transformers", "instructor", "llama-cpp", "ollama", "openai", "outlines"]
523
  requires_python = ">=3.9"
524
+ git = "https://github.com/argilla-io/distilabel.git"
525
+ ref = "feat/add-magpie-support-llama-cpp-ollama"
526
+ revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6"
527
  summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
528
  groups = ["default"]
529
  dependencies = [
530
  "argilla>=2.0.0",
531
+ "distilabel @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama",
532
  "huggingface-hub>=0.22.0",
533
  "instructor>=1.2.3",
534
  "ipython",
535
+ "llama-cpp-python>=0.2.0",
536
  "numba>=0.54.0",
537
+ "ollama>=0.1.7",
538
+ "openai>=1.0.0",
539
  "outlines>=0.0.40",
540
+ "torch>=2.0.0",
541
+ "transformers>=4.34.1",
 
 
542
  ]
543
 
544
  [[package]]
 
827
 
828
  [[package]]
829
  name = "httpx"
830
+ version = "0.27.2"
831
  requires_python = ">=3.8"
832
  summary = "The next generation HTTP client."
833
  groups = ["default"]
 
836
  "certifi",
837
  "httpcore==1.*",
838
  "idna",
839
+ "sniffio",
840
  ]
841
  files = [
842
+ {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
843
+ {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
844
  ]
845
 
846
  [[package]]
 
1072
  {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
1073
  ]
1074
 
1075
+ [[package]]
1076
+ name = "llama-cpp-python"
1077
+ version = "0.3.5"
1078
+ requires_python = ">=3.8"
1079
+ summary = "Python bindings for the llama.cpp library"
1080
+ groups = ["default"]
1081
+ dependencies = [
1082
+ "diskcache>=5.6.1",
1083
+ "jinja2>=2.11.3",
1084
+ "numpy>=1.20.0",
1085
+ "typing-extensions>=4.5.0",
1086
+ ]
1087
+ files = [
1088
+ {file = "llama_cpp_python-0.3.5.tar.gz", hash = "sha256:f5ce47499d53d3973e28ca5bdaf2dfe820163fa3fb67e3050f98e2e9b58d2cf6"},
1089
+ ]
1090
+
1091
  [[package]]
1092
  name = "llvmlite"
1093
  version = "0.43.0"
 
1558
  {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
1559
  ]
1560
 
1561
+ [[package]]
1562
+ name = "ollama"
1563
+ version = "0.4.4"
1564
+ requires_python = "<4.0,>=3.8"
1565
+ summary = "The official Python client for Ollama."
1566
+ groups = ["default"]
1567
+ dependencies = [
1568
+ "httpx<0.28.0,>=0.27.0",
1569
+ "pydantic<3.0.0,>=2.9.0",
1570
+ ]
1571
+ files = [
1572
+ {file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"},
1573
+ {file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"},
1574
+ ]
1575
+
1576
  [[package]]
1577
  name = "openai"
1578
  version = "1.57.4"
pyproject.toml CHANGED
@@ -18,7 +18,7 @@ readme = "README.md"
18
  license = {text = "Apache 2"}
19
 
20
  dependencies = [
21
- "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1,<2.0.0",
22
  "gradio[oauth]>=5.4.0,<6.0.0",
23
  "transformers>=4.44.2,<5.0.0",
24
  "sentence-transformers>=3.2.0,<4.0.0",
 
18
  license = {text = "Apache 2"}
19
 
20
  dependencies = [
21
+ "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@develop",
22
  "gradio[oauth]>=5.4.0,<6.0.0",
23
  "transformers>=4.44.2,<5.0.0",
24
  "sentence-transformers>=3.2.0,<4.0.0",
src/synthetic_dataset_generator/_distiset.py CHANGED
@@ -81,6 +81,15 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
81
  dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
82
  )
83
 
 
 
 
 
 
 
 
 
 
84
  readme_metadata = {}
85
  if repo_id and token:
86
  readme_metadata = self._extract_readme_metadata(repo_id, token)
@@ -90,6 +99,7 @@ class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
90
  "size_categories": size_categories_parser(
91
  max(len(dataset) for dataset in self.values())
92
  ),
 
93
  "tags": [
94
  "synthetic",
95
  "distilabel",
 
81
  dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
82
  )
83
 
84
+ keys = list(sample_records.keys())
85
+ if len(keys) != 2 or not (
86
+ ("label" in keys and "text" in keys)
87
+ or ("labels" in keys and "text" in keys)
88
+ ):
89
+ task_categories = ["text-classification"]
90
+ elif "prompt" in keys or "messages" in keys:
91
+ task_categories = ["text-generation", "text2text-generation"]
92
+
93
  readme_metadata = {}
94
  if repo_id and token:
95
  readme_metadata = self._extract_readme_metadata(repo_id, token)
 
99
  "size_categories": size_categories_parser(
100
  max(len(dataset) for dataset in self.values())
101
  ),
102
+ "task_categories": task_categories,
103
  "tags": [
104
  "synthetic",
105
  "distilabel",
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name):
77
  return repo_id
78
 
79
 
80
- def combine_datasets(repo_id: str, dataset: Dataset) -> Dataset:
 
 
81
  try:
82
  new_dataset = load_dataset(
83
- repo_id, split="train", download_mode="force_redownload"
 
 
 
84
  )
85
  return concatenate_datasets([dataset, new_dataset])
86
  except Exception:
 
77
  return repo_id
78
 
79
 
80
+ def combine_datasets(
81
+ repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
82
+ ) -> Dataset:
83
  try:
84
  new_dataset = load_dataset(
85
+ repo_id,
86
+ split="train",
87
+ download_mode="force_redownload",
88
+ token=oauth_token.token,
89
  )
90
  return concatenate_datasets([dataset, new_dataset])
91
  except Exception:
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -25,12 +25,12 @@ from synthetic_dataset_generator.constants import (
25
  MODEL,
26
  SFT_AVAILABLE,
27
  )
 
28
  from synthetic_dataset_generator.pipelines.chat import (
29
  DEFAULT_DATASET_DESCRIPTIONS,
30
  generate_pipeline_code,
31
  get_magpie_generator,
32
  get_prompt_generator,
33
- get_prompt_rewriter,
34
  get_response_generator,
35
  )
36
  from synthetic_dataset_generator.pipelines.embeddings import (
@@ -40,6 +40,7 @@ from synthetic_dataset_generator.pipelines.embeddings import (
40
  from synthetic_dataset_generator.utils import (
41
  get_argilla_client,
42
  get_org_dropdown,
 
43
  swap_visibility,
44
  )
45
 
@@ -106,7 +107,6 @@ def generate_dataset(
106
  ) -> pd.DataFrame:
107
  num_rows = test_max_num_rows(num_rows)
108
  progress(0.0, desc="(1/2) Generating instructions")
109
- prompt_rewriter = get_prompt_rewriter()
110
  magpie_generator = get_magpie_generator(
111
  system_prompt, num_turns, temperature, is_sample
112
  )
@@ -117,14 +117,7 @@ def generate_dataset(
117
  batch_size = DEFAULT_BATCH_SIZE
118
 
119
  # create prompt rewrites
120
- inputs = [
121
- {
122
- "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
123
- }
124
- for i in range(int(num_rows / 50))
125
- ]
126
- batch = list(prompt_rewriter.process(inputs=inputs))
127
- prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
128
 
129
  # create instructions
130
  n_processed = 0
@@ -142,6 +135,7 @@ def generate_dataset(
142
  batch = list(magpie_generator.process(inputs=inputs))
143
  magpie_results.extend(batch[0])
144
  n_processed += batch_size
 
145
  progress(0.5, desc="(1/2) Generating instructions")
146
 
147
  # generate responses
@@ -158,6 +152,7 @@ def generate_dataset(
158
  responses = list(response_generator.process(inputs=batch))
159
  response_results.extend(responses[0])
160
  n_processed += batch_size
 
161
  for result in response_results:
162
  result["prompt"] = result["instruction"]
163
  result["completion"] = result["generation"]
@@ -178,6 +173,7 @@ def generate_dataset(
178
  responses = list(response_generator.process(inputs=batch))
179
  response_results.extend(responses[0])
180
  n_processed += batch_size
 
181
  for result in response_results:
182
  result["messages"].append(
183
  {"role": "assistant", "content": result["generation"]}
@@ -236,7 +232,7 @@ def push_dataset_to_hub(
236
  dataframe = convert_dataframe_messages(dataframe)
237
  progress(0.7, desc="Creating dataset")
238
  dataset = Dataset.from_pandas(dataframe)
239
- dataset = combine_datasets(repo_id, dataset)
240
  progress(0.9, desc="Pushing dataset")
241
  distiset = Distiset({"default": dataset})
242
  distiset.push_to_hub(
@@ -600,4 +596,5 @@ with gr.Blocks() as app:
600
  outputs=[dataset_description, system_prompt, num_turns, dataframe],
601
  )
602
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
603
  app.load(fn=swap_visibility, outputs=main_ui)
 
25
  MODEL,
26
  SFT_AVAILABLE,
27
  )
28
+ from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
29
  from synthetic_dataset_generator.pipelines.chat import (
30
  DEFAULT_DATASET_DESCRIPTIONS,
31
  generate_pipeline_code,
32
  get_magpie_generator,
33
  get_prompt_generator,
 
34
  get_response_generator,
35
  )
36
  from synthetic_dataset_generator.pipelines.embeddings import (
 
40
  from synthetic_dataset_generator.utils import (
41
  get_argilla_client,
42
  get_org_dropdown,
43
+ get_random_repo_name,
44
  swap_visibility,
45
  )
46
 
 
107
  ) -> pd.DataFrame:
108
  num_rows = test_max_num_rows(num_rows)
109
  progress(0.0, desc="(1/2) Generating instructions")
 
110
  magpie_generator = get_magpie_generator(
111
  system_prompt, num_turns, temperature, is_sample
112
  )
 
117
  batch_size = DEFAULT_BATCH_SIZE
118
 
119
  # create prompt rewrites
120
+ prompt_rewrites = get_rewriten_prompts(system_prompt, num_rows)
 
 
 
 
 
 
 
121
 
122
  # create instructions
123
  n_processed = 0
 
135
  batch = list(magpie_generator.process(inputs=inputs))
136
  magpie_results.extend(batch[0])
137
  n_processed += batch_size
138
+ random.seed(a=random.randint(0, 2**32 - 1))
139
  progress(0.5, desc="(1/2) Generating instructions")
140
 
141
  # generate responses
 
152
  responses = list(response_generator.process(inputs=batch))
153
  response_results.extend(responses[0])
154
  n_processed += batch_size
155
+ random.seed(a=random.randint(0, 2**32 - 1))
156
  for result in response_results:
157
  result["prompt"] = result["instruction"]
158
  result["completion"] = result["generation"]
 
173
  responses = list(response_generator.process(inputs=batch))
174
  response_results.extend(responses[0])
175
  n_processed += batch_size
176
+ random.seed(a=random.randint(0, 2**32 - 1))
177
  for result in response_results:
178
  result["messages"].append(
179
  {"role": "assistant", "content": result["generation"]}
 
232
  dataframe = convert_dataframe_messages(dataframe)
233
  progress(0.7, desc="Creating dataset")
234
  dataset = Dataset.from_pandas(dataframe)
235
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
236
  progress(0.9, desc="Pushing dataset")
237
  distiset = Distiset({"default": dataset})
238
  distiset.push_to_hub(
 
596
  outputs=[dataset_description, system_prompt, num_turns, dataframe],
597
  )
598
  app.load(fn=get_org_dropdown, outputs=[org_name])
599
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
600
  app.load(fn=swap_visibility, outputs=main_ui)
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -41,6 +41,7 @@ from synthetic_dataset_generator.utils import (
41
  extract_column_names,
42
  get_argilla_client,
43
  get_org_dropdown,
 
44
  pad_or_truncate_list,
45
  process_columns,
46
  swap_visibility,
@@ -359,7 +360,7 @@ def push_dataset_to_hub(
359
  ):
360
  repo_id = validate_push_to_hub(org_name, repo_name)
361
  dataset = Dataset.from_pandas(dataframe)
362
- dataset = combine_datasets(repo_id, dataset)
363
  distiset = Distiset({"default": dataset})
364
  distiset.push_to_hub(
365
  repo_id=repo_id,
@@ -907,3 +908,4 @@ with gr.Blocks() as app:
907
 
908
  app.load(fn=swap_visibility, outputs=main_ui)
909
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
 
41
  extract_column_names,
42
  get_argilla_client,
43
  get_org_dropdown,
44
+ get_random_repo_name,
45
  pad_or_truncate_list,
46
  process_columns,
47
  swap_visibility,
 
360
  ):
361
  repo_id = validate_push_to_hub(org_name, repo_name)
362
  dataset = Dataset.from_pandas(dataframe)
363
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
364
  distiset = Distiset({"default": dataset})
365
  distiset.push_to_hub(
366
  repo_id=repo_id,
 
908
 
909
  app.load(fn=swap_visibility, outputs=main_ui)
910
  app.load(fn=get_org_dropdown, outputs=[org_name])
911
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -20,6 +20,7 @@ from synthetic_dataset_generator.apps.base import (
20
  validate_push_to_hub,
21
  )
22
  from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
 
23
  from synthetic_dataset_generator.pipelines.embeddings import (
24
  get_embeddings,
25
  get_sentence_embedding_dimensions,
@@ -35,6 +36,7 @@ from synthetic_dataset_generator.utils import (
35
  get_argilla_client,
36
  get_org_dropdown,
37
  get_preprocess_labels,
 
38
  swap_visibility,
39
  )
40
 
@@ -106,7 +108,7 @@ def generate_dataset(
106
  )
107
  updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
108
  if multi_label:
109
- updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
110
  labeller_generator = get_labeller_generator(
111
  system_prompt=updated_system_prompt,
112
  labels=labels,
@@ -118,6 +120,7 @@ def generate_dataset(
118
  # create text classification data
119
  n_processed = 0
120
  textcat_results = []
 
121
  while n_processed < num_rows:
122
  progress(
123
  2 * 0.5 * n_processed / num_rows,
@@ -128,25 +131,24 @@ def generate_dataset(
128
  batch_size = min(batch_size, remaining_rows)
129
  inputs = []
130
  for _ in range(batch_size):
 
131
  if multi_label:
132
  num_labels = len(labels)
133
  k = int(
134
  random.betavariate(alpha=(num_labels - 1), beta=num_labels)
135
  * num_labels
136
  )
137
- else:
138
- k = 1
139
-
140
  sampled_labels = random.sample(labels, min(k, len(labels)))
141
  random.shuffle(sampled_labels)
142
  inputs.append(
143
  {
144
- "task": f"{system_prompt}. The text represents the following categories: {', '.join(sampled_labels)}"
145
  }
146
  )
147
  batch = list(textcat_generator.process(inputs=inputs))
148
  textcat_results.extend(batch[0])
149
  n_processed += batch_size
 
150
  for result in textcat_results:
151
  result["text"] = result["input_text"]
152
 
@@ -164,6 +166,7 @@ def generate_dataset(
164
  labels_batch = list(labeller_generator.process(inputs=batch))
165
  labeller_results.extend(labels_batch[0])
166
  n_processed += batch_size
 
167
  progress(
168
  1,
169
  total=total_steps,
@@ -178,26 +181,43 @@ def generate_dataset(
178
 
179
  dataframe = pd.DataFrame(distiset_results)
180
  if multi_label:
181
- dataframe["labels"] = dataframe["labels"].apply(
182
- lambda x: list(
183
- set(
184
- [
 
 
 
185
  label.lower().strip()
186
- if (label is not None and label.lower().strip() in labels)
187
- else random.choice(labels)
188
  for label in x
189
- ]
 
190
  )
191
- )
192
- )
 
 
193
  dataframe = dataframe[dataframe["labels"].notna()]
194
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  dataframe = dataframe.rename(columns={"labels": "label"})
196
- dataframe["label"] = dataframe["label"].apply(
197
- lambda x: x.lower().strip()
198
- if x and x.lower().strip() in labels
199
- else random.choice(labels)
200
- )
201
  dataframe = dataframe[dataframe["text"].notna()]
202
 
203
  progress(1.0, desc="Dataset created")
@@ -235,7 +255,7 @@ def push_dataset_to_hub(
235
  dataframe.reset_index(drop=True),
236
  features=features,
237
  )
238
- dataset = combine_datasets(repo_id, dataset)
239
  distiset = Distiset({"default": dataset})
240
  progress(0.9, desc="Pushing dataset")
241
  distiset.push_to_hub(
@@ -647,3 +667,4 @@ with gr.Blocks() as app:
647
 
648
  app.load(fn=swap_visibility, outputs=main_ui)
649
  app.load(fn=get_org_dropdown, outputs=[org_name])
 
 
20
  validate_push_to_hub,
21
  )
22
  from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
23
+ from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts
24
  from synthetic_dataset_generator.pipelines.embeddings import (
25
  get_embeddings,
26
  get_sentence_embedding_dimensions,
 
36
  get_argilla_client,
37
  get_org_dropdown,
38
  get_preprocess_labels,
39
+ get_random_repo_name,
40
  swap_visibility,
41
  )
42
 
 
108
  )
109
  updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
110
  if multi_label:
111
+ updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is always better than applying too many labels."
112
  labeller_generator = get_labeller_generator(
113
  system_prompt=updated_system_prompt,
114
  labels=labels,
 
120
  # create text classification data
121
  n_processed = 0
122
  textcat_results = []
123
+ rewritten_system_prompts = get_rewriten_prompts(system_prompt, num_rows)
124
  while n_processed < num_rows:
125
  progress(
126
  2 * 0.5 * n_processed / num_rows,
 
131
  batch_size = min(batch_size, remaining_rows)
132
  inputs = []
133
  for _ in range(batch_size):
134
+ k = 1
135
  if multi_label:
136
  num_labels = len(labels)
137
  k = int(
138
  random.betavariate(alpha=(num_labels - 1), beta=num_labels)
139
  * num_labels
140
  )
 
 
 
141
  sampled_labels = random.sample(labels, min(k, len(labels)))
142
  random.shuffle(sampled_labels)
143
  inputs.append(
144
  {
145
+ "task": f"{random.choice(rewritten_system_prompts)}. The text represents the following categories: {', '.join(sampled_labels)}"
146
  }
147
  )
148
  batch = list(textcat_generator.process(inputs=inputs))
149
  textcat_results.extend(batch[0])
150
  n_processed += batch_size
151
+ random.seed(a=random.randint(0, 2**32 - 1))
152
  for result in textcat_results:
153
  result["text"] = result["input_text"]
154
 
 
166
  labels_batch = list(labeller_generator.process(inputs=batch))
167
  labeller_results.extend(labels_batch[0])
168
  n_processed += batch_size
169
+ random.seed(a=random.randint(0, 2**32 - 1))
170
  progress(
171
  1,
172
  total=total_steps,
 
181
 
182
  dataframe = pd.DataFrame(distiset_results)
183
  if multi_label:
184
+
185
+ def _validate_labels(x):
186
+ if isinstance(x, str): # single label
187
+ return [x.lower().strip()]
188
+ elif isinstance(x, list): # multiple labels
189
+ return list(
190
+ set(
191
  label.lower().strip()
 
 
192
  for label in x
193
+ if label.lower().strip() in labels
194
+ )
195
  )
196
+ else:
197
+ return list(set([random.choice(labels)]))
198
+
199
+ dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
200
  dataframe = dataframe[dataframe["labels"].notna()]
201
  else:
202
+
203
+ def _validate_labels(x):
204
+ if isinstance(x, str) and x.lower().strip() in labels:
205
+ return x.lower().strip()
206
+ elif isinstance(x, list):
207
+ options = [
208
+ label.lower().strip()
209
+ for label in x
210
+ if isinstance(label, str) and label.lower().strip() in labels
211
+ ]
212
+ if options:
213
+ return random.choice(options)
214
+ else:
215
+ return random.choice(labels)
216
+ else:
217
+ return random.choice(labels)
218
+
219
  dataframe = dataframe.rename(columns={"labels": "label"})
220
+ dataframe["label"] = dataframe["label"].apply(_validate_labels)
 
 
 
 
221
  dataframe = dataframe[dataframe["text"].notna()]
222
 
223
  progress(1.0, desc="Dataset created")
 
255
  dataframe.reset_index(drop=True),
256
  features=features,
257
  )
258
+ dataset = combine_datasets(repo_id, dataset, oauth_token)
259
  distiset = Distiset({"default": dataset})
260
  progress(0.9, desc="Pushing dataset")
261
  distiset.push_to_hub(
 
667
 
668
  app.load(fn=swap_visibility, outputs=main_ui)
669
  app.load(fn=get_org_dropdown, outputs=[org_name])
670
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
src/synthetic_dataset_generator/constants.py CHANGED
@@ -7,39 +7,66 @@ import argilla as rg
7
  TEXTCAT_TASK = "text_classification"
8
  SFT_TASK = "supervised_fine_tuning"
9
 
10
- # Hugging Face
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  if not HF_TOKEN:
13
  raise ValueError(
14
  "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
15
  )
16
 
17
- # Inference
18
- MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
19
- MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000))
20
- DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
21
- MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
22
- BASE_URL = os.getenv("BASE_URL", default=None)
23
-
24
  _API_KEY = os.getenv("API_KEY")
25
- if _API_KEY:
26
- API_KEYS = [_API_KEY]
27
- else:
28
- API_KEYS = [os.getenv("HF_TOKEN")] + [
29
- os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)
30
- ]
31
  API_KEYS = [token for token in API_KEYS if token]
32
 
33
  # Determine if SFT is available
34
  SFT_AVAILABLE = False
35
  llama_options = ["llama3", "llama-3", "llama 3"]
36
  qwen_options = ["qwen2", "qwen-2", "qwen 2"]
37
- if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"):
 
38
  SFT_AVAILABLE = True
39
- passed_pre_query_template = os.getenv("MAGPIE_PRE_QUERY_TEMPLATE")
40
- if passed_pre_query_template.lower() in llama_options:
41
  MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
42
- elif passed_pre_query_template.lower() in qwen_options:
43
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
44
  else:
45
  MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
@@ -54,12 +81,12 @@ elif MODEL.lower() in qwen_options or any(
54
  SFT_AVAILABLE = True
55
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
56
 
57
- if BASE_URL:
58
  SFT_AVAILABLE = False
59
 
60
  if not SFT_AVAILABLE:
61
  warnings.warn(
62
- message="`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints to generate chat data."
63
  )
64
  MAGPIE_PRE_QUERY_TEMPLATE = None
65
 
@@ -67,11 +94,12 @@ if not SFT_AVAILABLE:
67
  STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
68
 
69
  # Argilla
70
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
71
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
72
- if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
73
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
74
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
 
75
 
76
  if not ARGILLA_API_URL or not ARGILLA_API_KEY:
77
  warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
 
7
  TEXTCAT_TASK = "text_classification"
8
  SFT_TASK = "supervised_fine_tuning"
9
 
10
+ # Inference
11
+ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
12
+ MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
13
+ DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
14
+
15
+ # Models
16
+ MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
17
+ TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
18
+ OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
19
+ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
20
+ HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
21
+ VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
22
+
23
+ # check if model is set correctly
24
+ if HUGGINGFACE_BASE_URL and MODEL:
25
+ raise ValueError(
26
+ "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
27
+ )
28
+ if not MODEL:
29
+ if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
30
+ raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
31
+
32
+ # Check if multiple base URLs are provided
33
+ base_urls = [
34
+ url
35
+ for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
36
+ if url
37
+ ]
38
+ if len(base_urls) > 1:
39
+ raise ValueError(
40
+ f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
41
+ )
42
+ BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
43
+
44
+
45
+ # API Keys
46
  HF_TOKEN = os.getenv("HF_TOKEN")
47
  if not HF_TOKEN:
48
  raise ValueError(
49
  "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
50
  )
51
 
 
 
 
 
 
 
 
52
  _API_KEY = os.getenv("API_KEY")
53
+ API_KEYS = (
54
+ [_API_KEY]
55
+ if _API_KEY
56
+ else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
57
+ )
 
58
  API_KEYS = [token for token in API_KEYS if token]
59
 
60
  # Determine if SFT is available
61
  SFT_AVAILABLE = False
62
  llama_options = ["llama3", "llama-3", "llama 3"]
63
  qwen_options = ["qwen2", "qwen-2", "qwen 2"]
64
+
65
+ if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
66
  SFT_AVAILABLE = True
67
+ if passed_pre_query_template in llama_options:
 
68
  MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
69
+ elif passed_pre_query_template in qwen_options:
70
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
71
  else:
72
  MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
 
81
  SFT_AVAILABLE = True
82
  MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
83
 
84
+ if OPENAI_BASE_URL:
85
  SFT_AVAILABLE = False
86
 
87
  if not SFT_AVAILABLE:
88
  warnings.warn(
89
+ "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
90
  )
91
  MAGPIE_PRE_QUERY_TEMPLATE = None
92
 
 
94
  STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
95
 
96
  # Argilla
97
+ ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
98
+ "ARGILLA_API_URL_SDG_REVIEWER"
99
+ )
100
+ ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
101
+ "ARGILLA_API_KEY_SDG_REVIEWER"
102
+ )
103
 
104
  if not ARGILLA_API_URL or not ARGILLA_API_KEY:
105
  warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -1,4 +1,21 @@
1
- from synthetic_dataset_generator.constants import API_KEYS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  TOKEN_INDEX = 0
4
 
@@ -8,3 +25,117 @@ def _get_next_api_key():
8
  api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
9
  TOKEN_INDEX += 1
10
  return api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import gradio as gr
5
+ from distilabel.llms import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
6
+ from distilabel.steps.tasks import TextGeneration
7
+
8
+ from synthetic_dataset_generator.constants import (
9
+ API_KEYS,
10
+ DEFAULT_BATCH_SIZE,
11
+ HUGGINGFACE_BASE_URL,
12
+ MAGPIE_PRE_QUERY_TEMPLATE,
13
+ MODEL,
14
+ OLLAMA_BASE_URL,
15
+ OPENAI_BASE_URL,
16
+ TOKENIZER_ID,
17
+ VLLM_BASE_URL,
18
+ )
19
 
20
  TOKEN_INDEX = 0
21
 
 
25
  api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
26
  TOKEN_INDEX += 1
27
  return api_key
28
+
29
+
30
+ def _get_prompt_rewriter():
31
+ generation_kwargs = {
32
+ "temperature": 1,
33
+ }
34
+ system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
35
+ prompt_rewriter = TextGeneration(
36
+ llm=_get_llm(generation_kwargs=generation_kwargs),
37
+ system_prompt=system_prompt,
38
+ use_system_prompt=True,
39
+ )
40
+ prompt_rewriter.load()
41
+ return prompt_rewriter
42
+
43
+
44
+ def get_rewriten_prompts(prompt: str, num_rows: int):
45
+ prompt_rewriter = _get_prompt_rewriter()
46
+ # create prompt rewrites
47
+ inputs = [
48
+ {"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
49
+ for i in range(math.floor(num_rows / 100))
50
+ ]
51
+ n_processed = 0
52
+ prompt_rewrites = [prompt]
53
+ while n_processed < num_rows:
54
+ batch = list(
55
+ prompt_rewriter.process(
56
+ inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
57
+ )
58
+ )
59
+ prompt_rewrites += [entry["generation"] for entry in batch[0]]
60
+ n_processed += DEFAULT_BATCH_SIZE
61
+ random.seed(a=random.randint(0, 2**32 - 1))
62
+ return prompt_rewrites
63
+
64
+
65
+ def _get_llm(use_magpie_template=False, **kwargs):
66
+ if OPENAI_BASE_URL:
67
+ llm = OpenAILLM(
68
+ model=MODEL,
69
+ base_url=OPENAI_BASE_URL,
70
+ api_key=_get_next_api_key(),
71
+ **kwargs,
72
+ )
73
+ if "generation_kwargs" in kwargs:
74
+ if "stop_sequences" in kwargs["generation_kwargs"]:
75
+ kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
76
+ "stop_sequences"
77
+ ]
78
+ del kwargs["generation_kwargs"]["stop_sequences"]
79
+ if "do_sample" in kwargs["generation_kwargs"]:
80
+ del kwargs["generation_kwargs"]["do_sample"]
81
+ elif OLLAMA_BASE_URL:
82
+ if "generation_kwargs" in kwargs:
83
+ if "max_new_tokens" in kwargs["generation_kwargs"]:
84
+ kwargs["generation_kwargs"]["num_predict"] = kwargs[
85
+ "generation_kwargs"
86
+ ]["max_new_tokens"]
87
+ del kwargs["generation_kwargs"]["max_new_tokens"]
88
+ if "stop_sequences" in kwargs["generation_kwargs"]:
89
+ kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
90
+ "stop_sequences"
91
+ ]
92
+ del kwargs["generation_kwargs"]["stop_sequences"]
93
+ if "do_sample" in kwargs["generation_kwargs"]:
94
+ del kwargs["generation_kwargs"]["do_sample"]
95
+ options = kwargs["generation_kwargs"]
96
+ del kwargs["generation_kwargs"]
97
+ kwargs["generation_kwargs"] = {}
98
+ kwargs["generation_kwargs"]["options"] = options
99
+ llm = OllamaLLM(
100
+ model=MODEL,
101
+ host=OLLAMA_BASE_URL,
102
+ tokenizer_id=TOKENIZER_ID or MODEL,
103
+ **kwargs,
104
+ )
105
+ elif HUGGINGFACE_BASE_URL:
106
+ kwargs["generation_kwargs"]["do_sample"] = True
107
+ llm = InferenceEndpointsLLM(
108
+ api_key=_get_next_api_key(),
109
+ base_url=HUGGINGFACE_BASE_URL,
110
+ tokenizer_id=TOKENIZER_ID or MODEL,
111
+ **kwargs,
112
+ )
113
+ elif VLLM_BASE_URL:
114
+ if "generation_kwargs" in kwargs:
115
+ if "do_sample" in kwargs["generation_kwargs"]:
116
+ del kwargs["generation_kwargs"]["do_sample"]
117
+ llm = ClientvLLM(
118
+ base_url=VLLM_BASE_URL,
119
+ model=MODEL,
120
+ tokenizer=TOKENIZER_ID or MODEL,
121
+ api_key=_get_next_api_key(),
122
+ **kwargs,
123
+ )
124
+ else:
125
+ llm = InferenceEndpointsLLM(
126
+ api_key=_get_next_api_key(),
127
+ tokenizer_id=TOKENIZER_ID or MODEL,
128
+ model_id=MODEL,
129
+ magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
130
+ **kwargs,
131
+ )
132
+
133
+ return llm
134
+
135
+
136
+ try:
137
+ llm = _get_llm()
138
+ llm.load()
139
+ llm.generate([[{"content": "Hello, world!", "role": "user"}]])
140
+ except Exception as e:
141
+ gr.Error(f"Error loading {llm.__class__.__name__}: {e}")
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -1,4 +1,3 @@
1
- from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
4
  from synthetic_dataset_generator.constants import (
@@ -7,7 +6,7 @@ from synthetic_dataset_generator.constants import (
7
  MAX_NUM_TOKENS,
8
  MODEL,
9
  )
10
- from synthetic_dataset_generator.pipelines.base import _get_next_api_key
11
 
12
  INFORMATION_SEEKING_PROMPT = (
13
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -149,18 +148,13 @@ def _get_output_mappings(num_turns):
149
 
150
 
151
  def get_prompt_generator():
 
 
 
 
 
152
  prompt_generator = TextGeneration(
153
- llm=InferenceEndpointsLLM(
154
- api_key=_get_next_api_key(),
155
- model_id=MODEL,
156
- tokenizer_id=MODEL,
157
- base_url=BASE_URL,
158
- generation_kwargs={
159
- "temperature": 0.8,
160
- "max_new_tokens": MAX_NUM_TOKENS,
161
- "do_sample": True,
162
- },
163
- ),
164
  system_prompt=PROMPT_CREATION_PROMPT,
165
  use_system_prompt=True,
166
  )
@@ -172,38 +166,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
172
  input_mappings = _get_output_mappings(num_turns)
173
  output_mappings = input_mappings.copy()
174
  if num_turns == 1:
 
 
 
 
 
 
175
  magpie_generator = Magpie(
176
- llm=InferenceEndpointsLLM(
177
- model_id=MODEL,
178
- tokenizer_id=MODEL,
179
- base_url=BASE_URL,
180
- api_key=_get_next_api_key(),
181
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
182
- generation_kwargs={
183
- "temperature": temperature,
184
- "do_sample": True,
185
- "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
186
- "stop_sequences": _STOP_SEQUENCES,
187
- },
188
  ),
189
  n_turns=num_turns,
190
  output_mappings=output_mappings,
191
  only_instruction=True,
192
  )
193
  else:
 
 
 
 
 
 
194
  magpie_generator = Magpie(
195
- llm=InferenceEndpointsLLM(
196
- model_id=MODEL,
197
- tokenizer_id=MODEL,
198
- base_url=BASE_URL,
199
- api_key=_get_next_api_key(),
200
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
201
- generation_kwargs={
202
- "temperature": temperature,
203
- "do_sample": True,
204
- "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
205
- "stop_sequences": _STOP_SEQUENCES,
206
- },
207
  ),
208
  end_with_user=True,
209
  n_turns=num_turns,
@@ -213,51 +203,25 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
213
  return magpie_generator
214
 
215
 
216
- def get_prompt_rewriter():
217
- prompt_rewriter = TextGeneration(
218
- llm=InferenceEndpointsLLM(
219
- model_id=MODEL,
220
- tokenizer_id=MODEL,
221
- base_url=BASE_URL,
222
- api_key=_get_next_api_key(),
223
- generation_kwargs={
224
- "temperature": 1,
225
- },
226
- ),
227
- )
228
- prompt_rewriter.load()
229
- return prompt_rewriter
230
-
231
-
232
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
233
  if num_turns == 1:
 
 
 
 
234
  response_generator = TextGeneration(
235
- llm=InferenceEndpointsLLM(
236
- model_id=MODEL,
237
- tokenizer_id=MODEL,
238
- base_url=BASE_URL,
239
- api_key=_get_next_api_key(),
240
- generation_kwargs={
241
- "temperature": temperature,
242
- "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
243
- },
244
- ),
245
  system_prompt=system_prompt,
246
  output_mappings={"generation": "completion"},
247
  input_mappings={"instruction": "prompt"},
248
  )
249
  else:
 
 
 
 
250
  response_generator = ChatGeneration(
251
- llm=InferenceEndpointsLLM(
252
- model_id=MODEL,
253
- tokenizer_id=MODEL,
254
- base_url=BASE_URL,
255
- api_key=_get_next_api_key(),
256
- generation_kwargs={
257
- "temperature": temperature,
258
- "max_new_tokens": MAX_NUM_TOKENS,
259
- },
260
- ),
261
  output_mappings={"generation": "completion"},
262
  input_mappings={"conversation": "messages"},
263
  )
@@ -293,7 +257,7 @@ with Pipeline(name="sft") as pipeline:
293
  "max_new_tokens": {MAX_NUM_TOKENS},
294
  "stop_sequences": {_STOP_SEQUENCES}
295
  }},
296
- api_key=os.environ["BASE_URL"],
297
  ),
298
  n_turns={num_turns},
299
  num_rows={num_rows},
 
 
1
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
2
 
3
  from synthetic_dataset_generator.constants import (
 
6
  MAX_NUM_TOKENS,
7
  MODEL,
8
  )
9
+ from synthetic_dataset_generator.pipelines.base import _get_llm
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
148
 
149
 
150
  def get_prompt_generator():
151
+ generation_kwargs = {
152
+ "temperature": 0.8,
153
+ "max_new_tokens": MAX_NUM_TOKENS,
154
+ "do_sample": True,
155
+ }
156
  prompt_generator = TextGeneration(
157
+ llm=_get_llm(generation_kwargs=generation_kwargs),
 
 
 
 
 
 
 
 
 
 
158
  system_prompt=PROMPT_CREATION_PROMPT,
159
  use_system_prompt=True,
160
  )
 
166
  input_mappings = _get_output_mappings(num_turns)
167
  output_mappings = input_mappings.copy()
168
  if num_turns == 1:
169
+ generation_kwargs = {
170
+ "temperature": temperature,
171
+ "do_sample": True,
172
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
173
+ "stop_sequences": _STOP_SEQUENCES,
174
+ }
175
  magpie_generator = Magpie(
176
+ llm=_get_llm(
177
+ generation_kwargs=generation_kwargs,
 
 
 
178
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
179
+ use_magpie_template=True,
 
 
 
 
 
180
  ),
181
  n_turns=num_turns,
182
  output_mappings=output_mappings,
183
  only_instruction=True,
184
  )
185
  else:
186
+ generation_kwargs = {
187
+ "temperature": temperature,
188
+ "do_sample": True,
189
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
190
+ "stop_sequences": _STOP_SEQUENCES,
191
+ }
192
  magpie_generator = Magpie(
193
+ llm=_get_llm(
194
+ generation_kwargs=generation_kwargs,
 
 
 
195
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
196
+ use_magpie_template=True,
 
 
 
 
 
197
  ),
198
  end_with_user=True,
199
  n_turns=num_turns,
 
203
  return magpie_generator
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
207
  if num_turns == 1:
208
+ generation_kwargs = {
209
+ "temperature": temperature,
210
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
211
+ }
212
  response_generator = TextGeneration(
213
+ llm=_get_llm(generation_kwargs=generation_kwargs),
 
 
 
 
 
 
 
 
 
214
  system_prompt=system_prompt,
215
  output_mappings={"generation": "completion"},
216
  input_mappings={"instruction": "prompt"},
217
  )
218
  else:
219
+ generation_kwargs = {
220
+ "temperature": temperature,
221
+ "max_new_tokens": MAX_NUM_TOKENS,
222
+ }
223
  response_generator = ChatGeneration(
224
+ llm=_get_llm(generation_kwargs=generation_kwargs),
 
 
 
 
 
 
 
 
 
225
  output_mappings={"generation": "completion"},
226
  input_mappings={"conversation": "messages"},
227
  )
 
257
  "max_new_tokens": {MAX_NUM_TOKENS},
258
  "stop_sequences": {_STOP_SEQUENCES}
259
  }},
260
+ api_key=os.environ["API_KEY"],
261
  ),
262
  n_turns={num_turns},
263
  num_rows={num_rows},
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,7 +1,6 @@
1
  import random
2
  from typing import List
3
 
4
- from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
7
  TextClassification,
@@ -9,8 +8,12 @@ from distilabel.steps.tasks import (
9
  )
10
  from pydantic import BaseModel, Field
11
 
12
- from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
13
- from synthetic_dataset_generator.pipelines.base import _get_next_api_key
 
 
 
 
14
  from synthetic_dataset_generator.utils import get_preprocess_labels
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
@@ -69,23 +72,10 @@ def get_prompt_generator():
69
  "temperature": 0.8,
70
  "max_new_tokens": MAX_NUM_TOKENS,
71
  }
72
- if BASE_URL:
73
- llm = OpenAILLM(
74
- model=MODEL,
75
- base_url=BASE_URL,
76
- api_key=_get_next_api_key(),
77
- structured_output=structured_output,
78
- generation_kwargs=generation_kwargs,
79
- )
80
- else:
81
- generation_kwargs["do_sample"] = True
82
- llm = InferenceEndpointsLLM(
83
- api_key=_get_next_api_key(),
84
- model_id=MODEL,
85
- base_url=BASE_URL,
86
- structured_output=structured_output,
87
- generation_kwargs=generation_kwargs,
88
- )
89
 
90
  prompt_generator = TextGeneration(
91
  llm=llm,
@@ -103,22 +93,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
103
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
104
  "top_p": 0.95,
105
  }
106
- if BASE_URL:
107
- llm = OpenAILLM(
108
- model=MODEL,
109
- base_url=BASE_URL,
110
- api_key=_get_next_api_key(),
111
- generation_kwargs=generation_kwargs,
112
- )
113
- else:
114
- generation_kwargs["do_sample"] = True
115
- llm = InferenceEndpointsLLM(
116
- model_id=MODEL,
117
- base_url=BASE_URL,
118
- api_key=_get_next_api_key(),
119
- generation_kwargs=generation_kwargs,
120
- )
121
-
122
  textcat_generator = GenerateTextClassificationData(
123
  llm=llm,
124
  difficulty=None if difficulty == "mixed" else difficulty,
@@ -134,22 +109,7 @@ def get_labeller_generator(system_prompt, labels, multi_label):
134
  "temperature": 0.01,
135
  "max_new_tokens": MAX_NUM_TOKENS,
136
  }
137
-
138
- if BASE_URL:
139
- llm = OpenAILLM(
140
- model=MODEL,
141
- base_url=BASE_URL,
142
- api_key=_get_next_api_key(),
143
- generation_kwargs=generation_kwargs,
144
- )
145
- else:
146
- llm = InferenceEndpointsLLM(
147
- model_id=MODEL,
148
- base_url=BASE_URL,
149
- api_key=_get_next_api_key(),
150
- generation_kwargs=generation_kwargs,
151
- )
152
-
153
  labeller_generator = TextClassification(
154
  llm=llm,
155
  context=system_prompt,
 
1
  import random
2
  from typing import List
3
 
 
4
  from distilabel.steps.tasks import (
5
  GenerateTextClassificationData,
6
  TextClassification,
 
8
  )
9
  from pydantic import BaseModel, Field
10
 
11
+ from synthetic_dataset_generator.constants import (
12
+ BASE_URL,
13
+ MAX_NUM_TOKENS,
14
+ MODEL,
15
+ )
16
+ from synthetic_dataset_generator.pipelines.base import _get_llm
17
  from synthetic_dataset_generator.utils import get_preprocess_labels
18
 
19
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
 
72
  "temperature": 0.8,
73
  "max_new_tokens": MAX_NUM_TOKENS,
74
  }
75
+ llm = _get_llm(
76
+ structured_output=structured_output,
77
+ generation_kwargs=generation_kwargs,
78
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  prompt_generator = TextGeneration(
81
  llm=llm,
 
93
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
94
  "top_p": 0.95,
95
  }
96
+ llm = _get_llm(generation_kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  textcat_generator = GenerateTextClassificationData(
98
  llm=llm,
99
  difficulty=None if difficulty == "mixed" else difficulty,
 
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
+ llm = _get_llm(generation_kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,
src/synthetic_dataset_generator/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import warnings
3
  from typing import List, Optional, Union
4
 
@@ -55,6 +56,10 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None):
55
  return organizations
56
 
57
 
 
 
 
 
58
  def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
59
  if oauth_token is not None:
60
  orgs = list_orgs(oauth_token)
 
1
  import json
2
+ import uuid
3
  import warnings
4
  from typing import List, Optional, Union
5
 
 
56
  return organizations
57
 
58
 
59
+ def get_random_repo_name():
60
+ return f"my-distiset-{str(uuid.uuid4())[:8]}"
61
+
62
+
63
  def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None):
64
  if oauth_token is not None:
65
  orgs = list_orgs(oauth_token)