John6666 commited on
Commit
008c48b
·
verified ·
1 Parent(s): 27354c0

Upload modutils.py

Browse files
Files changed (1) hide show
  1. modutils.py +73 -6
modutils.py CHANGED
@@ -1,10 +1,12 @@
1
  import spaces
2
  import json
3
  import gradio as gr
4
- from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
7
  from PIL import Image
 
 
 
8
 
9
 
10
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
@@ -61,6 +63,48 @@ def get_local_model_list(dir_path):
61
  return model_list
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def download_things(directory, url, hf_token="", civitai_api_key=""):
65
  url = url.strip()
66
  if "drive.google.com" in url:
@@ -73,11 +117,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
73
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
74
  if "/blob/" in url:
75
  url = url.replace("/blob/", "/resolve/")
76
- user_header = f'"Authorization: Bearer {hf_token}"'
77
- if hf_token:
78
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
- else:
80
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
  elif "civitai.com" in url:
82
  if "?" in url:
83
  url = url.split("?")[0]
@@ -90,6 +130,33 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
90
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def escape_lora_basename(basename: str):
94
  return basename.replace(".", "_").replace(" ", "_").replace(",", "")
95
 
 
1
  import spaces
2
  import json
3
  import gradio as gr
 
4
  import os
5
  from pathlib import Path
6
  from PIL import Image
7
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
8
+ import urllib.parse
9
+ import re
10
 
11
 
12
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
 
63
  return model_list
64
 
65
 
66
+ def get_token():
67
+ try:
68
+ token = HfFolder.get_token()
69
+ except Exception:
70
+ token = ""
71
+ return token
72
+
73
+
74
+ def set_token(token):
75
+ try:
76
+ HfFolder.save_token(token)
77
+ except Exception:
78
+ print(f"Error: Failed to save token.")
79
+
80
+
81
+ set_token(HF_TOKEN)
82
+
83
+
84
+ def split_hf_url(url: str):
85
+ try:
86
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
87
+ if len(s) < 4: return "", "", "", ""
88
+ repo_id = s[1]
89
+ repo_type = "dataset" if s[0] == "datasets" else "model"
90
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
91
+ filename = urllib.parse.unquote(s[3])
92
+ return repo_id, filename, subfolder, repo_type
93
+ except Exception as e:
94
+ print(e)
95
+
96
+
97
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
98
+ hf_token = get_token()
99
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
100
+ try:
101
+ print(f"Downloading {url} to {directory}")
102
+ if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
103
+ else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
104
+ except Exception as e:
105
+ print(f"Failed to download: {e}")
106
+
107
+
108
  def download_things(directory, url, hf_token="", civitai_api_key=""):
109
  url = url.strip()
110
  if "drive.google.com" in url:
 
117
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
118
  if "/blob/" in url:
119
  url = url.replace("/blob/", "/resolve/")
120
+ download_hf_file(directory, url)
 
 
 
 
121
  elif "civitai.com" in url:
122
  if "?" in url:
123
  url = url.split("?")[0]
 
130
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
131
 
132
 
133
+ def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
134
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
135
+ print(f"Use HF Repo: {url}")
136
+ new_file = url
137
+ elif not "http" in url and Path(url).exists():
138
+ print(f"Use local file: {url}")
139
+ new_file = url
140
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
141
+ print(f"File to download alreday exists: {url}")
142
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
143
+ else:
144
+ print(f"Start downloading: {url}")
145
+ before = get_local_model_list(temp_dir)
146
+ try:
147
+ download_things(temp_dir, url.strip(), HF_TOKEN, civitai_key)
148
+ except Exception:
149
+ print(f"Download failed: {url}")
150
+ return ""
151
+ after = get_local_model_list(temp_dir)
152
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
153
+ if not new_file:
154
+ print(f"Download failed: {url}")
155
+ return ""
156
+ print(f"Download completed: {url}")
157
+ return new_file
158
+
159
+
160
  def escape_lora_basename(basename: str):
161
  return basename.replace(".", "_").replace(" ", "_").replace(",", "")
162