sdiazlor HF staff commited on
Commit
cfdc9ea
·
1 Parent(s): 85b97c4

validate labels on click, remove debug messages

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -177,13 +177,6 @@ def generate_dataset(
177
  distiset_results.append(record)
178
 
179
  dataframe = pd.DataFrame(distiset_results)
180
- if (
181
- not labels
182
- or len(set(label.lower().strip() for label in labels if label.strip())) < 2
183
- ):
184
- raise gr.Error(
185
- "Please provide at least 2 unique, non-empty labels to classify your text."
186
- )
187
  if multi_label:
188
  dataframe["labels"] = dataframe["labels"].apply(
189
  lambda x: list(
@@ -222,10 +215,6 @@ def push_dataset_to_hub(
222
  pipeline_code: str = "",
223
  progress=gr.Progress(),
224
  ):
225
- gr.Info(
226
- message=f"Dataframe columns in push dataset to hub: {dataframe.columns}",
227
- duration=20,
228
- )
229
  progress(0.0, desc="Validating")
230
  repo_id = validate_push_to_hub(org_name, repo_name)
231
  progress(0.3, desc="Preprocessing")
@@ -284,7 +273,6 @@ def push_dataset(
284
  num_rows=num_rows,
285
  temperature=temperature,
286
  )
287
- gr.Info(message=f"Dataframe columns: {dataframe.columns}", duration=20)
288
  push_dataset_to_hub(
289
  dataframe,
290
  org_name,
@@ -393,10 +381,13 @@ def push_dataset(
393
  return ""
394
 
395
 
396
- def validate_input_labels(labels):
397
- if not labels or len(labels) < 2:
 
 
 
398
  raise gr.Error(
399
- f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
400
  )
401
  return labels
402
 
@@ -569,6 +560,11 @@ with gr.Blocks() as app:
569
  )
570
 
571
  btn_apply_to_sample_dataset.click(
 
 
 
 
 
572
  fn=generate_sample_dataset,
573
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
574
  outputs=[dataframe],
@@ -585,6 +581,11 @@ with gr.Blocks() as app:
585
  inputs=[org_name, repo_name],
586
  outputs=[success_message],
587
  show_progress=True,
 
 
 
 
 
588
  ).success(
589
  fn=hide_success_message,
590
  outputs=[success_message],
 
177
  distiset_results.append(record)
178
 
179
  dataframe = pd.DataFrame(distiset_results)
 
 
 
 
 
 
 
180
  if multi_label:
181
  dataframe["labels"] = dataframe["labels"].apply(
182
  lambda x: list(
 
215
  pipeline_code: str = "",
216
  progress=gr.Progress(),
217
  ):
 
 
 
 
218
  progress(0.0, desc="Validating")
219
  repo_id = validate_push_to_hub(org_name, repo_name)
220
  progress(0.3, desc="Preprocessing")
 
273
  num_rows=num_rows,
274
  temperature=temperature,
275
  )
 
276
  push_dataset_to_hub(
277
  dataframe,
278
  org_name,
 
381
  return ""
382
 
383
 
384
+ def validate_input_labels(labels: List[str]) -> List[str]:
385
+ if (
386
+ not labels
387
+ or len(set(label.lower().strip() for label in labels if label.strip())) < 2
388
+ ):
389
  raise gr.Error(
390
+ f"Please provide at least 2 unique, non-empty labels to classify your text. You provided {len(labels) if labels else 0}."
391
  )
392
  return labels
393
 
 
560
  )
561
 
562
  btn_apply_to_sample_dataset.click(
563
+ fn=validate_input_labels,
564
+ inputs=[labels],
565
+ outputs=[labels],
566
+ show_progress=True,
567
+ ).success(
568
  fn=generate_sample_dataset,
569
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
570
  outputs=[dataframe],
 
581
  inputs=[org_name, repo_name],
582
  outputs=[success_message],
583
  show_progress=True,
584
+ ).success(
585
+ fn=validate_input_labels,
586
+ inputs=[labels],
587
+ outputs=[labels],
588
+ show_progress=True,
589
  ).success(
590
  fn=hide_success_message,
591
  outputs=[success_message],