ashhadahsan's picture
Upload 21 files
8e5dadf verified
import torch
from PIL import Image
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from lang_sam.models.utils import get_device_type
device_type = get_device_type()
DEVICE = torch.device(device_type)
if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class GDINO:
def __init__(self):
self.build_model()
def build_model(self, ckpt_path: str | None = None):
model_id = "IDEA-Research/grounding-dino-base"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
DEVICE
)
def predict(
self,
pil_images: list[Image.Image],
text_prompt: list[str],
box_threshold: float,
text_threshold: float,
) -> list[dict]:
for i, prompt in enumerate(text_prompt):
if prompt[-1] != ".":
text_prompt[i] += "."
inputs = self.processor(images=pil_images, text=text_prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[k.size[::-1] for k in pil_images],
)
return results
if __name__ == "__main__":
gdino = GDINO()
gdino.build_model()
out = gdino.predict(
[Image.open("./assets/car.jpeg"), Image.open("./assets/car.jpeg")],
["wheel", "wheel"],
0.3,
0.25,
)
print(out)