davidberenstein1957 HF staff commited on
Commit
d7a6ff4
1 Parent(s): f949aa9

fix: Update batching logic

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -92,31 +92,37 @@ def generate_dataset(
92
  batch_size = DEFAULT_BATCH_SIZE
93
 
94
  # create instructions
 
95
  magpie_results = []
96
- for i in range(0, num_rows, batch_size):
97
  progress(
98
- 0.5 * min(i + batch_size, num_rows) / num_rows,
99
  total=total_steps,
100
  desc="(1/2) Generating instructions",
101
  )
102
- batch = list(magpie_generator.process())[:batch_size]
103
- magpie_results.extend([item[0] for item in batch])
 
 
 
 
104
  progress(0.5, desc="(1/2) Generating instructions")
105
 
106
  # generate responses
 
107
  response_results = []
108
  if num_turns == 1:
109
- for i in range(0, num_rows, batch_size):
110
  progress(
111
- 0.5 + 0.5 * min(i + batch_size, num_rows) / num_rows,
112
  total=total_steps,
113
  desc="(2/2) Generating responses",
114
  )
115
- batch = magpie_results[i : i + batch_size]
116
- batch = [entry[0] for entry in batch]
117
  responses = list(response_generator.process(inputs=batch))
118
- response_results.extend(responses)
119
- for result in response_results[0]:
 
120
  result["prompt"] = result["instruction"]
121
  result["completion"] = result["generation"]
122
  result["system_prompt"] = system_prompt
@@ -126,18 +132,17 @@ def generate_dataset(
126
  0, {"role": "system", "content": system_prompt}
127
  )
128
  result[0]["messages"] = result[0]["conversation"]
129
- for i in range(0, num_rows, batch_size):
130
  progress(
131
- 0.5 + 0.5 * min(i + batch_size, num_rows) / num_rows,
132
  total=total_steps,
133
  desc="(2/2) Generating responses",
134
  )
135
- batch = magpie_results[i : i + batch_size]
136
- batch = [entry[0] for entry in batch]
137
  responses = list(response_generator.process(inputs=batch))
138
- response_results.extend(responses)
139
-
140
- for result in response_results[0]:
141
  result["messages"].append(
142
  {"role": "assistant", "content": result["generation"]}
143
  )
@@ -149,7 +154,7 @@ def generate_dataset(
149
 
150
  # create distiset
151
  distiset_results = []
152
- for result in response_results[0]:
153
  record = {}
154
  for relevant_keys in [
155
  "messages",
 
92
  batch_size = DEFAULT_BATCH_SIZE
93
 
94
  # create instructions
95
+ n_processed = 0
96
  magpie_results = []
97
+ while n_processed < num_rows:
98
  progress(
99
+ 0.5 * n_processed / num_rows,
100
  total=total_steps,
101
  desc="(1/2) Generating instructions",
102
  )
103
+ remaining_rows = num_rows - n_processed
104
+ batch_size = min(batch_size, remaining_rows)
105
+ inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
106
+ batch = list(magpie_generator.process(inputs=inputs))
107
+ magpie_results.extend(batch[0])
108
+ n_processed += batch_size
109
  progress(0.5, desc="(1/2) Generating instructions")
110
 
111
  # generate responses
112
+ n_processed = 0
113
  response_results = []
114
  if num_turns == 1:
115
+ while n_processed < num_rows:
116
  progress(
117
+ 0.5 + 0.5 * n_processed / num_rows,
118
  total=total_steps,
119
  desc="(2/2) Generating responses",
120
  )
121
+ batch = magpie_results[n_processed : n_processed + batch_size]
 
122
  responses = list(response_generator.process(inputs=batch))
123
+ response_results.extend(responses[0])
124
+ n_processed += batch_size
125
+ for result in response_results:
126
  result["prompt"] = result["instruction"]
127
  result["completion"] = result["generation"]
128
  result["system_prompt"] = system_prompt
 
132
  0, {"role": "system", "content": system_prompt}
133
  )
134
  result[0]["messages"] = result[0]["conversation"]
135
+ while n_processed < num_rows:
136
  progress(
137
+ 0.5 + 0.5 * n_processed / num_rows,
138
  total=total_steps,
139
  desc="(2/2) Generating responses",
140
  )
141
+ batch = magpie_results[n_processed : n_processed + batch_size]
 
142
  responses = list(response_generator.process(inputs=batch))
143
+ response_results.extend(responses[0])
144
+ n_processed += batch_size
145
+ for result in response_results:
146
  result["messages"].append(
147
  {"role": "assistant", "content": result["generation"]}
148
  )
 
154
 
155
  # create distiset
156
  distiset_results = []
157
+ for result in response_results:
158
  record = {}
159
  for relevant_keys in [
160
  "messages",
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -4,7 +4,7 @@ from distilabel.distiset import Distiset
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.pipeline import Pipeline
6
  from distilabel.steps import KeepColumns
7
- from distilabel.steps.tasks import ChatGeneration, MagpieGenerator, TextGeneration
8
 
9
  from src.distilabel_dataset_generator.utils import HF_TOKENS
10
 
@@ -221,7 +221,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
221
  input_mappings = _get_output_mappings(num_turns)
222
  output_mappings = input_mappings.copy()
223
  if num_turns == 1:
224
- magpie_generator = MagpieGenerator(
225
  llm=InferenceEndpointsLLM(
226
  model_id=MODEL,
227
  tokenizer_id=MODEL,
@@ -234,15 +234,13 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
234
  "stop_sequences": _STOP_SEQUENCES,
235
  },
236
  ),
237
- batch_size=DEFAULT_BATCH_SIZE,
238
  n_turns=num_turns,
239
- num_rows=num_rows,
240
  system_prompt=system_prompt,
241
  output_mappings=output_mappings,
242
  only_instruction=True,
243
  )
244
  else:
245
- magpie_generator = MagpieGenerator(
246
  llm=InferenceEndpointsLLM(
247
  model_id=MODEL,
248
  tokenizer_id=MODEL,
@@ -255,10 +253,8 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
255
  "stop_sequences": _STOP_SEQUENCES,
256
  },
257
  ),
258
- batch_size=DEFAULT_BATCH_SIZE,
259
  end_with_user=True,
260
  n_turns=num_turns,
261
- num_rows=num_rows,
262
  system_prompt=system_prompt,
263
  output_mappings=output_mappings,
264
  )
 
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.pipeline import Pipeline
6
  from distilabel.steps import KeepColumns
7
+ from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
8
 
9
  from src.distilabel_dataset_generator.utils import HF_TOKENS
10
 
 
221
  input_mappings = _get_output_mappings(num_turns)
222
  output_mappings = input_mappings.copy()
223
  if num_turns == 1:
224
+ magpie_generator = Magpie(
225
  llm=InferenceEndpointsLLM(
226
  model_id=MODEL,
227
  tokenizer_id=MODEL,
 
234
  "stop_sequences": _STOP_SEQUENCES,
235
  },
236
  ),
 
237
  n_turns=num_turns,
 
238
  system_prompt=system_prompt,
239
  output_mappings=output_mappings,
240
  only_instruction=True,
241
  )
242
  else:
243
+ magpie_generator = Magpie(
244
  llm=InferenceEndpointsLLM(
245
  model_id=MODEL,
246
  tokenizer_id=MODEL,
 
253
  "stop_sequences": _STOP_SEQUENCES,
254
  },
255
  ),
 
256
  end_with_user=True,
257
  n_turns=num_turns,
 
258
  system_prompt=system_prompt,
259
  output_mappings=output_mappings,
260
  )