Spaces:
Running
on
Zero
Running
on
Zero
Upload 2 files
Browse files
env.py
CHANGED
@@ -40,7 +40,8 @@ load_diffusers_format_model = [
|
|
40 |
'rubbrband/realcartoonRealistic_v14',
|
41 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
|
42 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
|
43 |
-
'
|
|
|
44 |
'Eugeoter/artiwaifu-diffusion-1.0',
|
45 |
'Raelina/Rae-Diffusion-XL-V2',
|
46 |
'Raelina/Raemu-XL-V4',
|
|
|
40 |
'rubbrband/realcartoonRealistic_v14',
|
41 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
|
42 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
|
43 |
+
'KBlueLeaf/Kohaku-XL-Zeta',
|
44 |
+
'kayfahaarukku/UrangDiffusion-1.2',
|
45 |
'Eugeoter/artiwaifu-diffusion-1.0',
|
46 |
'Raelina/Rae-Diffusion-XL-V2',
|
47 |
'Raelina/Raemu-XL-V4',
|
tagger.py
CHANGED
@@ -12,10 +12,15 @@ from pathlib import Path
|
|
12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
21 |
return (
|
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
|
506 |
return ", ".join(all_tags)
|
507 |
|
508 |
|
509 |
-
@spaces.GPU()
|
510 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
511 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
512 |
|
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
514 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
515 |
|
516 |
# get probabilities
|
|
|
517 |
results = {
|
518 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
519 |
}
|
|
|
520 |
# rating, character, general
|
521 |
rating, character, general = postprocess_results(
|
522 |
results, general_threshold, character_threshold
|
|
|
12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
14 |
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
default_device = device
|
|
|
17 |
|
18 |
+
try:
|
19 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
|
20 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
21 |
+
except Exception as e:
|
22 |
+
print(e)
|
23 |
+
wd_model = wd_processor = None
|
24 |
|
25 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
26 |
return (
|
|
|
511 |
return ", ".join(all_tags)
|
512 |
|
513 |
|
514 |
+
@spaces.GPU(duration=30)
|
515 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
516 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
517 |
|
|
|
519 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
520 |
|
521 |
# get probabilities
|
522 |
+
if device != default_device: wd_model.to(device=device)
|
523 |
results = {
|
524 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
525 |
}
|
526 |
+
if device != default_device: wd_model.to(device=default_device)
|
527 |
# rating, character, general
|
528 |
rating, character, general = postprocess_results(
|
529 |
results, general_threshold, character_threshold
|