aolko commited on
Commit
2923422
·
verified ·
1 Parent(s): bb14bef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -5,12 +5,29 @@ from diffusers import DiffusionPipeline
5
  import requests
6
  from PIL import Image
7
  from io import BytesIO
 
 
8
 
9
  # Initialize models
10
- anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-convnext-tagger-v3")
 
11
  photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
12
  processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def get_booru_image(booru, image_id):
15
  if booru == "Gelbooru":
16
  url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
@@ -36,8 +53,14 @@ def get_booru_image(booru, image_id):
36
 
37
  def transcribe_image(image, image_type, transcriber, booru_tags=None):
38
  if image_type == "Anime":
39
- with torch.no_grad():
40
- tags = anime_model(image)
 
 
 
 
 
 
41
  else:
42
  inputs = processor(images=image, return_tensors="pt")
43
  outputs = photo_model(**inputs)
 
5
  import requests
6
  from PIL import Image
7
  from io import BytesIO
8
+ import onnxruntime as ort
9
+ from huggingface_hub import hf_hub_download
10
 
11
  # Initialize models
12
+ anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
13
+ anime_model = ort.InferenceSession(anime_model_path)
14
  photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
15
  processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
16
 
17
+ # Load labels for the anime model
18
+ labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
19
+ with open(labels_path, 'r') as f:
20
+ labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header
21
+
22
+ def preprocess_image(image):
23
+ image = image.convert('RGB')
24
+ image = image.resize((448, 448), Image.LANCZOS)
25
+ image = np.array(image).astype(np.float32)
26
+ image = image[:, :, ::-1] # RGB -> BGR
27
+ image = np.transpose(image, (2, 0, 1)) # HWC -> CHW
28
+ image = image / 255.0
29
+ return image[np.newaxis, ...]
30
+
31
  def get_booru_image(booru, image_id):
32
  if booru == "Gelbooru":
33
  url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
 
53
 
54
  def transcribe_image(image, image_type, transcriber, booru_tags=None):
55
  if image_type == "Anime":
56
+ input_image = preprocess_image(image)
57
+ input_name = anime_model.get_inputs()[0].name
58
+ output_name = anime_model.get_outputs()[0].name
59
+ probs = anime_model.run([output_name], {input_name: input_image})[0]
60
+
61
+ # Get top 50 tags
62
+ top_indices = probs[0].argsort()[-50:][::-1]
63
+ tags = [labels[i] for i in top_indices]
64
  else:
65
  inputs = processor(images=image, return_tensors="pt")
66
  outputs = photo_model(**inputs)