theNeofr commited on
Commit
3385bd3
·
verified ·
1 Parent(s): c99620c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +197 -299
utils.py CHANGED
@@ -1,299 +1,197 @@
1
- import gradio as gr
2
- from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
3
- import os
4
- from pathlib import Path
5
- import shutil
6
- import gc
7
- import re
8
- import urllib.parse
9
- import subprocess
10
- import time
11
- from typing import Any
12
-
13
-
14
- def get_token():
15
- try:
16
- token = HfFolder.get_token()
17
- except Exception:
18
- token = ""
19
- return token
20
-
21
-
22
- def set_token(token):
23
- try:
24
- HfFolder.save_token(token)
25
- except Exception:
26
- print(f"Error: Failed to save token.")
27
-
28
-
29
- def get_state(state: dict, key: str):
30
- if key in state.keys(): return state[key]
31
- else:
32
- print(f"State '{key}' not found.")
33
- return None
34
-
35
-
36
- def set_state(state: dict, key: str, value: Any):
37
- state[key] = value
38
-
39
-
40
- def get_user_agent():
41
- return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
42
-
43
-
44
- def is_repo_exists(repo_id: str, repo_type: str="model"):
45
- hf_token = get_token()
46
- api = HfApi(token=hf_token)
47
- try:
48
- if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
49
- else: return False
50
- except Exception as e:
51
- print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
52
- return True # for safe
53
-
54
-
55
- MODEL_TYPE_CLASS = {
56
- "diffusers:StableDiffusionPipeline": "SD 1.5",
57
- "diffusers:StableDiffusionXLPipeline": "SDXL",
58
- "diffusers:FluxPipeline": "FLUX",
59
- }
60
-
61
-
62
- def get_model_type(repo_id: str):
63
- hf_token = get_token()
64
- api = HfApi(token=hf_token)
65
- lora_filename = "pytorch_lora_weights.safetensors"
66
- diffusers_filename = "model_index.json"
67
- default = "SDXL"
68
- try:
69
- if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
70
- if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
71
- model = api.model_info(repo_id=repo_id, token=hf_token)
72
- tags = model.tags
73
- for tag in tags:
74
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
75
- except Exception:
76
- return default
77
- return default
78
-
79
-
80
- def list_uniq(l):
81
- return sorted(set(l), key=l.index)
82
-
83
-
84
- def list_sub(a, b):
85
- return [e for e in a if e not in b]
86
-
87
-
88
- def is_repo_name(s):
89
- return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s)
90
-
91
-
92
- def get_hf_url(repo_id: str, repo_type: str="model"):
93
- if repo_type == "dataset": url = f"https://huggingface.co/datasets/{repo_id}"
94
- elif repo_type == "space": url = f"https://huggingface.co/spaces/{repo_id}"
95
- else: url = f"https://huggingface.co/{repo_id}"
96
- return url
97
-
98
-
99
- def split_hf_url(url: str):
100
- try:
101
- s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
102
- if len(s) < 4: return "", "", "", ""
103
- repo_id = s[1]
104
- if s[0] == "datasets": repo_type = "dataset"
105
- elif s[0] == "spaces": repo_type = "space"
106
- else: repo_type = "model"
107
- subfolder = urllib.parse.unquote(s[2]) if s[2] else None
108
- filename = urllib.parse.unquote(s[3])
109
- return repo_id, filename, subfolder, repo_type
110
- except Exception as e:
111
- print(e)
112
-
113
-
114
- def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
115
- hf_token = get_token()
116
- repo_id, filename, subfolder, repo_type = split_hf_url(url)
117
- try:
118
- print(f"Downloading {url} to {directory}")
119
- if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
120
- else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
121
- return path
122
- except Exception as e:
123
- print(f"Failed to download: {e}")
124
- return None
125
-
126
-
127
- def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
128
- try:
129
- url = url.strip()
130
- if "drive.google.com" in url:
131
- original_dir = os.getcwd()
132
- os.chdir(directory)
133
- subprocess.run(f"gdown --fuzzy {url}", shell=True)
134
- os.chdir(original_dir)
135
- elif "huggingface.co" in url:
136
- url = url.replace("?download=true", "")
137
- if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
138
- download_hf_file(directory, url)
139
- elif "civitai.com" in url:
140
- if civitai_api_key:
141
- url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}"
142
- print(f"Downloading {url}")
143
- subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True)
144
- else:
145
- print("You need an API key to download Civitai models.")
146
- else:
147
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
148
- except Exception as e:
149
- print(f"Failed to download: {e}")
150
-
151
-
152
- def get_local_file_list(dir_path, recursive=False):
153
- file_list = []
154
- pattern = "**/*.*" if recursive else "*/*.*"
155
- for file in Path(dir_path).glob(pattern):
156
- if file.is_file():
157
- file_path = str(file)
158
- file_list.append(file_path)
159
- return file_list
160
-
161
-
162
- def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
163
- try:
164
- if not "http" in url and is_repo_name(url) and not Path(url).exists():
165
- print(f"Use HF Repo: {url}")
166
- new_file = url
167
- elif not "http" in url and Path(url).exists():
168
- print(f"Use local file: {url}")
169
- new_file = url
170
- elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
171
- print(f"File to download alreday exists: {url}")
172
- new_file = f"{temp_dir}/{url.split('/')[-1]}"
173
- else:
174
- print(f"Start downloading: {url}")
175
- recursive = False if "huggingface.co" in url else True
176
- before = get_local_file_list(temp_dir, recursive)
177
- download_thing(temp_dir, url.strip(), civitai_key)
178
- after = get_local_file_list(temp_dir, recursive)
179
- new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
180
- if not new_file:
181
- print(f"Download failed: {url}")
182
- return ""
183
- print(f"Download completed: {url}")
184
- return new_file
185
- except Exception as e:
186
- print(f"Download failed: {url} {e}")
187
- return ""
188
-
189
-
190
- def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
191
- hf_token = get_token()
192
- try:
193
- snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
194
- ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"], force_download=True)
195
- return True
196
- except Exception as e:
197
- print(f"Error: Failed to download {repo_id}. {e}")
198
- gr.Warning(f"Error: Failed to download {repo_id}. {e}")
199
- return False
200
-
201
-
202
- def upload_repo(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
203
- hf_token = get_token()
204
- api = HfApi(token=hf_token)
205
- try:
206
- progress(0, desc="Start uploading...")
207
- api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
208
- api.upload_folder(repo_id=repo_id, folder_path=dir_path, path_in_repo="", create_pr=is_pr, token=hf_token)
209
- progress(1, desc="Uploaded.")
210
- return get_hf_url(repo_id, "model")
211
- except Exception as e:
212
- print(f"Error: Failed to upload to {repo_id}. {e}")
213
- return ""
214
-
215
-
216
- def gate_repo(repo_id: str, gated_str: str, repo_type: str="model"):
217
- hf_token = get_token()
218
- api = HfApi(token=hf_token)
219
- try:
220
- if gated_str == "auto": gated = "auto"
221
- elif gated_str == "manual": gated = "manual"
222
- else: gated = False
223
- api.update_repo_settings(repo_id=repo_id, gated=gated, repo_type=repo_type, token=hf_token)
224
- except Exception as e:
225
- print(f"Error: Failed to update settings {repo_id}. {e}")
226
-
227
-
228
- HF_SUBFOLDER_NAME = ["None", "user_repo"]
229
-
230
-
231
- def duplicate_hf_repo(src_repo: str, dst_repo: str, src_repo_type: str, dst_repo_type: str,
232
- is_private: bool, subfolder_type: str=HF_SUBFOLDER_NAME[1], progress=gr.Progress(track_tqdm=True)):
233
- hf_token = get_token()
234
- api = HfApi(token=hf_token)
235
- try:
236
- if subfolder_type == "user_repo": subfolder = src_repo.replace("/", "_")
237
- else: subfolder = ""
238
- progress(0, desc="Start duplicating...")
239
- api.create_repo(repo_id=dst_repo, repo_type=dst_repo_type, private=is_private, exist_ok=True, token=hf_token)
240
- for path in api.list_repo_files(repo_id=src_repo, repo_type=src_repo_type, token=hf_token):
241
- file = hf_hub_download(repo_id=src_repo, filename=path, repo_type=src_repo_type, token=hf_token)
242
- if not Path(file).exists(): continue
243
- if Path(file).is_dir(): # unused for now
244
- api.upload_folder(repo_id=dst_repo, folder_path=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
245
- repo_type=dst_repo_type, token=hf_token)
246
- elif Path(file).is_file():
247
- api.upload_file(repo_id=dst_repo, path_or_fileobj=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
248
- repo_type=dst_repo_type, token=hf_token)
249
- if Path(file).exists(): Path(file).unlink()
250
- progress(1, desc="Duplicated.")
251
- return f"{get_hf_url(dst_repo, dst_repo_type)}/tree/main/{subfolder}" if subfolder else get_hf_url(dst_repo, dst_repo_type)
252
- except Exception as e:
253
- print(f"Error: Failed to duplicate repo {src_repo} to {dst_repo}. {e}")
254
- return ""
255
-
256
-
257
- BASE_DIR = str(Path(__file__).resolve().parent.resolve())
258
- CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
259
-
260
-
261
- def get_file(url: str, path: str): # requires aria2, gdown
262
- print(f"Downloading {url} to {path}...")
263
- get_download_file(path, url, CIVITAI_API_KEY)
264
-
265
-
266
- def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git
267
- os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True)
268
- os.chdir(Path(BASE_DIR, path))
269
- print(f"Cloning {url} to {path}...")
270
- cmd = f'git clone {url}'
271
- print(f'Running {cmd} at {Path.cwd()}')
272
- i = subprocess.run(cmd, shell=True).returncode
273
- if i != 0: print(f'Error occured at running {cmd}')
274
- p = url.split("/")[-1]
275
- if not Path(p).exists: return
276
- if pip:
277
- os.chdir(Path(BASE_DIR, path, p))
278
- cmd = f'pip install -r requirements.txt'
279
- print(f'Running {cmd} at {Path.cwd()}')
280
- i = subprocess.run(cmd, shell=True).returncode
281
- if i != 0: print(f'Error occured at running {cmd}')
282
- if addcmd:
283
- os.chdir(Path(BASE_DIR, path, p))
284
- cmd = addcmd
285
- print(f'Running {cmd} at {Path.cwd()}')
286
- i = subprocess.run(cmd, shell=True).returncode
287
- if i != 0: print(f'Error occured at running {cmd}')
288
-
289
-
290
- def run(cmd: str, timeout: float=0):
291
- print(f'Running {cmd} at {Path.cwd()}')
292
- if timeout == 0:
293
- i = subprocess.run(cmd, shell=True).returncode
294
- if i != 0: print(f'Error occured at running {cmd}')
295
- else:
296
- p = subprocess.Popen(cmd, shell=True)
297
- time.sleep(timeout)
298
- p.terminate()
299
- print(f'Terminated in {timeout} seconds')
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+ import subprocess
10
+ import time
11
+ from typing import Any
12
+
13
+
14
+
15
+
16
+
17
+ def get_state(state: dict, key: str):
18
+ if key in state.keys(): return state[key]
19
+ else:
20
+ print(f"State '{key}' not found.")
21
+ return None
22
+
23
+
24
+ def set_state(state: dict, key: str, value: Any):
25
+ state[key] = value
26
+
27
+
28
+ def get_user_agent():
29
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
30
+
31
+
32
+
33
+
34
+ MODEL_TYPE_CLASS = {
35
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
36
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
37
+ "diffusers:FluxPipeline": "FLUX",
38
+ }
39
+
40
+
41
+ def get_model_type(repo_id: str):
42
+ hf_token = get_token()
43
+ api = HfApi(token=hf_token)
44
+ lora_filename = "pytorch_lora_weights.safetensors"
45
+ diffusers_filename = "model_index.json"
46
+ default = "SDXL"
47
+ try:
48
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
49
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
50
+ model = api.model_info(repo_id=repo_id, token=hf_token)
51
+ tags = model.tags
52
+ for tag in tags:
53
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
54
+ except Exception:
55
+ return default
56
+ return default
57
+
58
+
59
+ def list_uniq(l):
60
+ return sorted(set(l), key=l.index)
61
+
62
+
63
+ def list_sub(a, b):
64
+ return [e for e in a if e not in b]
65
+
66
+
67
+ def is_repo_name(s):
68
+ return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s)
69
+
70
+
71
+
72
+
73
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
74
+ try:
75
+ url = url.strip()
76
+ if "drive.google.com" in url:
77
+ original_dir = os.getcwd()
78
+ os.chdir(directory)
79
+ subprocess.run(f"gdown --fuzzy {url}", shell=True)
80
+ os.chdir(original_dir)
81
+ elif "huggingface.co" in url:
82
+ url = url.replace("?download=true", "")
83
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
84
+ download_hf_file(directory, url)
85
+ elif "civitai.com" in url:
86
+ if civitai_api_key:
87
+ url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}"
88
+ print(f"Downloading {url}")
89
+ subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True)
90
+ else:
91
+ print("You need an API key to download Civitai models.")
92
+ else:
93
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
94
+ except Exception as e:
95
+ print(f"Failed to download: {e}")
96
+
97
+
98
+ def get_local_file_list(dir_path, recursive=False):
99
+ file_list = []
100
+ pattern = "**/*.*" if recursive else "*/*.*"
101
+ for file in Path(dir_path).glob(pattern):
102
+ if file.is_file():
103
+ file_path = str(file)
104
+ file_list.append(file_path)
105
+ return file_list
106
+
107
+
108
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
109
+ try:
110
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
111
+ print(f"Use HF Repo: {url}")
112
+ new_file = url
113
+ elif not "http" in url and Path(url).exists():
114
+ print(f"Use local file: {url}")
115
+ new_file = url
116
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
117
+ print(f"File to download alreday exists: {url}")
118
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
119
+ else:
120
+ print(f"Start downloading: {url}")
121
+ recursive = False if "huggingface.co" in url else True
122
+ before = get_local_file_list(temp_dir, recursive)
123
+ download_thing(temp_dir, url.strip(), civitai_key)
124
+ after = get_local_file_list(temp_dir, recursive)
125
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
126
+ if not new_file:
127
+ print(f"Download failed: {url}")
128
+ return ""
129
+ print(f"Download completed: {url}")
130
+ return new_file
131
+ except Exception as e:
132
+ print(f"Download failed: {url} {e}")
133
+ return ""
134
+
135
+
136
+
137
+
138
+ def gate_repo(repo_id: str, gated_str: str, repo_type: str="model"):
139
+ hf_token = get_token()
140
+ api = HfApi(token=hf_token)
141
+ try:
142
+ if gated_str == "auto": gated = "auto"
143
+ elif gated_str == "manual": gated = "manual"
144
+ else: gated = False
145
+ api.update_repo_settings(repo_id=repo_id, gated=gated, repo_type=repo_type, token=hf_token)
146
+ except Exception as e:
147
+ print(f"Error: Failed to update settings {repo_id}. {e}")
148
+
149
+
150
+ HF_SUBFOLDER_NAME = ["None", "user_repo"]
151
+
152
+
153
+
154
+
155
+ BASE_DIR = os.getcwd()
156
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
157
+
158
+
159
+ def get_file(url: str, path: str): # requires aria2, gdown
160
+ print(f"Downloading {url} to {path}...")
161
+ get_download_file(path, url, CIVITAI_API_KEY)
162
+
163
+
164
+ def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git
165
+ os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True)
166
+ os.chdir(Path(BASE_DIR, path))
167
+ print(f"Cloning {url} to {path}...")
168
+ cmd = f'git clone {url}'
169
+ print(f'Running {cmd} at {Path.cwd()}')
170
+ i = subprocess.run(cmd, shell=True).returncode
171
+ if i != 0: print(f'Error occured at running {cmd}')
172
+ p = url.split("/")[-1]
173
+ if not Path(p).exists: return
174
+ if pip:
175
+ os.chdir(Path(BASE_DIR, path, p))
176
+ cmd = f'pip install -r requirements.txt'
177
+ print(f'Running {cmd} at {Path.cwd()}')
178
+ i = subprocess.run(cmd, shell=True).returncode
179
+ if i != 0: print(f'Error occured at running {cmd}')
180
+ if addcmd:
181
+ os.chdir(Path(BASE_DIR, path, p))
182
+ cmd = addcmd
183
+ print(f'Running {cmd} at {Path.cwd()}')
184
+ i = subprocess.run(cmd, shell=True).returncode
185
+ if i != 0: print(f'Error occured at running {cmd}')
186
+
187
+
188
+ def run(cmd: str, timeout: float=0):
189
+ print(f'Running {cmd} at {Path.cwd()}')
190
+ if timeout == 0:
191
+ i = subprocess.run(cmd, shell=True).returncode
192
+ if i != 0: print(f'Error occured at running {cmd}')
193
+ else:
194
+ p = subprocess.Popen(cmd, shell=True)
195
+ time.sleep(timeout)
196
+ p.terminate()
197
+ print(f'Terminated in {timeout} seconds')