drhead commited on
Commit
374865c
1 Parent(s): d7bad8b

major UI overhaul

Browse files
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -128,23 +128,33 @@ allowed_tags = list(tags.keys())
128
  for idx, tag in enumerate(allowed_tags):
129
  allowed_tags[idx] = tag.replace("_", " ")
130
 
 
 
131
  @spaces.GPU(duration=5)
132
- def create_tags(image, threshold):
 
133
  img = image.convert('RGB')
134
  tensor = transform(img).unsqueeze(0)
135
 
136
  with torch.no_grad():
137
  logits = model(tensor)
138
  probabilities = torch.nn.functional.sigmoid(logits[0])
139
- indices = torch.where(probabilities > threshold)[0]
140
  values = probabilities[indices]
141
 
142
  tag_score = dict()
143
  for i in range(indices.size(0)):
144
  tag_score[allowed_tags[indices[i]]] = values[i].item()
145
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
146
- text_no_impl = ", ".join(sorted_tag_score.keys())
147
- return text_no_impl, sorted_tag_score
 
 
 
 
 
 
 
148
 
149
  with gr.Blocks(css=".output-class { display: none; }") as demo:
150
  gr.Markdown("""
@@ -153,14 +163,24 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
153
 
154
  This tagger is the result of joint efforts between members of the RedRocket team. Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
155
  """)
156
- gr.Interface(
157
- create_tags,
158
- inputs=[gr.Image(label="Source", sources=['upload'], type='pil', height="50vh", show_label=False), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")],
159
- outputs=[
160
- gr.Textbox(label="Tag String"),
161
- gr.Label(label="Tag Predictions", num_top_classes=200, show_label=False),
162
- ],
163
- allow_flagging="never",
 
 
 
 
 
 
 
 
 
 
164
  )
165
 
166
  if __name__ == "__main__":
 
128
  for idx, tag in enumerate(allowed_tags):
129
  allowed_tags[idx] = tag.replace("_", " ")
130
 
131
+ sorted_tag_score = {}
132
+
133
  @spaces.GPU(duration=5)
134
+ def run_classifier(image, threshold):
135
+ global sorted_tag_score
136
  img = image.convert('RGB')
137
  tensor = transform(img).unsqueeze(0)
138
 
139
  with torch.no_grad():
140
  logits = model(tensor)
141
  probabilities = torch.nn.functional.sigmoid(logits[0])
142
+ indices = torch.where(probabilities, 250).indices
143
  values = probabilities[indices]
144
 
145
  tag_score = dict()
146
  for i in range(indices.size(0)):
147
  tag_score[allowed_tags[indices[i]]] = values[i].item()
148
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
149
+
150
+ return create_tags(threshold)
151
+
152
+ def create_tags(threshold):
153
+ global sorted_tag_score
154
+ filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
155
+ text_no_impl = ", ".join(filtered_tag_score.keys())
156
+ return text_no_impl, filtered_tag_score
157
+
158
 
159
  with gr.Blocks(css=".output-class { display: none; }") as demo:
160
  gr.Markdown("""
 
163
 
164
  This tagger is the result of joint efforts between members of the RedRocket team. Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
165
  """)
166
+ with gr.Row():
167
+ with gr.Column():
168
+ image_input = gr.Image(label="Source", sources=['upload'], type='pil', height="60vh", show_label=False)
169
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
170
+ with gr.Column():
171
+ tag_string = gr.Textbox(label="Tag String")
172
+ label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
173
+
174
+ image_input.upload(
175
+ fn=run_classifier,
176
+ inputs=[image_input, threshold_slider],
177
+ outputs=[tag_string, label_box]
178
+ )
179
+
180
+ threshold_slider.input(
181
+ fn=create_tags,
182
+ inputs=[threshold_slider],
183
+ outputs=[tag_string, label_box]
184
  )
185
 
186
  if __name__ == "__main__":