Image Classification
timm
drhead commited on
Commit
eb8bf23
·
verified ·
1 Parent(s): cdacbad

fuck it, gradio demo

Browse files
Files changed (1) hide show
  1. inference_gradio.py +169 -0
inference_gradio.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import torch
6
+ from torchvision.transforms import transforms
7
+ from torchvision.transforms import InterpolationMode
8
+ import torchvision.transforms.functional as TF
9
+
10
+ import timm
11
+ from timm.models import VisionTransformer
12
+ import safetensors.torch
13
+
14
+
15
+ torch.jit.script = lambda f: f
16
+ torch.set_grad_enabled(False)
17
+
18
+ class Fit(torch.nn.Module):
19
+ def __init__(
20
+ self,
21
+ bounds: tuple[int, int] | int,
22
+ interpolation = InterpolationMode.LANCZOS,
23
+ grow: bool = True,
24
+ pad: float | None = None
25
+ ):
26
+ super().__init__()
27
+
28
+ self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
29
+ self.interpolation = interpolation
30
+ self.grow = grow
31
+ self.pad = pad
32
+
33
+ def forward(self, img: Image) -> Image:
34
+ wimg, himg = img.size
35
+ hbound, wbound = self.bounds
36
+
37
+ hscale = hbound / himg
38
+ wscale = wbound / wimg
39
+
40
+ if not self.grow:
41
+ hscale = min(hscale, 1.0)
42
+ wscale = min(wscale, 1.0)
43
+
44
+ scale = min(hscale, wscale)
45
+ if scale == 1.0:
46
+ return img
47
+
48
+ hnew = min(round(himg * scale), hbound)
49
+ wnew = min(round(wimg * scale), wbound)
50
+
51
+ img = TF.resize(img, (hnew, wnew), self.interpolation)
52
+
53
+ if self.pad is None:
54
+ return img
55
+
56
+ hpad = hbound - hnew
57
+ wpad = wbound - wnew
58
+
59
+ tpad = hpad // 2
60
+ bpad = hpad - tpad
61
+
62
+ lpad = wpad // 2
63
+ rpad = wpad - lpad
64
+
65
+ return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
66
+
67
+ def __repr__(self) -> str:
68
+ return (
69
+ f"{self.__class__.__name__}(" +
70
+ f"bounds={self.bounds}, " +
71
+ f"interpolation={self.interpolation.value}, " +
72
+ f"grow={self.grow}, " +
73
+ f"pad={self.pad})"
74
+ )
75
+
76
+ class CompositeAlpha(torch.nn.Module):
77
+ def __init__(
78
+ self,
79
+ background: tuple[float, float, float] | float,
80
+ ):
81
+ super().__init__()
82
+
83
+ self.background = (background, background, background) if isinstance(background, float) else background
84
+ self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
85
+
86
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
87
+ if img.shape[-3] == 3:
88
+ return img
89
+
90
+ alpha = img[..., 3, None, :, :]
91
+
92
+ img[..., :3, :, :] *= alpha
93
+
94
+ background = self.background.expand(-1, img.shape[-2], img.shape[-1])
95
+ if background.ndim == 1:
96
+ background = background[:, None, None]
97
+ elif background.ndim == 2:
98
+ background = background[None, :, :]
99
+
100
+ img[..., :3, :, :] += (1.0 - alpha) * background
101
+ return img[..., :3, :, :]
102
+
103
+ def __repr__(self) -> str:
104
+ return (
105
+ f"{self.__class__.__name__}(" +
106
+ f"background={self.background})"
107
+ )
108
+
109
+ transform = transforms.Compose([
110
+ Fit((384, 384)),
111
+ transforms.ToTensor(),
112
+ CompositeAlpha(0.5),
113
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
114
+ transforms.CenterCrop((384, 384)),
115
+ ])
116
+
117
+ model = timm.create_model(
118
+ "vit_so400m_patch14_siglip_384.webli",
119
+ pretrained=False,
120
+ num_classes=9083,
121
+ ) # type: VisionTransformer
122
+
123
+ safetensors.torch.load_model(model, "JTP_PILOT/JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
124
+ model.eval()
125
+
126
+ with open("JTP_PILOT/tags.json", "r") as file:
127
+ tags = json.load(file) # type: dict
128
+ allowed_tags = list(tags.keys())
129
+
130
+ def create_tags(image, threshold):
131
+ img = image.convert('RGB')
132
+ tensor = transform(img).unsqueeze(0)
133
+
134
+ with torch.no_grad():
135
+ logits = model(tensor)
136
+ probabilities = torch.nn.functional.sigmoid(logits[0])
137
+ indices = torch.where(probabilities > threshold)[0]
138
+ values = probabilities[indices]
139
+
140
+ temp = []
141
+ tag_score = dict()
142
+ for i in range(indices.size(0)):
143
+ temp.append([allowed_tags[indices[i]], values[i].item()])
144
+ tag_score[allowed_tags[indices[i]]] = values[i].item()
145
+ temp = [t[0] for t in temp]
146
+ text_no_impl = ", ".join(temp)
147
+ return text_no_impl, tag_score
148
+
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("""
151
+ ## Joint Tagger Project: PILOT
152
+ This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
153
+
154
+ This tagger is the result of joint efforts between members of the RedRocket team.
155
+
156
+ Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
157
+ """)
158
+ gr.Interface(
159
+ create_tags,
160
+ inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")],
161
+ outputs=[
162
+ gr.Textbox(label="Tag String"),
163
+ gr.Label(label="Tag Predictions", num_top_classes=200),
164
+ ],
165
+ allow_flagging="never",
166
+ )
167
+
168
+ if __name__ == "__main__":
169
+ demo.launch()