osanseviero HF staff commited on
Commit
2e046cf
1 Parent(s): 8a8b6d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -5
app.py CHANGED
@@ -1,14 +1,73 @@
1
  import gradio as gr
 
 
 
2
 
3
- def fork(source_repo, token, type):
4
- return token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  interface = gr.Interface(
7
  fn=fork,
8
  inputs=[
9
- gr.inputs.Textbox(placeholder="Source repository"),
10
- gr.inputs.Textbox(placeholder="Write access token"),
11
- gr.inputs.Dropdown(choices=["model", "dataset", "space"])
 
12
  ],
13
  outputs=["textbox"],
14
  allow_flagging=False,
 
1
  import gradio as gr
2
+ from huggingface_hub import create_repo, whoami
3
+ import subprocess
4
+ import os, shutil
5
 
6
+ def fork(source_repo, dst_repo, token, repo_type):
7
+ # Creating repos has inconsistent API (https://github.com/huggingface/huggingface_hub/issues/47)
8
+ repo_namespace, dst_id = dst_repo.split("/")
9
+ username = whoami(token)
10
+ org = None
11
+ if repo_namespace != username:
12
+ org = repo_namespace
13
+
14
+ # Create the destination repo
15
+ if repo_type in ["spaces", "dataset"]:
16
+ # For some reason create_repo does not allow repo_type="model"..., even if documentation says
17
+ # that's the default.
18
+ create_repo(dst_id, token=token, organization=org, repo_type=repo_type, space_sdk="static")
19
+ else:
20
+ create_repo(dst_id, token=token, organization=org)
21
+
22
+ # Clone source repo
23
+ endpoint = "https://huggingface.co/"
24
+ if repo_type in ["spaces", "dataset"]:
25
+ endpoint += repo_type
26
+ full_path = endpoint + "/" + source_repo
27
+ local_dir = "hub/" + source_repo
28
+ repo = Repository(local_dir=local_dir, clone_from=full_path)
29
+
30
+ # Change remote origin
31
+ command = f"git remote set-url origin https://user:{token}@huggingface.co/"
32
+ if repo_type in ["spaces", "dataset"]:
33
+ command += repo_type
34
+ command += dst_repo
35
+ subprocess.run(
36
+ command.split(),
37
+ stderr=subprocess.PIPE,
38
+ stdout=subprocess.PIPE,
39
+ encoding="utf-8",
40
+ check=True,
41
+ cwd=local_dir,
42
+ )
43
+
44
+ # Push!
45
+ subprocess.run(
46
+ "git push --force origin main".split(),
47
+ stderr=subprocess.PIPE,
48
+ stdout=subprocess.PIPE,
49
+ encoding="utf-8",
50
+ check=True,
51
+ cwd=local_dir,
52
+ )
53
+
54
+ # Clean up
55
+ for filename in os.listdir(local_dir):
56
+ file_path = os.path.join(local_dir, filename)
57
+ if os.path.isfile(file_path) or os.path.islink(file_path):
58
+ os.unlink(file_path)
59
+ elif os.path.isdir(file_path):
60
+ shutil.rmtree(file_path)
61
+
62
+ return endpoint + "/" + dst_repo
63
 
64
  interface = gr.Interface(
65
  fn=fork,
66
  inputs=[
67
+ gr.inputs.Textbox(placeholder="Source repository"),
68
+ gr.inputs.Textbox(placeholder="Destination repository name"),
69
+ gr.inputs.Textbox(placeholder="Write access token"),
70
+ gr.inputs.Dropdown(choices=["model", "dataset", "space"])
71
  ],
72
  outputs=["textbox"],
73
  allow_flagging=False,