John6666 commited on
Commit
ca34c27
·
verified ·
1 Parent(s): f2a4a58

Upload 2 files

Browse files
Files changed (2) hide show
  1. env.py +2 -1
  2. tagger.py +11 -4
env.py CHANGED
@@ -40,7 +40,8 @@ load_diffusers_format_model = [
40
  'rubbrband/realcartoonRealistic_v14',
41
  'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
42
  'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
43
- 'kayfahaarukku/UrangDiffusion-1.1',
 
44
  'Eugeoter/artiwaifu-diffusion-1.0',
45
  'Raelina/Rae-Diffusion-XL-V2',
46
  'Raelina/Raemu-XL-V4',
 
40
  'rubbrband/realcartoonRealistic_v14',
41
  'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
42
  'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
43
+ 'KBlueLeaf/Kohaku-XL-Zeta',
44
+ 'kayfahaarukku/UrangDiffusion-1.2',
45
  'Eugeoter/artiwaifu-diffusion-1.0',
46
  'Raelina/Rae-Diffusion-XL-V2',
47
  'Raelina/Raemu-XL-V4',
tagger.py CHANGED
@@ -12,10 +12,15 @@ from pathlib import Path
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
- wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
16
- wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
17
- wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
 
 
 
 
 
 
 
19
 
20
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
21
  return (
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
506
  return ", ".join(all_tags)
507
 
508
 
509
- @spaces.GPU()
510
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
511
  inputs = wd_processor.preprocess(image, return_tensors="pt")
512
 
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
514
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
515
 
516
  # get probabilities
 
517
  results = {
518
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
519
  }
 
520
  # rating, character, general
521
  rating, character, general = postprocess_results(
522
  results, general_threshold, character_threshold
 
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ default_device = device
 
17
 
18
+ try:
19
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
20
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
21
+ except Exception as e:
22
+ print(e)
23
+ wd_model = wd_processor = None
24
 
25
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
26
  return (
 
511
  return ", ".join(all_tags)
512
 
513
 
514
+ @spaces.GPU(duration=30)
515
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
516
  inputs = wd_processor.preprocess(image, return_tensors="pt")
517
 
 
519
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
520
 
521
  # get probabilities
522
+ if device != default_device: wd_model.to(device=device)
523
  results = {
524
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
525
  }
526
+ if device != default_device: wd_model.to(device=default_device)
527
  # rating, character, general
528
  rating, character, general = postprocess_results(
529
  results, general_threshold, character_threshold