ashhadahsan's picture
Upload 21 files
8e5dadf verified
import numpy as np
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
from lang_sam.models.utils import get_device_type
DEVICE = torch.device(get_device_type())
if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
SAM_MODELS = {
"sam2.1_hiera_tiny": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"config": "configs/sam2.1/sam2.1_hiera_t.yaml",
},
"sam2.1_hiera_small": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"config": "configs/sam2.1/sam2.1_hiera_s.yaml",
},
"sam2.1_hiera_base_plus": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"config": "configs/sam2.1/sam2.1_hiera_b+.yaml",
},
"sam2.1_hiera_large": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
"config": "configs/sam2.1/sam2.1_hiera_l.yaml",
},
}
class SAM:
def build_model(self, sam_type: str, ckpt_path: str | None = None):
self.sam_type = sam_type
self.ckpt_path = ckpt_path
cfg = compose(config_name=SAM_MODELS[self.sam_type]["config"], overrides=[])
OmegaConf.resolve(cfg)
self.model = instantiate(cfg.model, _recursive_=True)
self._load_checkpoint(self.model)
self.model = self.model.to(DEVICE)
self.model.eval()
self.mask_generator = SAM2AutomaticMaskGenerator(self.model)
self.predictor = SAM2ImagePredictor(self.model)
def _load_checkpoint(self, model: torch.nn.Module):
if self.ckpt_path is None:
checkpoint_url = SAM_MODELS[self.sam_type]["url"]
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
else:
state_dict = torch.load(self.ckpt_path, map_location="cpu", weights_only=True)
try:
model.load_state_dict(state_dict, strict=True)
except Exception as e:
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {self.sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it. Error: {e}")
def generate(self, image_rgb: np.ndarray) -> list[dict]:
"""
Output format
SAM2AutomaticMaskGenerator returns a list of masks, where each mask is a dict containing various information
about the mask:
segmentation - [np.ndarray] - the mask with (W, H) shape, and bool type
area - [int] - the area of the mask in pixels
bbox - [List[int]] - the boundary box of the mask in xywh format
predicted_iou - [float] - the model's own prediction for the quality of the mask
point_coords - [List[List[float]]] - the sampled input point that generated this mask
stability_score - [float] - an additional measure of mask quality
crop_box - List[int] - the crop of the image used to generate this mask in xywh format
"""
sam2_result = self.mask_generator.generate(image_rgb)
return sam2_result
def predict(self, image_rgb: np.ndarray, xyxy: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
self.predictor.set_image(image_rgb)
masks, scores, logits = self.predictor.predict(box=xyxy, multimask_output=False)
if len(masks.shape) > 3:
masks = np.squeeze(masks, axis=1)
return masks, scores, logits
def predict_batch(
self,
images_rgb: list[np.ndarray],
xyxy: list[np.ndarray],
) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray]]:
self.predictor.set_image_batch(images_rgb)
masks, scores, logits = self.predictor.predict_batch(box_batch=xyxy, multimask_output=False)
masks = [np.squeeze(mask, axis=1) if len(mask.shape) > 3 else mask for mask in masks]
scores = [np.squeeze(score) for score in scores]
logits = [np.squeeze(logit, axis=1) if len(logit.shape) > 3 else logit for logit in logits]
return masks, scores, logits