Spaces:
Runtime error
Runtime error
File size: 1,882 Bytes
8e5dadf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from PIL import Image
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from lang_sam.models.utils import get_device_type
device_type = get_device_type()
DEVICE = torch.device(device_type)
if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class GDINO:
def __init__(self):
self.build_model()
def build_model(self, ckpt_path: str | None = None):
model_id = "IDEA-Research/grounding-dino-base"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
DEVICE
)
def predict(
self,
pil_images: list[Image.Image],
text_prompt: list[str],
box_threshold: float,
text_threshold: float,
) -> list[dict]:
for i, prompt in enumerate(text_prompt):
if prompt[-1] != ".":
text_prompt[i] += "."
inputs = self.processor(images=pil_images, text=text_prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[k.size[::-1] for k in pil_images],
)
return results
if __name__ == "__main__":
gdino = GDINO()
gdino.build_model()
out = gdino.predict(
[Image.open("./assets/car.jpeg"), Image.open("./assets/car.jpeg")],
["wheel", "wheel"],
0.3,
0.25,
)
print(out)
|