Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from hydra import compose | |
from hydra.utils import instantiate | |
from omegaconf import OmegaConf | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from lang_sam.models.utils import get_device_type | |
DEVICE = torch.device(get_device_type()) | |
if torch.cuda.is_available(): | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
SAM_MODELS = { | |
"sam2.1_hiera_tiny": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", | |
"config": "configs/sam2.1/sam2.1_hiera_t.yaml", | |
}, | |
"sam2.1_hiera_small": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", | |
"config": "configs/sam2.1/sam2.1_hiera_s.yaml", | |
}, | |
"sam2.1_hiera_base_plus": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", | |
"config": "configs/sam2.1/sam2.1_hiera_b+.yaml", | |
}, | |
"sam2.1_hiera_large": { | |
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", | |
"config": "configs/sam2.1/sam2.1_hiera_l.yaml", | |
}, | |
} | |
class SAM: | |
def build_model(self, sam_type: str, ckpt_path: str | None = None): | |
self.sam_type = sam_type | |
self.ckpt_path = ckpt_path | |
cfg = compose(config_name=SAM_MODELS[self.sam_type]["config"], overrides=[]) | |
OmegaConf.resolve(cfg) | |
self.model = instantiate(cfg.model, _recursive_=True) | |
self._load_checkpoint(self.model) | |
self.model = self.model.to(DEVICE) | |
self.model.eval() | |
self.mask_generator = SAM2AutomaticMaskGenerator(self.model) | |
self.predictor = SAM2ImagePredictor(self.model) | |
def _load_checkpoint(self, model: torch.nn.Module): | |
if self.ckpt_path is None: | |
checkpoint_url = SAM_MODELS[self.sam_type]["url"] | |
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"] | |
else: | |
state_dict = torch.load(self.ckpt_path, map_location="cpu", weights_only=True) | |
try: | |
model.load_state_dict(state_dict, strict=True) | |
except Exception as e: | |
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {self.sam_type} \ | |
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \ | |
re-downloading it. Error: {e}") | |
def generate(self, image_rgb: np.ndarray) -> list[dict]: | |
""" | |
Output format | |
SAM2AutomaticMaskGenerator returns a list of masks, where each mask is a dict containing various information | |
about the mask: | |
segmentation - [np.ndarray] - the mask with (W, H) shape, and bool type | |
area - [int] - the area of the mask in pixels | |
bbox - [List[int]] - the boundary box of the mask in xywh format | |
predicted_iou - [float] - the model's own prediction for the quality of the mask | |
point_coords - [List[List[float]]] - the sampled input point that generated this mask | |
stability_score - [float] - an additional measure of mask quality | |
crop_box - List[int] - the crop of the image used to generate this mask in xywh format | |
""" | |
sam2_result = self.mask_generator.generate(image_rgb) | |
return sam2_result | |
def predict(self, image_rgb: np.ndarray, xyxy: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
self.predictor.set_image(image_rgb) | |
masks, scores, logits = self.predictor.predict(box=xyxy, multimask_output=False) | |
if len(masks.shape) > 3: | |
masks = np.squeeze(masks, axis=1) | |
return masks, scores, logits | |
def predict_batch( | |
self, | |
images_rgb: list[np.ndarray], | |
xyxy: list[np.ndarray], | |
) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray]]: | |
self.predictor.set_image_batch(images_rgb) | |
masks, scores, logits = self.predictor.predict_batch(box_batch=xyxy, multimask_output=False) | |
masks = [np.squeeze(mask, axis=1) if len(mask.shape) > 3 else mask for mask in masks] | |
scores = [np.squeeze(score) for score in scores] | |
logits = [np.squeeze(logit, axis=1) if len(logit.shape) > 3 else logit for logit in logits] | |
return masks, scores, logits | |