theNeofr commited on
Commit
d02ba0f
·
verified ·
1 Parent(s): 9eb8e23

Update StableGR.py

Browse files
Files changed (1) hide show
  1. StableGR.py +362 -417
StableGR.py CHANGED
@@ -1,417 +1,362 @@
1
- import gradio as gr
2
- from huggingface_hub import HfApi, hf_hub_url
3
- import os
4
- from pathlib import Path
5
- import gc
6
- import requests
7
- from requests.adapters import HTTPAdapter
8
- from urllib3.util import Retry
9
- from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file,
10
- list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state)
11
- import re
12
- from PIL import Image
13
- import json
14
- import pandas as pd
15
- import tempfile
16
- import hashlib
17
-
18
-
19
- TEMP_DIR = tempfile.mkdtemp()
20
-
21
-
22
- def parse_urls(s):
23
- url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
24
- try:
25
- urls = re.findall(url_pattern, s)
26
- return list(urls)
27
- except Exception:
28
- return []
29
-
30
-
31
- def parse_repos(s):
32
- repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?'
33
- try:
34
- s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s)
35
- repos = re.findall(repo_pattern, s)
36
- return list(repos)
37
- except Exception:
38
- return []
39
-
40
-
41
- def to_urls(l: list[str]):
42
- return "\n".join(l)
43
-
44
-
45
- def uniq_urls(s):
46
- return to_urls(list_uniq(parse_urls(s) + parse_repos(s)))
47
-
48
-
49
- def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
50
- output_filename = Path(filename).name
51
- hf_token = get_token()
52
- api = HfApi(token=hf_token)
53
- try:
54
- if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
55
- progress(0, desc=f"Start uploading... {filename} to {repo_id}")
56
- api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
57
- progress(1, desc="Uploaded.")
58
- url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
59
- except Exception as e:
60
- print(f"Error: Failed to upload to {repo_id}. {e}")
61
- gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
62
- return None
63
- finally:
64
- if Path(filename).exists(): Path(filename).unlink()
65
- return url
66
-
67
-
68
- def is_same_file(filename: str, cmp_sha256: str, cmp_size: int):
69
- if cmp_sha256:
70
- sha256_hash = hashlib.sha256()
71
- with open(filename, "rb") as f:
72
- for byte_block in iter(lambda: f.read(4096), b""):
73
- sha256_hash.update(byte_block)
74
- sha256 = sha256_hash.hexdigest()
75
- else: sha256 = ""
76
- size = os.path.getsize(filename)
77
- if size == cmp_size and sha256 == cmp_sha256: return True
78
- else: return False
79
-
80
-
81
- def get_safe_filename(filename, repo_id, repo_type):
82
- hf_token = get_token()
83
- api = HfApi(token=hf_token)
84
- new_filename = filename
85
- try:
86
- i = 1
87
- while api.file_exists(repo_id=repo_id, filename=Path(new_filename).name, repo_type=repo_type, token=hf_token):
88
- infos = api.get_paths_info(repo_id=repo_id, paths=[Path(new_filename).name], repo_type=repo_type, token=hf_token)
89
- if infos and len(infos) == 1:
90
- repo_fs = infos[0].size
91
- repo_sha256 = infos[0].lfs.sha256 if infos[0].lfs is not None else ""
92
- if is_same_file(filename, repo_sha256, repo_fs): break
93
- new_filename = str(Path(Path(filename).parent, f"{Path(filename).stem}_{i}{Path(filename).suffix}"))
94
- i += 1
95
- if filename != new_filename:
96
- print(f"{Path(filename).name} is already exists but file content is different. renaming to {Path(new_filename).name}.")
97
- Path(filename).rename(new_filename)
98
- except Exception as e:
99
- print(f"Error occured when renaming {filename}. {e}")
100
- finally:
101
- return new_filename
102
-
103
-
104
- def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
105
- download_dir = TEMP_DIR
106
- progress(0, desc=f"Start downloading... {dl_url}")
107
- output_filename = get_download_file(download_dir, dl_url, civitai_key)
108
- return output_filename
109
-
110
-
111
- def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)):
112
- json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key)
113
- if not json_str: return "", "", ""
114
- json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json"))
115
- html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html"))
116
- try:
117
- with open(json_path, 'w') as f:
118
- json.dump(json_str, f, indent=2)
119
- with open(html_path, mode='w', encoding="utf-8") as f:
120
- f.write(html_str)
121
- return json_path, html_path, image_path
122
- except Exception as e:
123
- print(f"Error: Failed to save info file {json_path}, {html_path} {e}")
124
- return "", "", ""
125
-
126
-
127
- def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)):
128
- def upload_file(api, filename, repo_id, repo_type, hf_token):
129
- if not Path(filename).exists(): return
130
- api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
131
- Path(filename).unlink()
132
-
133
- hf_token = get_token()
134
- api = HfApi(token=hf_token)
135
- try:
136
- if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
137
- progress(0, desc=f"Downloading info... {filename}")
138
- json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key)
139
- progress(0, desc=f"Start uploading info... {filename} to {repo_id}")
140
- if not json_path: return
141
- else: upload_file(api, json_path, repo_id, repo_type, hf_token)
142
- if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token)
143
- if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token)
144
- progress(1, desc="Info uploaded.")
145
- return
146
- except Exception as e:
147
- print(f"Error: Failed to upload info to {repo_id}. {e}")
148
- gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}")
149
- return
150
-
151
-
152
- def download_civitai(dl_url, civitai_key, hf_token, urls,
153
- newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)):
154
- if hf_token: set_token(hf_token)
155
- else: set_token(os.environ.get("HF_TOKEN")) # default huggingface write token
156
- if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
157
- if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload
158
- if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
159
- if not urls: urls = []
160
- dl_urls = parse_urls(dl_url)
161
- remain_urls = dl_urls.copy()
162
- try:
163
- md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n'
164
- for u in dl_urls:
165
- file = download_file(u, civitai_key)
166
- if not Path(file).exists() or not Path(file).is_file(): continue
167
- if is_rename: file = get_safe_filename(file, newrepo_id, repo_type)
168
- url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
169
- if url:
170
- if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key)
171
- urls.append(url)
172
- remain_urls.remove(u)
173
- md += f"- Uploaded [{str(u)}]({str(u)})\n"
174
- dp_repos = parse_repos(dl_url)
175
- for r in dp_repos:
176
- url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1])
177
- if url: urls.append(url)
178
- return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False)
179
- except Exception as e:
180
- gr.Info(f"Error occured: {e}")
181
- return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True)
182
- finally:
183
- gc.collect()
184
-
185
-
186
- CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "DoRA",
187
- "Controlnet", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"]
188
- CIVITAI_FILETYPE = ["Model", "VAE", "Config", "Training Data"]
189
- CIVITAI_BASEMODEL = ["Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S", "SD 3.5"]
190
- #CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
191
- CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Most Liked", "Most Discussed", "Most Collected", "Most Buzz", "Newest"]
192
- CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
193
-
194
-
195
- def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
196
- sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1,
197
- filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)):
198
- user_agent = get_user_agent()
199
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
200
- if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
201
- base_url = 'https://civitai.com/api/v1/models'
202
- params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'}
203
- if len(types) != 0: params["types"] = types
204
- if query: params["query"] = query
205
- if tag: params["tag"] = tag
206
- if user: params["username"] = user
207
- if page != 0: params["page"] = int(page)
208
- session = requests.Session()
209
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
210
- session.mount("https://", HTTPAdapter(max_retries=retries))
211
- rs = []
212
- try:
213
- if page == 0:
214
- progress(0, desc="Searching page 1...")
215
- print("Searching page 1...")
216
- r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30))
217
- rs.append(r)
218
- if r.ok:
219
- json = r.json()
220
- next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
221
- i = 2
222
- while(next_url is not None):
223
- progress(0, desc=f"Searching page {i}...")
224
- print(f"Searching page {i}...")
225
- r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30))
226
- rs.append(r)
227
- if r.ok:
228
- json = r.json()
229
- next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
230
- else: next_url = None
231
- i += 1
232
- else:
233
- progress(0, desc="Searching page 1...")
234
- print("Searching page 1...")
235
- r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30))
236
- rs.append(r)
237
- except requests.exceptions.ConnectTimeout:
238
- print("Request timed out.")
239
- except Exception as e:
240
- print(e)
241
- items = []
242
- for r in rs:
243
- if not r.ok: continue
244
- json = r.json()
245
- if 'items' not in json: continue
246
- for j in json['items']:
247
- for model in j['modelVersions']:
248
- item = {}
249
- if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
250
- item['name'] = j['name']
251
- item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
252
- item['tags'] = j['tags'] if 'tags' in j.keys() else []
253
- item['model_name'] = model['name'] if 'name' in model.keys() else ""
254
- item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
255
- item['description'] = model['description'] if 'description' in model.keys() else ""
256
- item['md'] = ""
257
- if 'images' in model.keys() and len(model["images"]) != 0:
258
- item['img_url'] = model["images"][0]["url"]
259
- item['md'] += f'<img src="{model["images"][0]["url"]}#float" alt="thumbnail" width="150" height="240"><br>'
260
- else: item['img_url'] = "/home/user/app/null.png"
261
- item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br>
262
- Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}'''
263
- if 'files' in model.keys():
264
- for f in model['files']:
265
- i = item.copy()
266
- i['dl_url'] = f['downloadUrl']
267
- if len(filetype) != 0 and f['type'] not in set(filetype): continue
268
- items.append(i)
269
- else:
270
- item['dl_url'] = model['downloadUrl']
271
- items.append(item)
272
- return items if len(items) > 0 else None
273
-
274
-
275
- def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag="", user="", limit=100, page=1,
276
- filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)):
277
- civitai_last_results = {}
278
- set_state(state, "civitai_last_choices", [("", "")])
279
- set_state(state, "civitai_last_gallery", [])
280
- set_state(state, "civitai_last_results", civitai_last_results)
281
- results_info = "No item found."
282
- items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key)
283
- if not items: return gr.update(choices=[("", "")], value=[], visible=True),\
284
- gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state
285
- choices = []
286
- gallery = []
287
- for item in items:
288
- base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
289
- name = f"{item['name']} (for {base_model_name} / By: {item['creator']})"
290
- value = item['dl_url']
291
- choices.append((name, value))
292
- gallery.append((item['img_url'], name))
293
- civitai_last_results[value] = item
294
- if len(choices) >= 1: results_info = f"{int(len(choices))} items found."
295
- else: choices = [("", "")]
296
- md = ""
297
- set_state(state, "civitai_last_choices", choices)
298
- set_state(state, "civitai_last_gallery", gallery)
299
- set_state(state, "civitai_last_results", civitai_last_results)
300
- return gr.update(choices=choices, value=[], visible=True), gr.update(value=md, visible=True),\
301
- gr.update(), gr.update(), gr.update(value=gallery), gr.update(choices=choices, value=[]), results_info, state
302
-
303
-
304
- def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""):
305
- if not image_baseurl: image_baseurl = dl_url
306
- default = ("", "", "") if is_html else ""
307
- if "https://civitai.com/api/download/models/" not in dl_url: return default
308
- user_agent = get_user_agent()
309
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
310
- if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
311
- base_url = 'https://civitai.com/api/v1/model-versions/'
312
- params = {}
313
- session = requests.Session()
314
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
315
- session.mount("https://", HTTPAdapter(max_retries=retries))
316
- model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url)
317
- url = base_url + model_id
318
- #url = base_url + str(dl_url.split("/")[-1])
319
- try:
320
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
321
- if not r.ok: return default
322
- json = dict(r.json()).copy()
323
- html = ""
324
- image = ""
325
- if "modelId" in json.keys():
326
- url = f"https://civitai.com/models/{json['modelId']}"
327
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
328
- if not r.ok: return json, html, image
329
- html = r.text
330
- if 'images' in json.keys() and len(json["images"]) != 0:
331
- url = json["images"][0]["url"]
332
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
333
- if not r.ok: return json, html, image
334
- image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix))
335
- image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png"))
336
- with open(image_temp, 'wb') as f:
337
- f.write(r.content)
338
- Image.open(image_temp).convert('RGBA').save(image)
339
- return json, html, image
340
- except Exception as e:
341
- print(e)
342
- return default
343
-
344
-
345
- def get_civitai_tag():
346
- default = [""]
347
- user_agent = get_user_agent()
348
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
349
- base_url = 'https://civitai.com/api/v1/tags'
350
- params = {'limit': 200}
351
- session = requests.Session()
352
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
353
- session.mount("https://", HTTPAdapter(max_retries=retries))
354
- url = base_url
355
- try:
356
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15))
357
- if not r.ok: return default
358
- j = dict(r.json()).copy()
359
- if "items" not in j.keys(): return default
360
- items = []
361
- for item in j["items"]:
362
- items.append([str(item.get("name", "")), int(item.get("modelCount", 0))])
363
- df = pd.DataFrame(items)
364
- df.sort_values(1, ascending=False)
365
- tags = df.values.tolist()
366
- tags = [""] + [l[0] for l in tags]
367
- return tags
368
- except Exception as e:
369
- print(e)
370
- return default
371
-
372
-
373
- def select_civitai_item(results: list[str], state: dict):
374
- json = {}
375
- if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state
376
- result = get_state(state, "civitai_last_results")
377
- last_selects = get_state(state, "civitai_last_selects")
378
- selects = list_sub(results, last_selects if last_selects else [])
379
- md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else ""
380
- set_state(state, "civitai_last_selects", results)
381
- return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state
382
-
383
-
384
- def add_civitai_item(results: list[str], dl_url: str):
385
- if "http" not in "".join(results): return gr.update(value=dl_url)
386
- new_url = dl_url if dl_url else ""
387
- for result in results:
388
- if "http" not in result: continue
389
- new_url += f"\n{result}" if new_url else f"{result}"
390
- new_url = uniq_urls(new_url)
391
- return gr.update(value=new_url)
392
-
393
-
394
- def select_civitai_all_item(button_name: str, state: dict):
395
- if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True)
396
- civitai_last_choices = get_state(state, "civitai_last_choices")
397
- selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else []
398
- new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All"
399
- return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices)
400
-
401
-
402
- def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict):
403
- try:
404
- civitai_last_choices = get_state(state, "civitai_last_choices")
405
- selected_index = evt.index
406
- selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]])
407
- return gr.update(value=selected)
408
- except Exception:
409
- return gr.update()
410
-
411
-
412
- def update_civitai_checkbox(selected: list[str]):
413
- return gr.update(value=selected)
414
-
415
-
416
- def from_civitai_checkbox(selected: list[str]):
417
- return gr.update(value=selected)
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, hf_hub_url
3
+ import os
4
+ from pathlib import Path
5
+ import gc
6
+ import requests
7
+ from requests.adapters import HTTPAdapter
8
+ from urllib3.util import Retry
9
+ from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file,
10
+ list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state)
11
+ import re
12
+ from PIL import Image
13
+ import json
14
+ import pandas as pd
15
+ import tempfile
16
+ import hashlib
17
+
18
+
19
+ TEMP_DIR = os.getcwd())
20
+
21
+ def parse_urls(s):
22
+ url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
23
+ try:
24
+ urls = re.findall(url_pattern, s)
25
+ return list(urls)
26
+ except Exception:
27
+ return []
28
+
29
+
30
+ def parse_repos(s):
31
+ repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?'
32
+ try:
33
+ s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s)
34
+ repos = re.findall(repo_pattern, s)
35
+ return list(repos)
36
+ except Exception:
37
+ return []
38
+
39
+
40
+ def to_urls(l: list[str]):
41
+ return "\n".join(l)
42
+
43
+
44
+ def uniq_urls(s):
45
+ return to_urls(list_uniq(parse_urls(s) + parse_repos(s)))
46
+
47
+
48
+
49
+ def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
50
+ download_dir = TEMP_DIR
51
+ progress(0, desc=f"Start downloading... {dl_url}")
52
+ output_filename = get_download_file(download_dir, dl_url, civitai_key)
53
+ return output_filename
54
+
55
+
56
+ def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)):
57
+ json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key)
58
+ if not json_str: return "", "", ""
59
+ json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json"))
60
+ html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html"))
61
+ try:
62
+ with open(json_path, 'w') as f:
63
+ json.dump(json_str, f, indent=2)
64
+ with open(html_path, mode='w', encoding="utf-8") as f:
65
+ f.write(html_str)
66
+ return json_path, html_path, image_path
67
+ except Exception as e:
68
+ print(f"Error: Failed to save info file {json_path}, {html_path} {e}")
69
+ return "", "", ""
70
+
71
+
72
+ def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)):
73
+ def upload_file(api, filename, repo_id, repo_type, hf_token):
74
+ if not Path(filename).exists(): return
75
+ api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
76
+ Path(filename).unlink()
77
+
78
+ hf_token = get_token()
79
+ api = HfApi(token=hf_token)
80
+ try:
81
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
82
+ progress(0, desc=f"Downloading info... {filename}")
83
+ json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key)
84
+ progress(0, desc=f"Start uploading info... {filename} to {repo_id}")
85
+ if not json_path: return
86
+ else: upload_file(api, json_path, repo_id, repo_type, hf_token)
87
+ if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token)
88
+ if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token)
89
+ progress(1, desc="Info uploaded.")
90
+ return
91
+ except Exception as e:
92
+ print(f"Error: Failed to upload info to {repo_id}. {e}")
93
+ gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}")
94
+ return
95
+
96
+
97
+ def download_civitai(dl_url, civitai_key, hf_token, urls,
98
+ newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)):
99
+ if hf_token: set_token(hf_token)
100
+ else: set_token(os.environ.get("HF_TOKEN")) # default huggingface write token
101
+ if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
102
+ if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload
103
+ if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
104
+ if not urls: urls = []
105
+ dl_urls = parse_urls(dl_url)
106
+ remain_urls = dl_urls.copy()
107
+ try:
108
+ md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n'
109
+ for u in dl_urls:
110
+ file = download_file(u, civitai_key)
111
+ if not Path(file).exists() or not Path(file).is_file(): continue
112
+ if is_rename: file = get_safe_filename(file, newrepo_id, repo_type)
113
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
114
+ if url:
115
+ if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key)
116
+ urls.append(url)
117
+ remain_urls.remove(u)
118
+ md += f"- Uploaded [{str(u)}]({str(u)})\n"
119
+ dp_repos = parse_repos(dl_url)
120
+ for r in dp_repos:
121
+ url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1])
122
+ if url: urls.append(url)
123
+ return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False)
124
+ except Exception as e:
125
+ gr.Info(f"Error occured: {e}")
126
+ return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True)
127
+ finally:
128
+ gc.collect()
129
+
130
+
131
+ CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "DoRA",
132
+ "Controlnet", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"]
133
+ CIVITAI_FILETYPE = ["Model", "VAE", "Config", "Training Data"]
134
+ CIVITAI_BASEMODEL = ["Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S", "SD 3.5"]
135
+ #CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
136
+ CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Most Liked", "Most Discussed", "Most Collected", "Most Buzz", "Newest"]
137
+ CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
138
+
139
+
140
+ def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
141
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1,
142
+ filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)):
143
+ user_agent = get_user_agent()
144
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
145
+ if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
146
+ base_url = 'https://civitai.com/api/v1/models'
147
+ params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'}
148
+ if len(types) != 0: params["types"] = types
149
+ if query: params["query"] = query
150
+ if tag: params["tag"] = tag
151
+ if user: params["username"] = user
152
+ if page != 0: params["page"] = int(page)
153
+ session = requests.Session()
154
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
155
+ session.mount("https://", HTTPAdapter(max_retries=retries))
156
+ rs = []
157
+ try:
158
+ if page == 0:
159
+ progress(0, desc="Searching page 1...")
160
+ print("Searching page 1...")
161
+ r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30))
162
+ rs.append(r)
163
+ if r.ok:
164
+ json = r.json()
165
+ next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
166
+ i = 2
167
+ while(next_url is not None):
168
+ progress(0, desc=f"Searching page {i}...")
169
+ print(f"Searching page {i}...")
170
+ r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30))
171
+ rs.append(r)
172
+ if r.ok:
173
+ json = r.json()
174
+ next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
175
+ else: next_url = None
176
+ i += 1
177
+ else:
178
+ progress(0, desc="Searching page 1...")
179
+ print("Searching page 1...")
180
+ r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30))
181
+ rs.append(r)
182
+ except requests.exceptions.ConnectTimeout:
183
+ print("Request timed out.")
184
+ except Exception as e:
185
+ print(e)
186
+ items = []
187
+ for r in rs:
188
+ if not r.ok: continue
189
+ json = r.json()
190
+ if 'items' not in json: continue
191
+ for j in json['items']:
192
+ for model in j['modelVersions']:
193
+ item = {}
194
+ if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
195
+ item['name'] = j['name']
196
+ item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
197
+ item['tags'] = j['tags'] if 'tags' in j.keys() else []
198
+ item['model_name'] = model['name'] if 'name' in model.keys() else ""
199
+ item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
200
+ item['description'] = model['description'] if 'description' in model.keys() else ""
201
+ item['md'] = ""
202
+ if 'images' in model.keys() and len(model["images"]) != 0:
203
+ item['img_url'] = model["images"][0]["url"]
204
+ item['md'] += f'<img src="{model["images"][0]["url"]}#float" alt="thumbnail" width="150" height="240"><br>'
205
+ else: item['img_url'] = "/home/user/app/null.png"
206
+ item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br>
207
+ Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}'''
208
+ if 'files' in model.keys():
209
+ for f in model['files']:
210
+ i = item.copy()
211
+ i['dl_url'] = f['downloadUrl']
212
+ if len(filetype) != 0 and f['type'] not in set(filetype): continue
213
+ items.append(i)
214
+ else:
215
+ item['dl_url'] = model['downloadUrl']
216
+ items.append(item)
217
+ return items if len(items) > 0 else None
218
+
219
+
220
+ def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag="", user="", limit=100, page=1,
221
+ filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)):
222
+ civitai_last_results = {}
223
+ set_state(state, "civitai_last_choices", [("", "")])
224
+ set_state(state, "civitai_last_gallery", [])
225
+ set_state(state, "civitai_last_results", civitai_last_results)
226
+ results_info = "No item found."
227
+ items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key)
228
+ if not items: return gr.update(choices=[("", "")], value=[], visible=True),\
229
+ gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state
230
+ choices = []
231
+ gallery = []
232
+ for item in items:
233
+ base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
234
+ name = f"{item['name']} (for {base_model_name} / By: {item['creator']})"
235
+ value = item['dl_url']
236
+ choices.append((name, value))
237
+ gallery.append((item['img_url'], name))
238
+ civitai_last_results[value] = item
239
+ if len(choices) >= 1: results_info = f"{int(len(choices))} items found."
240
+ else: choices = [("", "")]
241
+ md = ""
242
+ set_state(state, "civitai_last_choices", choices)
243
+ set_state(state, "civitai_last_gallery", gallery)
244
+ set_state(state, "civitai_last_results", civitai_last_results)
245
+ return gr.update(choices=choices, value=[], visible=True), gr.update(value=md, visible=True),\
246
+ gr.update(), gr.update(), gr.update(value=gallery), gr.update(choices=choices, value=[]), results_info, state
247
+
248
+
249
+ def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""):
250
+ if not image_baseurl: image_baseurl = dl_url
251
+ default = ("", "", "") if is_html else ""
252
+ if "https://civitai.com/api/download/models/" not in dl_url: return default
253
+ user_agent = get_user_agent()
254
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
255
+ if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
256
+ base_url = 'https://civitai.com/api/v1/model-versions/'
257
+ params = {}
258
+ session = requests.Session()
259
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
260
+ session.mount("https://", HTTPAdapter(max_retries=retries))
261
+ model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url)
262
+ url = base_url + model_id
263
+ #url = base_url + str(dl_url.split("/")[-1])
264
+ try:
265
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
266
+ if not r.ok: return default
267
+ json = dict(r.json()).copy()
268
+ html = ""
269
+ image = ""
270
+ if "modelId" in json.keys():
271
+ url = f"https://civitai.com/models/{json['modelId']}"
272
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
273
+ if not r.ok: return json, html, image
274
+ html = r.text
275
+ if 'images' in json.keys() and len(json["images"]) != 0:
276
+ url = json["images"][0]["url"]
277
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
278
+ if not r.ok: return json, html, image
279
+ image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix))
280
+ image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png"))
281
+ with open(image_temp, 'wb') as f:
282
+ f.write(r.content)
283
+ Image.open(image_temp).convert('RGBA').save(image)
284
+ return json, html, image
285
+ except Exception as e:
286
+ print(e)
287
+ return default
288
+
289
+
290
+ def get_civitai_tag():
291
+ default = [""]
292
+ user_agent = get_user_agent()
293
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
294
+ base_url = 'https://civitai.com/api/v1/tags'
295
+ params = {'limit': 200}
296
+ session = requests.Session()
297
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
298
+ session.mount("https://", HTTPAdapter(max_retries=retries))
299
+ url = base_url
300
+ try:
301
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15))
302
+ if not r.ok: return default
303
+ j = dict(r.json()).copy()
304
+ if "items" not in j.keys(): return default
305
+ items = []
306
+ for item in j["items"]:
307
+ items.append([str(item.get("name", "")), int(item.get("modelCount", 0))])
308
+ df = pd.DataFrame(items)
309
+ df.sort_values(1, ascending=False)
310
+ tags = df.values.tolist()
311
+ tags = [""] + [l[0] for l in tags]
312
+ return tags
313
+ except Exception as e:
314
+ print(e)
315
+ return default
316
+
317
+
318
+ def select_civitai_item(results: list[str], state: dict):
319
+ json = {}
320
+ if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state
321
+ result = get_state(state, "civitai_last_results")
322
+ last_selects = get_state(state, "civitai_last_selects")
323
+ selects = list_sub(results, last_selects if last_selects else [])
324
+ md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else ""
325
+ set_state(state, "civitai_last_selects", results)
326
+ return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state
327
+
328
+
329
+ def add_civitai_item(results: list[str], dl_url: str):
330
+ if "http" not in "".join(results): return gr.update(value=dl_url)
331
+ new_url = dl_url if dl_url else ""
332
+ for result in results:
333
+ if "http" not in result: continue
334
+ new_url += f"\n{result}" if new_url else f"{result}"
335
+ new_url = uniq_urls(new_url)
336
+ return gr.update(value=new_url)
337
+
338
+
339
+ def select_civitai_all_item(button_name: str, state: dict):
340
+ if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True)
341
+ civitai_last_choices = get_state(state, "civitai_last_choices")
342
+ selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else []
343
+ new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All"
344
+ return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices)
345
+
346
+
347
+ def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict):
348
+ try:
349
+ civitai_last_choices = get_state(state, "civitai_last_choices")
350
+ selected_index = evt.index
351
+ selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]])
352
+ return gr.update(value=selected)
353
+ except Exception:
354
+ return gr.update()
355
+
356
+
357
+ def update_civitai_checkbox(selected: list[str]):
358
+ return gr.update(value=selected)
359
+
360
+
361
+ def from_civitai_checkbox(selected: list[str]):
362
+ return gr.update(value=selected)