Metal079 commited on
Commit
308d90e
·
1 Parent(s): 6a85f15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -1,31 +1,33 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- pipe_aesthetic = pipeline("image-classification", "./sonic")
5
- def aesthetic(input_img):
6
- data = pipe_aesthetic(input_img, top_k=2)
7
- final = {}
8
- for d in data:
9
- final[d["label"]] = d["score"]
10
- return final
11
- demo_aesthetic = gr.Interface(fn=aesthetic, inputs=gr.Image(type="pil"), outputs=gr.Label(label="aesthetic"))
12
 
13
- pipe_style = pipeline("image-classification", "cafeai/cafe_style")
14
- def style(input_img):
15
- data = pipe_style(input_img, top_k=5)
16
- final = {}
17
- for d in data:
18
- final[d["label"]] = d["score"]
19
- return final
20
- demo_style = gr.Interface(fn=style, inputs=gr.Image(type="pil"), outputs=gr.Label(label="style"))
 
21
 
22
- pipe_waifu = pipeline("image-classification", "cafeai/cafe_waifu")
23
- def waifu(input_img):
24
- data = pipe_waifu(input_img, top_k=5)
 
 
 
 
 
25
  final = {}
26
  for d in data:
27
  final[d["label"]] = d["score"]
28
  return final
29
- demo_waifu = gr.Interface(fn=waifu, inputs=gr.Image(type="pil"), outputs=gr.Label(label="waifu"))
30
 
31
- gr.Parallel(demo_aesthetic, demo_style, demo_waifu).launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, ImageClassificationPipeline
3
 
4
+ class MultiClassLabel(ImageClassificationPipeline):
5
+ def postprocess(self, model_outputs, top_k=5):
6
+ if top_k > self.model.config.num_labels:
7
+ top_k = self.model.config.num_labels
 
 
 
 
8
 
9
+ if self.framework == "pt":
10
+ probs = model_outputs.logits.sigmoid()[0]
11
+ scores, ids = probs.topk(top_k)
12
+ elif self.framework == "tf":
13
+ probs = stable_softmax(model_outputs.logits, axis=-1)[0]
14
+ topk = tf.math.top_k(probs, k=top_k)
15
+ scores, ids = topk.values.numpy(), topk.indices.numpy()
16
+ else:
17
+ raise ValueError(f"Unsupported framework: {self.framework}")
18
 
19
+ scores = scores.tolist()
20
+ ids = ids.tolist()
21
+ return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
22
+
23
+ pipe_aesthetic = pipeline("image-classification", "./sonic", pipeline_class=MultiClassLabel)
24
+
25
+ def aesthetic(input_img):
26
+ data = pipe_aesthetic(input_img, top_k=5)
27
  final = {}
28
  for d in data:
29
  final[d["label"]] = d["score"]
30
  return final
31
+ demo_aesthetic = gr.Interface(fn=aesthetic, inputs=gr.Image(type="pil"), outputs=gr.Label(label="characters"))
32
 
33
+ gr.Parallel(demo_aesthetic).launch()