merve HF staff commited on
Commit
d2b4796
1 Parent(s): d4be87c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import pipeline, SamModel, SamProcessor
2
  import torch
3
  import numpy as np
 
4
  import spaces
5
 
6
  checkpoint = "google/owlv2-base-patch16-ensemble"
@@ -8,8 +9,9 @@ detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=
8
  sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
9
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
10
 
 
11
  @spaces.GPU
12
- def query(image, texts, threshold):
13
  texts = texts.split(",")
14
 
15
  predictions = detector(
@@ -20,42 +22,48 @@ def query(image, texts, threshold):
20
 
21
  result_labels = []
22
  for pred in predictions:
23
-
24
  box = pred["box"]
25
  score = pred["score"]
26
  label = pred["label"]
27
- box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
28
  round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
29
-
30
  inputs = sam_processor(
31
  image,
32
- input_boxes=[[[box]]],
33
  return_tensors="pt"
34
  ).to("cuda")
35
-
36
  with torch.no_grad():
37
- outputs = sam_model(**inputs)
38
 
39
  mask = sam_processor.image_processor.post_process_masks(
40
  outputs.pred_masks.cpu(),
41
  inputs["original_sizes"].cpu(),
42
  inputs["reshaped_input_sizes"].cpu()
43
- )[0][0][0].numpy()
44
- mask = mask[np.newaxis, ...]
45
- result_labels.append((mask, label))
 
 
 
 
 
 
 
 
 
46
  return image, result_labels
47
 
48
- import gradio as gr
49
 
50
  description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
51
  demo = gr.Interface(
52
  query,
53
- inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
54
  outputs="annotatedimage",
55
  title="OWL 🤝 SAM",
56
  description=description,
57
  examples=[
58
- ["./cats.png", "cat", 0.1],
59
  ],
60
  cache_examples=True
61
  )
 
1
  from transformers import pipeline, SamModel, SamProcessor
2
  import torch
3
  import numpy as np
4
+ import gradio as gr
5
  import spaces
6
 
7
  checkpoint = "google/owlv2-base-patch16-ensemble"
 
9
  sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
10
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
11
 
12
+
13
  @spaces.GPU
14
+ def query(image, texts, threshold, sam_threshold):
15
  texts = texts.split(",")
16
 
17
  predictions = detector(
 
22
 
23
  result_labels = []
24
  for pred in predictions:
25
+
26
  box = pred["box"]
27
  score = pred["score"]
28
  label = pred["label"]
29
+ box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
30
  round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
 
31
  inputs = sam_processor(
32
  image,
33
+ input_boxes=[[box]],
34
  return_tensors="pt"
35
  ).to("cuda")
 
36
  with torch.no_grad():
37
+ outputs = sam_model(**inputs)
38
 
39
  mask = sam_processor.image_processor.post_process_masks(
40
  outputs.pred_masks.cpu(),
41
  inputs["original_sizes"].cpu(),
42
  inputs["reshaped_input_sizes"].cpu()
43
+ )
44
+ iou_scores = outputs["iou_scores"]
45
+
46
+ masks, iou_scores, boxes = sam_processor.image_processor.filter_masks(
47
+ mask[0],
48
+ iou_scores[0].cpu(),
49
+ inputs["original_sizes"][0].cpu(),
50
+ box,
51
+ pred_iou_thresh=sam_threshold,
52
+ )
53
+
54
+ result_labels.append((mask[0][0][0].numpy(), label))
55
  return image, result_labels
56
 
 
57
 
58
  description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
59
  demo = gr.Interface(
60
  query,
61
+ inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold for OWL"), gr.Slider(0, 1, value=0.88, label="IoU threshold for SAM")],
62
  outputs="annotatedimage",
63
  title="OWL 🤝 SAM",
64
  description=description,
65
  examples=[
66
+ ["./cats.png", "cat", 0.1, 0.88],
67
  ],
68
  cache_examples=True
69
  )