Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor | |
from lang_sam.models.utils import get_device_type | |
device_type = get_device_type() | |
DEVICE = torch.device(device_type) | |
if torch.cuda.is_available(): | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
class GDINO: | |
def __init__(self): | |
self.build_model() | |
def build_model(self, ckpt_path: str | None = None): | |
model_id = "IDEA-Research/grounding-dino-base" | |
self.processor = AutoProcessor.from_pretrained(model_id) | |
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( | |
DEVICE | |
) | |
def predict( | |
self, | |
pil_images: list[Image.Image], | |
text_prompt: list[str], | |
box_threshold: float, | |
text_threshold: float, | |
) -> list[dict]: | |
for i, prompt in enumerate(text_prompt): | |
if prompt[-1] != ".": | |
text_prompt[i] += "." | |
inputs = self.processor(images=pil_images, text=text_prompt, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
results = self.processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
target_sizes=[k.size[::-1] for k in pil_images], | |
) | |
return results | |
if __name__ == "__main__": | |
gdino = GDINO() | |
gdino.build_model() | |
out = gdino.predict( | |
[Image.open("./assets/car.jpeg"), Image.open("./assets/car.jpeg")], | |
["wheel", "wheel"], | |
0.3, | |
0.25, | |
) | |
print(out) | |