davidberenstein1957 HF staff commited on
Commit
c1b3b74
1 Parent(s): 753af07

feat: move generation outside of pipeline

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,23 +1,24 @@
1
  import io
2
- import multiprocessing
3
- import time
4
  from typing import Union
5
 
6
  import gradio as gr
7
  import pandas as pd
8
  from datasets import Dataset
9
  from distilabel.distiset import Distiset
 
10
  from gradio.oauth import OAuthToken
11
  from huggingface_hub import upload_file
12
 
13
  from src.distilabel_dataset_generator.pipelines.sft import (
 
14
  DEFAULT_DATASET_DESCRIPTIONS,
15
  DEFAULT_DATASETS,
16
  DEFAULT_SYSTEM_PROMPTS,
17
  PROMPT_CREATION_PROMPT,
18
  generate_pipeline_code,
19
- get_pipeline,
20
- get_prompt_generation_step,
 
21
  )
22
  from src.distilabel_dataset_generator.utils import (
23
  get_login_button,
@@ -26,22 +27,15 @@ from src.distilabel_dataset_generator.utils import (
26
  )
27
 
28
 
29
- def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, is_sample):
30
- pipeline = get_pipeline(num_turns, num_rows, system_prompt, is_sample)
31
- distiset: Distiset = pipeline.run(use_cache=False)
32
- result_queue.put(distiset)
33
-
34
-
35
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
 
36
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
37
  index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
38
  if index < len(DEFAULT_SYSTEM_PROMPTS):
39
  return DEFAULT_SYSTEM_PROMPTS[index]
40
 
41
- progress(0.1, desc="Initializing text generation")
42
- generate_description = get_prompt_generation_step()
43
- progress(0.4, desc="Loading model")
44
- generate_description.load()
45
  progress(0.7, desc="Generating system prompt")
46
  result = next(
47
  generate_description.process(
@@ -62,12 +56,9 @@ def generate_sample_dataset(system_prompt, progress=gr.Progress()):
62
  index = DEFAULT_SYSTEM_PROMPTS.index(system_prompt)
63
  if index < len(DEFAULT_DATASETS):
64
  return DEFAULT_DATASETS[index]
65
-
66
- progress(0.1, desc="Initializing sample dataset generation")
67
  result = generate_dataset(
68
  system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
69
  )
70
- progress(1.0, desc="Sample dataset generated")
71
  return result
72
 
73
 
@@ -92,52 +83,98 @@ def generate_dataset(
92
  is_sample: bool = False,
93
  progress=gr.Progress(),
94
  ):
95
- if num_rows < 5:
96
- duration = 25
97
- elif num_rows < 10:
98
- duration = 60
99
- elif num_rows < 30:
100
- duration = 120
101
- elif num_rows < 100:
102
- duration = 240
103
- elif num_rows < 300:
104
- duration = 600
105
- elif num_rows < 1000:
106
- duration = 1200
107
- else:
108
- duration = 2400
109
-
110
- result_queue = multiprocessing.Queue()
111
- p = multiprocessing.Process(
112
- target=_run_pipeline,
113
- args=(result_queue, num_turns, num_rows, system_prompt, is_sample),
114
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- try:
117
- p.start()
118
- total_steps = 100
119
- for step in range(total_steps):
120
- if not p.is_alive() or p._popen.poll() is not None:
121
- break
122
  progress(
123
- (step + 1) / total_steps,
124
- desc=f"Generating dataset with {num_rows} rows. Don't close this window.",
 
125
  )
126
- time.sleep(duration / total_steps) # Adjust this value based on your needs
127
- p.join()
128
- except Exception as e:
129
- raise gr.Error(f"An error occurred during dataset generation: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- distiset = result_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # If not pushing to hub generate the dataset directly
134
- distiset = distiset["default"]["train"]
135
  if num_turns == 1:
136
- outputs = distiset.to_pandas()[["prompt", "completion"]]
137
  else:
138
  outputs = distiset.to_pandas()[["messages"]]
139
  dataframe = pd.DataFrame(outputs)
140
-
141
  progress(1.0, desc="Dataset generation completed")
142
  return dataframe
143
 
@@ -233,7 +270,7 @@ with gr.Blocks(
233
  )
234
 
235
  with gr.Row():
236
- sample_dataset = gr.DataFrame(
237
  value=DEFAULT_DATASETS[0],
238
  label="Sample dataset. Prompts and completions truncated to 256 tokens.",
239
  interactive=False,
@@ -311,7 +348,7 @@ with gr.Blocks(
311
  value="Push to Hub", variant="primary", scale=2
312
  )
313
  with gr.Row():
314
- final_dataset = gr.DataFrame(
315
  value=DEFAULT_DATASETS[0],
316
  label="Generated dataset",
317
  interactive=False,
 
1
  import io
 
 
2
  from typing import Union
3
 
4
  import gradio as gr
5
  import pandas as pd
6
  from datasets import Dataset
7
  from distilabel.distiset import Distiset
8
+ from distilabel.steps.tasks.text_generation import TextGeneration
9
  from gradio.oauth import OAuthToken
10
  from huggingface_hub import upload_file
11
 
12
  from src.distilabel_dataset_generator.pipelines.sft import (
13
+ DEFAULT_BATCH_SIZE,
14
  DEFAULT_DATASET_DESCRIPTIONS,
15
  DEFAULT_DATASETS,
16
  DEFAULT_SYSTEM_PROMPTS,
17
  PROMPT_CREATION_PROMPT,
18
  generate_pipeline_code,
19
+ get_magpie_generator,
20
+ get_prompt_generator,
21
+ get_response_generator,
22
  )
23
  from src.distilabel_dataset_generator.utils import (
24
  get_login_button,
 
27
  )
28
 
29
 
 
 
 
 
 
 
30
  def generate_system_prompt(dataset_description, progress=gr.Progress()):
31
+ progress(0.0, desc="Generating system prompt")
32
  if dataset_description in DEFAULT_DATASET_DESCRIPTIONS:
33
  index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description)
34
  if index < len(DEFAULT_SYSTEM_PROMPTS):
35
  return DEFAULT_SYSTEM_PROMPTS[index]
36
 
37
+ progress(0.3, desc="Initializing text generation")
38
+ generate_description: TextGeneration = get_prompt_generator()
 
 
39
  progress(0.7, desc="Generating system prompt")
40
  result = next(
41
  generate_description.process(
 
56
  index = DEFAULT_SYSTEM_PROMPTS.index(system_prompt)
57
  if index < len(DEFAULT_DATASETS):
58
  return DEFAULT_DATASETS[index]
 
 
59
  result = generate_dataset(
60
  system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
61
  )
 
62
  return result
63
 
64
 
 
83
  is_sample: bool = False,
84
  progress=gr.Progress(),
85
  ):
86
+ progress(0.0, desc="(1/2) Generating instructions")
87
+ magpie_generator = get_magpie_generator(
88
+ num_turns, num_rows, system_prompt, is_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
+ response_generator = get_response_generator(num_turns, system_prompt, is_sample)
91
+ total_steps: int = num_rows * 2
92
+ batch_size = DEFAULT_BATCH_SIZE
93
+
94
+ # create instructions
95
+ magpie_results = []
96
+ for i in range(0, num_rows, batch_size):
97
+ progress(
98
+ 0.5 * min(i + batch_size, num_rows) / num_rows,
99
+ total=total_steps,
100
+ desc="(1/2) Generating instructions",
101
+ )
102
+ batch = list(magpie_generator.process())[:batch_size]
103
+ magpie_results.extend([item[0] for item in batch])
104
+ progress(0.5, desc="(1/2) Generating instructions")
105
 
106
+ # generate responses
107
+ response_results = []
108
+ if num_turns == 1:
109
+ for i in range(0, num_rows, batch_size):
 
 
110
  progress(
111
+ 0.5 + 0.5 * min(i + batch_size, num_rows) / num_rows,
112
+ total=total_steps,
113
+ desc="(2/2) Generating responses",
114
  )
115
+ batch = magpie_results[i : i + batch_size]
116
+ batch = [entry[0] for entry in batch]
117
+ responses = list(response_generator.process(inputs=batch))
118
+ response_results.extend(responses)
119
+ for result in response_results[0]:
120
+ result["prompt"] = result["instruction"]
121
+ result["completion"] = result["generation"]
122
+ result["system_prompt"] = system_prompt
123
+ else:
124
+ for result in magpie_results:
125
+ result[0]["conversation"].insert(
126
+ 0, {"role": "system", "content": system_prompt}
127
+ )
128
+ result[0]["messages"] = result[0]["conversation"]
129
+ for i in range(0, num_rows, batch_size):
130
+ progress(
131
+ 0.5 + 0.5 * min(i + batch_size, num_rows) / num_rows,
132
+ total=total_steps,
133
+ desc="(2/2) Generating responses",
134
+ )
135
+ batch = magpie_results[i : i + batch_size]
136
+ batch = [entry[0] for entry in batch]
137
+ responses = list(response_generator.process(inputs=batch))
138
+ response_results.extend(responses)
139
+
140
+ for result in response_results[0]:
141
+ result["messages"].append(
142
+ {"role": "assistant", "content": result["generation"]}
143
+ )
144
+ progress(
145
+ 1,
146
+ total=total_steps,
147
+ desc="(2/2) Generating responses",
148
+ )
149
 
150
+ # create distiset
151
+ distiset_results = []
152
+ for result in response_results[0]:
153
+ record = {}
154
+ for relevant_keys in [
155
+ "messages",
156
+ "prompt",
157
+ "completion",
158
+ "model_name",
159
+ "system_prompt",
160
+ ]:
161
+ if relevant_keys in result:
162
+ record[relevant_keys] = result[relevant_keys]
163
+ distiset_results.append(record)
164
+
165
+ distiset = Distiset(
166
+ {
167
+ "default": Dataset.from_list(distiset_results),
168
+ }
169
+ )
170
 
171
  # If not pushing to hub generate the dataset directly
172
+ distiset = distiset["default"]
173
  if num_turns == 1:
174
+ outputs = distiset.to_pandas()[["system_prompt", "prompt", "completion"]]
175
  else:
176
  outputs = distiset.to_pandas()[["messages"]]
177
  dataframe = pd.DataFrame(outputs)
 
178
  progress(1.0, desc="Dataset generation completed")
179
  return dataframe
180
 
 
270
  )
271
 
272
  with gr.Row():
273
+ sample_dataset = gr.Dataframe(
274
  value=DEFAULT_DATASETS[0],
275
  label="Sample dataset. Prompts and completions truncated to 256 tokens.",
276
  interactive=False,
 
348
  value="Push to Hub", variant="primary", scale=2
349
  )
350
  with gr.Row():
351
+ final_dataset = gr.Dataframe(
352
  value=DEFAULT_DATASETS[0],
353
  label="Generated dataset",
354
  interactive=False,
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,10 +1,12 @@
1
  import pandas as pd
 
 
2
  from distilabel.llms import InferenceEndpointsLLM
3
  from distilabel.pipeline import Pipeline
4
  from distilabel.steps import KeepColumns
5
- from distilabel.steps.tasks import MagpieGenerator, TextGeneration
6
 
7
- from src.distilabel_dataset_generator.utils import HF_TOKENS
8
 
9
  INFORMATION_SEEKING_PROMPT = (
10
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -118,7 +120,7 @@ The prompt you write should follow the same style and structure as the following
118
  User dataset description:
119
  """
120
 
121
- MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct"
122
  DEFAULT_DATASET_DESCRIPTIONS = (
123
  "rude customer assistant for a phone company",
124
  "assistant that solves math puzzles using python",
@@ -155,7 +157,7 @@ _STOP_SEQUENCES = [
155
  "assistant",
156
  " \n\n",
157
  ]
158
- DEFAULT_BATCH_SIZE = 50
159
  TOKEN_INDEX = 0
160
 
161
 
@@ -198,7 +200,7 @@ with Pipeline(name="sft") as pipeline:
198
  output_mappings={input_mappings},
199
  )
200
  keep_columns = KeepColumns(
201
- columns={list(input_mappings.values())} + ["model_name"],
202
  )
203
  magpie.connect(keep_columns)
204
 
@@ -208,92 +210,101 @@ if __name__ == "__main__":
208
  return code
209
 
210
 
211
- def get_pipeline(num_turns, num_rows, system_prompt, is_sample):
212
  global TOKEN_INDEX
213
- input_mappings = _get_output_mappings(num_turns)
214
- output_mappings = input_mappings
215
  api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
216
  TOKEN_INDEX += 1
217
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
218
- print("is sample?", is_sample)
219
- if num_turns == 1:
220
- with Pipeline(name="sft") as pipeline:
221
- magpie = MagpieGenerator(
222
- llm=InferenceEndpointsLLM(
223
- model_id=MODEL,
224
- tokenizer_id=MODEL,
225
- api_key=api_key,
226
- magpie_pre_query_template="llama3",
227
- generation_kwargs={
228
- "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
229
- "do_sample": True,
230
- "max_new_tokens": 256 if is_sample else 512,
231
- "stop_sequences": _STOP_SEQUENCES,
232
- },
233
- ),
234
- batch_size=DEFAULT_BATCH_SIZE,
235
- n_turns=num_turns,
236
- num_rows=num_rows,
237
- system_prompt=system_prompt,
238
- output_mappings={"instruction": "prompt"},
239
- only_instruction=True,
240
- )
241
 
242
- generate_response = TextGeneration(
243
- llm=InferenceEndpointsLLM(
244
- model_id=MODEL,
245
- tokenizer_id=MODEL,
246
- api_key=api_key,
247
- generation_kwargs={
248
- "temperature": 0.8,
249
- "max_new_tokens": 256 if is_sample else 1024,
250
- },
251
- ),
252
- system_prompt=system_prompt,
253
- output_mappings={"generation": "completion"},
254
- input_mappings={"instruction": "prompt"},
255
- )
256
 
257
- keep_columns = KeepColumns(
258
- columns=list(output_mappings.values()) + ["model_name"],
259
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- magpie.connect(generate_response)
262
- generate_response.connect(keep_columns)
263
- return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  else:
265
- with Pipeline(name="sft") as pipeline:
266
- magpie = MagpieGenerator(
267
- llm=InferenceEndpointsLLM(
268
- model_id=MODEL,
269
- tokenizer_id=MODEL,
270
- api_key=api_key,
271
- magpie_pre_query_template="llama3",
272
- generation_kwargs={
273
- "temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
274
- "do_sample": True,
275
- "max_new_tokens": 2048,
276
- "stop_sequences": _STOP_SEQUENCES,
277
- },
278
- ),
279
- batch_size=DEFAULT_BATCH_SIZE,
280
- n_turns=num_turns,
281
- num_rows=num_rows,
282
- system_prompt=system_prompt,
283
- output_mappings=output_mappings,
284
- )
285
- keep_columns = KeepColumns(
286
- columns=list(output_mappings.values()) + ["model_name"],
287
- )
288
- magpie.connect(keep_columns)
289
- return pipeline
290
 
291
 
292
- def get_prompt_generation_step():
293
  global TOKEN_INDEX
294
  api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
295
  TOKEN_INDEX += 1
296
- generate_description = TextGeneration(
297
  llm=InferenceEndpointsLLM(
298
  api_key=api_key,
299
  model_id=MODEL,
@@ -306,13 +317,30 @@ def get_prompt_generation_step():
306
  ),
307
  use_system_prompt=True,
308
  )
309
- return generate_description
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
 
312
  if __name__ == "__main__":
313
- prompt_generation_step = get_prompt_generation_step()
314
- prompt_generation_step.load()
315
- result = next(
316
  prompt_generation_step.process(
317
  [
318
  {
@@ -322,5 +350,64 @@ if __name__ == "__main__":
322
  ]
323
  )
324
  )[0]["generation"]
325
- pipeline = get_pipeline(num_rows=100, num_turns=1, system_prompt=result)
326
- pipeline.run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
+ from datasets import Dataset
3
+ from distilabel.distiset import Distiset
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.pipeline import Pipeline
6
  from distilabel.steps import KeepColumns
7
+ from distilabel.steps.tasks import ChatGeneration, MagpieGenerator, TextGeneration
8
 
9
+ from distilabel_dataset_generator.utils import HF_TOKENS
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
120
  User dataset description:
121
  """
122
 
123
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
124
  DEFAULT_DATASET_DESCRIPTIONS = (
125
  "rude customer assistant for a phone company",
126
  "assistant that solves math puzzles using python",
 
157
  "assistant",
158
  " \n\n",
159
  ]
160
+ DEFAULT_BATCH_SIZE = 5
161
  TOKEN_INDEX = 0
162
 
163
 
 
200
  output_mappings={input_mappings},
201
  )
202
  keep_columns = KeepColumns(
203
+ columns={list(input_mappings.values())} + ["model_name", "system_prompt"],
204
  )
205
  magpie.connect(keep_columns)
206
 
 
210
  return code
211
 
212
 
213
+ def _get_next_api_key():
214
  global TOKEN_INDEX
 
 
215
  api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
216
  TOKEN_INDEX += 1
217
+ return api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
221
+ input_mappings = _get_output_mappings(num_turns)
222
+ output_mappings = input_mappings.copy()
223
+ if num_turns == 1:
224
+ magpie_generator = MagpieGenerator(
225
+ llm=InferenceEndpointsLLM(
226
+ model_id=MODEL,
227
+ tokenizer_id=MODEL,
228
+ api_key=_get_next_api_key(),
229
+ magpie_pre_query_template="llama3",
230
+ generation_kwargs={
231
+ "temperature": 0.8,
232
+ "do_sample": True,
233
+ "max_new_tokens": 256 if is_sample else 512,
234
+ "stop_sequences": _STOP_SEQUENCES,
235
+ },
236
+ ),
237
+ batch_size=DEFAULT_BATCH_SIZE,
238
+ n_turns=num_turns,
239
+ num_rows=num_rows,
240
+ system_prompt=system_prompt,
241
+ output_mappings=output_mappings,
242
+ only_instruction=True,
243
+ )
244
+ else:
245
+ magpie_generator = MagpieGenerator(
246
+ llm=InferenceEndpointsLLM(
247
+ model_id=MODEL,
248
+ tokenizer_id=MODEL,
249
+ api_key=_get_next_api_key(),
250
+ magpie_pre_query_template="llama3",
251
+ generation_kwargs={
252
+ "temperature": 0.8,
253
+ "do_sample": True,
254
+ "max_new_tokens": 256 if is_sample else 1024,
255
+ "stop_sequences": _STOP_SEQUENCES,
256
+ },
257
+ ),
258
+ batch_size=DEFAULT_BATCH_SIZE,
259
+ end_with_user=True,
260
+ n_turns=num_turns,
261
+ num_rows=num_rows,
262
+ system_prompt=system_prompt,
263
+ output_mappings=output_mappings,
264
+ )
265
+ magpie_generator.load()
266
+ return magpie_generator
267
 
268
+
269
+ def get_response_generator(num_turns, system_prompt, is_sample):
270
+ if num_turns == 1:
271
+ response_generator = TextGeneration(
272
+ llm=InferenceEndpointsLLM(
273
+ model_id=MODEL,
274
+ tokenizer_id=MODEL,
275
+ api_key=_get_next_api_key(),
276
+ generation_kwargs={
277
+ "temperature": 0.8,
278
+ "max_new_tokens": 256 if is_sample else 1024,
279
+ },
280
+ ),
281
+ system_prompt=system_prompt,
282
+ output_mappings={"generation": "completion"},
283
+ input_mappings={"instruction": "prompt"},
284
+ )
285
  else:
286
+ response_generator = ChatGeneration(
287
+ llm=InferenceEndpointsLLM(
288
+ model_id=MODEL,
289
+ tokenizer_id=MODEL,
290
+ api_key=_get_next_api_key(),
291
+ generation_kwargs={
292
+ "temperature": 0.8,
293
+ "max_new_tokens": 2048,
294
+ },
295
+ ),
296
+ output_mappings={"generation": "completion"},
297
+ input_mappings={"conversation": "messages"},
298
+ )
299
+ response_generator.load()
300
+ return response_generator
 
 
 
 
 
 
 
 
 
 
301
 
302
 
303
+ def get_prompt_generator():
304
  global TOKEN_INDEX
305
  api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
306
  TOKEN_INDEX += 1
307
+ prompt_generator = TextGeneration(
308
  llm=InferenceEndpointsLLM(
309
  api_key=api_key,
310
  model_id=MODEL,
 
317
  ),
318
  use_system_prompt=True,
319
  )
320
+ prompt_generator.load()
321
+ return prompt_generator
322
+
323
+
324
+ def get_pipeline(num_turns, num_rows, system_prompt, is_sample):
325
+ input_mappings = _get_output_mappings(num_turns)
326
+ output_mappings = input_mappings
327
+
328
+ with Pipeline(name="sft") as pipeline:
329
+ magpie = get_magpie_generator(num_turns, num_rows, system_prompt, is_sample)
330
+ generate_response = get_response_generator(system_prompt, is_sample)
331
+
332
+ keep_columns = KeepColumns(
333
+ columns=list(output_mappings.values()) + ["model_name"],
334
+ )
335
+
336
+ magpie.connect(generate_response)
337
+ generate_response.connect(keep_columns)
338
+ return pipeline
339
 
340
 
341
  if __name__ == "__main__":
342
+ prompt_generation_step = get_prompt_generator()
343
+ system_prompt = next(
 
344
  prompt_generation_step.process(
345
  [
346
  {
 
350
  ]
351
  )
352
  )[0]["generation"]
353
+ num_rows = 2
354
+ num_turns = 1
355
+ magpie_generator = get_magpie_generator(num_turns, num_rows, system_prompt, False)
356
+ response_generator = get_response_generator(num_turns, system_prompt, False)
357
+ total_steps = num_rows * 2
358
+ batch_size = 5 # Adjust this value as needed
359
+
360
+ # create instructions
361
+ magpie_results = []
362
+ for i in range(0, num_rows, batch_size):
363
+ batch = list(magpie_generator.process())[:batch_size]
364
+ magpie_results.extend([item[0] for item in batch])
365
+
366
+ # generate responses
367
+ response_results = []
368
+ if num_turns == 1:
369
+ for i in range(0, len(magpie_results), batch_size):
370
+ batch = magpie_results[i : i + batch_size]
371
+ batch = [entry[0] for entry in batch]
372
+ responses = list(response_generator.process(inputs=batch))
373
+ response_results.extend(responses)
374
+ for result in response_results:
375
+ result[0]["prompt"] = result[0]["instruction"]
376
+ result[0]["completion"] = result[0]["generation"]
377
+ result[0]["system_prompt"] = system_prompt
378
+ else:
379
+ for result in magpie_results:
380
+ result[0]["conversation"].insert(
381
+ 0, {"role": "system", "content": system_prompt}
382
+ )
383
+ result[0]["messages"] = result[0]["conversation"]
384
+ for i in range(0, len(magpie_results), batch_size):
385
+ batch = magpie_results[i : i + batch_size]
386
+ batch = [entry[0] for entry in batch]
387
+ responses = list(response_generator.process(inputs=batch))
388
+ response_results.extend(responses)
389
+
390
+ for result in response_results:
391
+ result[0]["messages"].append(
392
+ {"role": "assistant", "content": result[0]["generation"]}
393
+ )
394
+
395
+ distiset_results = []
396
+ for result in response_results[0]:
397
+ record = {}
398
+ for relevant_keys in [
399
+ "messages",
400
+ "prompt",
401
+ "completion",
402
+ "model_name",
403
+ "system_prompt",
404
+ ]:
405
+ if relevant_keys in result:
406
+ record[relevant_keys] = result[relevant_keys]
407
+ distiset_results.append(record)
408
+
409
+ distiset = Distiset(
410
+ {
411
+ "default": Dataset.from_list(distiset_results),
412
+ }
413
+ )