Commit
•
c973277
1
Parent(s):
a13f86c
Reduce simple dataset generation time
Browse files
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
@@ -190,31 +190,73 @@ if __name__ == "__main__":
|
|
190 |
def get_pipeline(num_turns, num_rows, system_prompt):
|
191 |
input_mappings = _get_output_mappings(num_turns)
|
192 |
output_mappings = input_mappings
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
|
220 |
def get_prompt_generation_step():
|
|
|
190 |
def get_pipeline(num_turns, num_rows, system_prompt):
|
191 |
input_mappings = _get_output_mappings(num_turns)
|
192 |
output_mappings = input_mappings
|
193 |
+
if num_turns == 1:
|
194 |
+
with Pipeline(name="sft") as pipeline:
|
195 |
+
magpie = MagpieGenerator(
|
196 |
+
llm=InferenceEndpointsLLM(
|
197 |
+
model_id=MODEL,
|
198 |
+
tokenizer_id=MODEL,
|
199 |
+
api_key=os.environ["HF_TOKEN"],
|
200 |
+
magpie_pre_query_template="llama3",
|
201 |
+
generation_kwargs={
|
202 |
+
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
203 |
+
"do_sample": True,
|
204 |
+
"max_new_tokens": 512,
|
205 |
+
"stop_sequences": _STOP_SEQUENCES,
|
206 |
+
},
|
207 |
+
),
|
208 |
+
batch_size=2,
|
209 |
+
n_turns=num_turns,
|
210 |
+
num_rows=num_rows,
|
211 |
+
system_prompt=system_prompt,
|
212 |
+
output_mappings=output_mappings,
|
213 |
+
only_instructions=True
|
214 |
+
)
|
215 |
+
|
216 |
+
generate_response = TextGeneration(
|
217 |
+
llm=InferenceEndpointsLLM(
|
218 |
+
model_id=MODEL,
|
219 |
+
tokenizer_id=MODEL,
|
220 |
+
generation_kwargs={
|
221 |
+
"temperature": 0.8,
|
222 |
+
"max_new_tokens": 1024
|
223 |
+
},
|
224 |
+
)
|
225 |
+
)
|
226 |
+
|
227 |
+
keep_columns = KeepColumns(
|
228 |
+
columns=list(output_mappings.values()) + ["model_name"],
|
229 |
+
)
|
230 |
+
|
231 |
+
magpie.connect(generate_response)
|
232 |
+
generate_response.connect(keep_columns)
|
233 |
+
return pipeline
|
234 |
+
else:
|
235 |
+
with Pipeline(name="sft") as pipeline:
|
236 |
+
magpie = MagpieGenerator(
|
237 |
+
llm=InferenceEndpointsLLM(
|
238 |
+
model_id=MODEL,
|
239 |
+
tokenizer_id=MODEL,
|
240 |
+
api_key=os.environ["HF_TOKEN"],
|
241 |
+
magpie_pre_query_template="llama3",
|
242 |
+
generation_kwargs={
|
243 |
+
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
244 |
+
"do_sample": True,
|
245 |
+
"max_new_tokens": 2048,
|
246 |
+
"stop_sequences": _STOP_SEQUENCES,
|
247 |
+
},
|
248 |
+
),
|
249 |
+
batch_size=2,
|
250 |
+
n_turns=num_turns,
|
251 |
+
num_rows=num_rows,
|
252 |
+
system_prompt=system_prompt,
|
253 |
+
output_mappings=output_mappings,
|
254 |
+
)
|
255 |
+
keep_columns = KeepColumns(
|
256 |
+
columns=list(output_mappings.values()) + ["model_name"],
|
257 |
+
)
|
258 |
+
magpie.connect(keep_columns)
|
259 |
+
return pipeline
|
260 |
|
261 |
|
262 |
def get_prompt_generation_step():
|