Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,13 @@ from huggingface_hub import hf_hub_download
|
|
11 |
# Initialize models
|
12 |
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
|
13 |
anime_model = ort.InferenceSession(anime_model_path)
|
14 |
-
photo_model =
|
15 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
16 |
|
17 |
# Load labels for the anime model
|
18 |
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
|
19 |
with open(labels_path, 'r') as f:
|
20 |
-
|
21 |
|
22 |
def preprocess_image(image):
|
23 |
image = image.convert('RGB')
|
@@ -28,6 +28,39 @@ def preprocess_image(image):
|
|
28 |
image = image / 255.0
|
29 |
return image[np.newaxis, ...]
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def get_booru_image(booru, image_id):
|
32 |
if booru == "Gelbooru":
|
33 |
url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
|
@@ -51,27 +84,6 @@ def get_booru_image(booru, image_id):
|
|
51 |
|
52 |
return img, tags
|
53 |
|
54 |
-
def transcribe_image(image, image_type, transcriber, booru_tags=None):
|
55 |
-
if image_type == "Anime":
|
56 |
-
input_image = preprocess_image(image)
|
57 |
-
input_name = anime_model.get_inputs()[0].name
|
58 |
-
output_name = anime_model.get_outputs()[0].name
|
59 |
-
probs = anime_model.run([output_name], {input_name: input_image})[0]
|
60 |
-
|
61 |
-
# Get top 50 tags
|
62 |
-
top_indices = probs[0].argsort()[-50:][::-1]
|
63 |
-
tags = [labels[i] for i in top_indices]
|
64 |
-
else:
|
65 |
-
inputs = processor(images=image, return_tensors="pt")
|
66 |
-
outputs = photo_model(**inputs)
|
67 |
-
tags = outputs.logits.topk(50).indices.squeeze().tolist()
|
68 |
-
tags = [processor.config.id2label[t] for t in tags]
|
69 |
-
|
70 |
-
if booru_tags:
|
71 |
-
tags = list(set(tags + booru_tags))
|
72 |
-
|
73 |
-
return ", ".join(tags)
|
74 |
-
|
75 |
def update_image(image_type, booru, image_id, uploaded_image):
|
76 |
if image_type == "Anime" and booru != "Upload":
|
77 |
image, booru_tags = get_booru_image(booru, image_id)
|
|
|
11 |
# Initialize models
|
12 |
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
|
13 |
anime_model = ort.InferenceSession(anime_model_path)
|
14 |
+
photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
15 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
16 |
|
17 |
# Load labels for the anime model
|
18 |
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
|
19 |
with open(labels_path, 'r') as f:
|
20 |
+
anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header
|
21 |
|
22 |
def preprocess_image(image):
|
23 |
image = image.convert('RGB')
|
|
|
28 |
image = image / 255.0
|
29 |
return image[np.newaxis, ...]
|
30 |
|
31 |
+
def transcribe_image(image, image_type, transcriber, booru_tags=None):
|
32 |
+
if image_type == "Anime":
|
33 |
+
input_image = preprocess_image(image)
|
34 |
+
input_name = anime_model.get_inputs()[0].name
|
35 |
+
output_name = anime_model.get_outputs()[0].name
|
36 |
+
probs = anime_model.run([output_name], {input_name: input_image})[0]
|
37 |
+
|
38 |
+
# Get top 50 tags
|
39 |
+
top_indices = probs[0].argsort()[-50:][::-1]
|
40 |
+
tags = [anime_labels[i] for i in top_indices]
|
41 |
+
else:
|
42 |
+
prompt = "<MORE_DETAILED_CAPTION>"
|
43 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt")
|
44 |
+
|
45 |
+
generated_ids = photo_model.generate(
|
46 |
+
input_ids=inputs["input_ids"],
|
47 |
+
pixel_values=inputs["pixel_values"],
|
48 |
+
max_new_tokens=1024,
|
49 |
+
do_sample=False,
|
50 |
+
num_beams=3,
|
51 |
+
)
|
52 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
53 |
+
parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
|
54 |
+
|
55 |
+
# Extract tags from parsed_answer
|
56 |
+
tags = [obj['class'] for obj in parsed_answer]
|
57 |
+
|
58 |
+
if booru_tags:
|
59 |
+
tags = list(set(tags + booru_tags))
|
60 |
+
|
61 |
+
return ", ".join(tags)
|
62 |
+
|
63 |
+
|
64 |
def get_booru_image(booru, image_id):
|
65 |
if booru == "Gelbooru":
|
66 |
url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
|
|
|
84 |
|
85 |
return img, tags
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def update_image(image_type, booru, image_id, uploaded_image):
|
88 |
if image_type == "Anime" and booru != "Upload":
|
89 |
image, booru_tags = get_booru_image(booru, image_id)
|