Model Overview
The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub.
Links
- Segment Anything Quickstart Notebook: coming soon
- Segment Anything API Documentation
- Segment Anything Model Card
- Segment Anything paper
Installation
Keras and KerasHub can be installed with:
pip install -U -q keras-Hub
pip install -U -q keras>=3
Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the Keras Getting Started page.
Presets
The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below.
Preset name | Parameters | Description |
---|---|---|
sam_base_sa1b | 93.74M | The base SAM model trained on the SA1B dataset. |
sam_large_sa1b | 312.34M | The large SAM model trained on the SA1B dataset. |
sam_huge_sa1b | 641.09M | The huge SAM model trained on the SA1B dataset. |
Example Usage
Load pretrained model using from_preset
.
image_size=1024
batch_size=2
input_data = {
"images": np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
),
"points": np.ones((batch_size, 1, 2), dtype="float32"),
"labels": np.ones((batch_size, 1), dtype="float32"),
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(batch_size, 0, image_size, image_size, 1)
),
}
sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b')
outputs = sam.predict(input_data)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
Load segment anything image segmenter with custom backbone
image_size = 128
batch_size = 2
images = np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
)
image_encoder = keras_hub.models.ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(image_size, image_size, 3),
)
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
image_size,
image_size,
),
mask_in_channels=16,
)
mask_decoder = keras_hub.layers.SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
backbone = keras_hub.models.SAMBackbone(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
)
sam = keras_hub.models.SAMImageSegmenter(
backbone=backbone
)
Example Usage with Hugging Face URI
Load pretrained model using from_preset
.
image_size=1024
batch_size=2
input_data = {
"images": np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
),
"points": np.ones((batch_size, 1, 2), dtype="float32"),
"labels": np.ones((batch_size, 1), dtype="float32"),
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(batch_size, 0, image_size, image_size, 1)
),
}
sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b')
outputs = sam.predict(input_data)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
Load segment anything image segmenter with custom backbone
image_size = 128
batch_size = 2
images = np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
)
image_encoder = keras_hub.models.ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(image_size, image_size, 3),
)
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
image_size,
image_size,
),
mask_in_channels=16,
)
mask_decoder = keras_hub.layers.SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
backbone = keras_hub.models.SAMBackbone(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
)
sam = keras_hub.models.SAMImageSegmenter(
backbone=backbone
)
- Downloads last month
- 32