osanseviero
commited on
Commit
•
3b2d079
1
Parent(s):
b01e314
Update src/distilabel_dataset_generator/apps/sft.py
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -18,7 +18,6 @@ from src.distilabel_dataset_generator.pipelines.sft import (
|
|
18 |
from src.distilabel_dataset_generator.utils import (
|
19 |
get_login_button,
|
20 |
get_org_dropdown,
|
21 |
-
get_token,
|
22 |
)
|
23 |
|
24 |
|
@@ -66,8 +65,8 @@ def generate_dataset(
|
|
66 |
private=True,
|
67 |
org_name=None,
|
68 |
repo_name=None,
|
69 |
-
token=None,
|
70 |
progress=gr.Progress(),
|
|
|
71 |
):
|
72 |
repo_id = (
|
73 |
f"{org_name}/{repo_name}"
|
@@ -79,12 +78,6 @@ def generate_dataset(
|
|
79 |
raise gr.Error(
|
80 |
"Please provide a repo_name and org_name to push the dataset to."
|
81 |
)
|
82 |
-
try:
|
83 |
-
whoami(token=token)
|
84 |
-
except Exception:
|
85 |
-
raise gr.Error(
|
86 |
-
"Provide a Hugging Face token with write access to the organization you want to push the dataset to."
|
87 |
-
)
|
88 |
|
89 |
if num_turns > 4:
|
90 |
num_turns = 4
|
@@ -133,7 +126,7 @@ def generate_dataset(
|
|
133 |
repo_id=repo_id,
|
134 |
private=private,
|
135 |
include_script=False,
|
136 |
-
token=token,
|
137 |
)
|
138 |
|
139 |
# If not pushing to hub generate the dataset directly
|
@@ -215,7 +208,7 @@ with gr.Blocks(
|
|
215 |
# Add a header for the full dataset generation section
|
216 |
gr.Markdown("## Generate full dataset")
|
217 |
gr.Markdown(
|
218 |
-
"Once you're satisfied with the sample, generate a larger dataset and push it to the
|
219 |
)
|
220 |
|
221 |
with gr.Column() as push_to_hub_ui:
|
@@ -281,7 +274,6 @@ with gr.Blocks(
|
|
281 |
private,
|
282 |
org_name,
|
283 |
repo_name,
|
284 |
-
hf_token,
|
285 |
],
|
286 |
outputs=[table],
|
287 |
show_progress=True,
|
@@ -300,5 +292,4 @@ with gr.Blocks(
|
|
300 |
label="Distilabel Pipeline Code",
|
301 |
)
|
302 |
|
303 |
-
app.load(get_token, outputs=[hf_token])
|
304 |
app.load(get_org_dropdown, outputs=[org_name])
|
|
|
18 |
from src.distilabel_dataset_generator.utils import (
|
19 |
get_login_button,
|
20 |
get_org_dropdown,
|
|
|
21 |
)
|
22 |
|
23 |
|
|
|
65 |
private=True,
|
66 |
org_name=None,
|
67 |
repo_name=None,
|
|
|
68 |
progress=gr.Progress(),
|
69 |
+
oauth_token: Union[gr.OAuthToken, None]
|
70 |
):
|
71 |
repo_id = (
|
72 |
f"{org_name}/{repo_name}"
|
|
|
78 |
raise gr.Error(
|
79 |
"Please provide a repo_name and org_name to push the dataset to."
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
if num_turns > 4:
|
83 |
num_turns = 4
|
|
|
126 |
repo_id=repo_id,
|
127 |
private=private,
|
128 |
include_script=False,
|
129 |
+
token=oauth_token.token,
|
130 |
)
|
131 |
|
132 |
# If not pushing to hub generate the dataset directly
|
|
|
208 |
# Add a header for the full dataset generation section
|
209 |
gr.Markdown("## Generate full dataset")
|
210 |
gr.Markdown(
|
211 |
+
"Once you're satisfied with the sample, generate a larger dataset and push it to the Hub."
|
212 |
)
|
213 |
|
214 |
with gr.Column() as push_to_hub_ui:
|
|
|
274 |
private,
|
275 |
org_name,
|
276 |
repo_name,
|
|
|
277 |
],
|
278 |
outputs=[table],
|
279 |
show_progress=True,
|
|
|
292 |
label="Distilabel Pipeline Code",
|
293 |
)
|
294 |
|
|
|
295 |
app.load(get_org_dropdown, outputs=[org_name])
|