Spaces:
Build error
Build error
Create describe.py
Browse files- describe.py +80 -0
describe.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import requests
|
6 |
+
from diffusers import StableDiffusionPipeline
|
7 |
+
|
8 |
+
# Load models using diffusers
|
9 |
+
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
10 |
+
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")
|
11 |
+
|
12 |
+
# Placeholder functions for the actual implementations
|
13 |
+
def check_anime_image(image):
|
14 |
+
# Use SauceNAO or similar service to check if the image is anime
|
15 |
+
# and fetch similar images and tags
|
16 |
+
return False, [], []
|
17 |
+
|
18 |
+
def describe_image_general(image):
|
19 |
+
# Use the general model to describe the image
|
20 |
+
description = general_model(image)
|
21 |
+
return description
|
22 |
+
|
23 |
+
def describe_image_anime(image):
|
24 |
+
# Use the anime model to describe the image
|
25 |
+
description = anime_model(image)
|
26 |
+
return description
|
27 |
+
|
28 |
+
def merge_tags(tags1, tags2):
|
29 |
+
# Merge tags, removing duplicates
|
30 |
+
return list(set(tags1 + tags2))
|
31 |
+
|
32 |
+
# Gradio app functions
|
33 |
+
def process_image(image, mode):
|
34 |
+
# Convert the image to a format suitable for the models
|
35 |
+
transform = transforms.Compose([
|
36 |
+
transforms.Resize((256, 256)),
|
37 |
+
transforms.ToTensor()
|
38 |
+
])
|
39 |
+
image = transform(image).unsqueeze(0)
|
40 |
+
|
41 |
+
if mode == "Anime":
|
42 |
+
is_anime, similar_images, original_tags = check_anime_image(image)
|
43 |
+
if is_anime:
|
44 |
+
tags = describe_image_anime(image)
|
45 |
+
return tags, original_tags
|
46 |
+
else:
|
47 |
+
return ["Not an anime image"], []
|
48 |
+
else:
|
49 |
+
tags = describe_image_general(image)
|
50 |
+
return tags, []
|
51 |
+
|
52 |
+
def describe(image, mode):
|
53 |
+
tags, original_tags = process_image(image, mode)
|
54 |
+
return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags))
|
55 |
+
|
56 |
+
def merge(tags, original_tags):
|
57 |
+
merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n"))
|
58 |
+
return "\n".join(merged_tags)
|
59 |
+
|
60 |
+
# Gradio interface
|
61 |
+
with gr.Blocks() as demo:
|
62 |
+
with gr.Row():
|
63 |
+
image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image")
|
64 |
+
mode = gr.Dropdown(choices=["Anime", "General"], label="Mode")
|
65 |
+
|
66 |
+
describe_button = gr.Button("Describe")
|
67 |
+
merge_button = gr.Button("Merge Tags")
|
68 |
+
|
69 |
+
with gr.TabGroup() as tab_group:
|
70 |
+
with gr.TabItem("Described Tags"):
|
71 |
+
described_tags = gr.TextArea(label="Described Tags")
|
72 |
+
with gr.TabItem("Original Tags"):
|
73 |
+
original_tags = gr.TextArea(label="Original Tags")
|
74 |
+
|
75 |
+
merged_tags = gr.TextArea(label="Merged Tags")
|
76 |
+
|
77 |
+
describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags])
|
78 |
+
merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags)
|
79 |
+
|
80 |
+
demo.launch()
|