""" Created By: ishwor subedi Date: 2024-07-10 """ import os.path import cv2 import numpy as np import requests import wget from PIL import Image, ImageOps from tqdm import tqdm from ultralytics import YOLO from segment_anything import SamPredictor, sam_model_registry class Segmentation: def __init__(self): model_path = "artifacts/segmentation/yolov8x-seg.pt" self.segmentation_model = YOLO(model=model_path) def segment_image(self, image_path: str): results = self.segmentation_model(image_path, show=True) return results class SegmentAnything: def __init__(self, device="cpu"): self.model_name = "sam_vit_l_0b3195.pth" self.model_download() self.sam = sam_model_registry["vit_l"](checkpoint="artifacts/segmentation/sam_vit_l_0b3195.pth").to(device) self.samPredictor = SamPredictor(self.sam) def model_download(self): if os.path.exists(f"artifacts/segmentation/{self.model_name}"): print(f"{self.model_name} model already exists.") else: print(f"Downloading {self.model_name} model...") url = f"https://dl.fbaipublicfiles.com/segment_anything/{self.model_name}" response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(f"artifacts/segmentation/{self.model_name}", 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") def generate_mask(self, image, selected_points, deselected_points): selected_pixels = [] deselected_pixels = [] selected_pixels.append(selected_points) deselected_pixels.append(deselected_points) self.samPredictor.set_image(image) points = np.array(selected_pixels) label = np.ones(points.shape[0]) mask, _, _ = self.samPredictor.predict( point_coords=points, point_labels=label, ) mask = Image.fromarray(mask[0, :, :]) mask_img = ImageOps.invert(mask) return mask_img if __name__ == '__main__': segment_anything = SegmentAnything() image_path = "/home/ishwor/Pictures/01.TEST/alia/5869473_dark_lean.png" image = cv2.imread(image_path) mask = segment_anything.generate_mask(image, (20, 20), (20, 20)) maskimage = np.array(mask) image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB) print(maskimage.shape) cv2.imshow("image", image) cv2.waitKey(0)