aolko commited on
Commit
9d051b5
·
verified ·
1 Parent(s): 333abdb

Create describe.py

Browse files
Files changed (1) hide show
  1. describe.py +80 -0
describe.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import requests
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ # Load models using diffusers
9
+ general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
10
+ anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")
11
+
12
+ # Placeholder functions for the actual implementations
13
+ def check_anime_image(image):
14
+ # Use SauceNAO or similar service to check if the image is anime
15
+ # and fetch similar images and tags
16
+ return False, [], []
17
+
18
+ def describe_image_general(image):
19
+ # Use the general model to describe the image
20
+ description = general_model(image)
21
+ return description
22
+
23
+ def describe_image_anime(image):
24
+ # Use the anime model to describe the image
25
+ description = anime_model(image)
26
+ return description
27
+
28
+ def merge_tags(tags1, tags2):
29
+ # Merge tags, removing duplicates
30
+ return list(set(tags1 + tags2))
31
+
32
+ # Gradio app functions
33
+ def process_image(image, mode):
34
+ # Convert the image to a format suitable for the models
35
+ transform = transforms.Compose([
36
+ transforms.Resize((256, 256)),
37
+ transforms.ToTensor()
38
+ ])
39
+ image = transform(image).unsqueeze(0)
40
+
41
+ if mode == "Anime":
42
+ is_anime, similar_images, original_tags = check_anime_image(image)
43
+ if is_anime:
44
+ tags = describe_image_anime(image)
45
+ return tags, original_tags
46
+ else:
47
+ return ["Not an anime image"], []
48
+ else:
49
+ tags = describe_image_general(image)
50
+ return tags, []
51
+
52
+ def describe(image, mode):
53
+ tags, original_tags = process_image(image, mode)
54
+ return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags))
55
+
56
+ def merge(tags, original_tags):
57
+ merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n"))
58
+ return "\n".join(merged_tags)
59
+
60
+ # Gradio interface
61
+ with gr.Blocks() as demo:
62
+ with gr.Row():
63
+ image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image")
64
+ mode = gr.Dropdown(choices=["Anime", "General"], label="Mode")
65
+
66
+ describe_button = gr.Button("Describe")
67
+ merge_button = gr.Button("Merge Tags")
68
+
69
+ with gr.TabGroup() as tab_group:
70
+ with gr.TabItem("Described Tags"):
71
+ described_tags = gr.TextArea(label="Described Tags")
72
+ with gr.TabItem("Original Tags"):
73
+ original_tags = gr.TextArea(label="Original Tags")
74
+
75
+ merged_tags = gr.TextArea(label="Merged Tags")
76
+
77
+ describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags])
78
+ merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags)
79
+
80
+ demo.launch()