RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight)

Model Card for ViT Base (ViT-B) version

Open In Colab Huggingfaces

Official repository for RobustSAM: Segment Anything Robustly on Degraded Images

Project Page | Paper | Dataset

Introduction

Segment Anything Model (SAM) has emerged as a transformative approach in image segmentation, acclaimed for its robust zero-shot segmentation capabilities and flexible prompting system. Nonetheless, its performance is challenged by images with degraded quality. Addressing this limitation, we propose the Robust Segment Anything Model (RobustSAM), which enhances SAM's performance on low-quality images while preserving its promptability and zero-shot generalization.

Our method leverages the pre-trained SAM model with only marginal parameter increments and computational requirements. The additional parameters of RobustSAM can be optimized within 30 hours on eight GPUs, demonstrating its feasibility and practicality for typical research laboratories. We also introduce the Robust-Seg dataset, a collection of 688K image-mask pairs with different degradations designed to train and evaluate our model optimally. Extensive experiments across various segmentation tasks and datasets confirm RobustSAM's superior performance, especially under zero-shot conditions, underscoring its potential for extensive real-world application. Additionally, our method has been shown to effectively improve the performance of SAM-based downstream tasks such as single image dehazing and deblurring.

image

Disclaimer: Content from this model card has been written by the Hugging Face team, and parts of it were copy pasted from the original SAM model card.

Model Details

The RobustSAM model is made up of 3 modules:

  • The VisionEncoder: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used.
  • The PromptEncoder: generates embeddings for points and bounding boxes
  • The MaskDecoder: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed
  • The Neck: predicts the output masks based on the contextualized masks produced by the MaskDecoder.

Usage

Prompted-Mask-Generation

from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration

# load the RobustSAM model and processor
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")

# load an image from a url
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

# we define input points (2D localization of an object in the image)
input_points = [[[450, 600]]]  # example point
# process the image and input points
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")

# generate masks using the model
with torch.no_grad():
    outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores

Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to the official repository. For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example!

Automatic-Mask-Generation

The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of 1024 points which are all fed to the model.

The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate points_per_batch argument)

from transformers import pipeline

# initialize the pipeline for mask generation
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)

image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)

Now to display the generated mask on the image:

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# simple function to display the mask
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    
    # get the height and width from the mask
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

# display the original image
plt.imshow(np.array(raw_image))
ax = plt.gca()

# loop through the masks and display each one
for mask in outputs["masks"]:
    show_mask(mask, ax=ax, random_color=True)

plt.axis("off")

# show the image with the masks
plt.show()

Visual Comparison

image

Reference

If you find this work useful, please consider citing us!

@inproceedings{chen2024robustsam,
  title={RobustSAM: Segment Anything Robustly on Degraded Images},
  author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
  journal={CVPR},
  year={2024}
}

Acknowledgements

We thank the authors of SAM from which our repo is based off of.

Downloads last month
269
Safetensors
Model size
93.7M params
Tensor type
F32
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.