import spaces # from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection from transformers import Owlv2Processor, Owlv2ForObjectDetection import torch import gradio as gr device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda") owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") # dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") # dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda") english_candidate_labels = ["hat", "sunglass", "hair band", "glove", "arm sleeve", "watch", "singlet", "t-shirts", "energy gel", "half pants", "socks", "shoes", "ear phone"] korean_candidate_labels = ["모자", "썬글라스", "헤어밴드", "장갑", "팔토시", "시계", "싱글렛", "티셔츠", "에너지젤", "쇼츠바지", "양말", "신발", "이어폰"] english_candidate_labels_string = ",".join(english_candidate_labels) # 영문 레이블을 한글 레이블로 매칭하는 딕셔너리 생성 label_mapping = dict(zip(english_candidate_labels, korean_candidate_labels)) @spaces.GPU def infer(img, text_queries, score_threshold, model): if model == "dino": queries="" for query in text_queries: queries += f"{query}. " width, height = img.shape[:2] target_sizes=[(width, height)] inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = dino_model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids, box_threshold=score_threshold, target_sizes=target_sizes) elif model == "owl": size = max(img.shape[:2]) target_sizes = torch.Tensor([[size, size]]) inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = owl_model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] result_labels = [] for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score < score_threshold: continue if model == "owl": label = text_queries[label.cpu().item()] result_labels.append((box, label)) elif model == "dino": if label != "": result_labels.append((box, label)) return result_labels # def query_image(img, text_queries, owl_threshold, dino_threshold): def query_image(img, text_queries, owl_threshold, flag_output_korean): text_queries = text_queries text_queries = text_queries.split(",") owl_output = infer(img, text_queries, owl_threshold, "owl") # dino_output = infer(img, text_queries, dino_threshold, "dino") # add - check flag output korean owl_output_final = [] if flag_output_korean: for box, label in owl_output: kor_label = label_mapping[label] owl_output_final.append((box, kor_label)) else: owl_output_final = owl_output # return (img, owl_output), (img, dino_output) return (img, owl_output_final) owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold") # dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold") owl_output = gr.AnnotatedImage(label="OWL Output") # dino_output = gr.AnnotatedImage(label="Grounding DINO Output") demo = gr.Interface( query_image, # inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold], inputs=[ gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels", value=english_candidate_labels_string), owl_threshold, gr.Checkbox(label="Output labels Korean") ], # outputs=[owl_output, dino_output], outputs=[owl_output], title="OWLv2 Demo", description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) . Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.", # examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]] # examples=[["./rs_sample1.jpg", english_candidate_labels_string, 0.16, 0.12], ["./rs_sample2.jpg", english_candidate_labels_string, 0.13, 0.10]] examples=[["./rs_sample1.jpg", english_candidate_labels_string, 0.16, 0.12], ["./rs_sample2.jpg", english_candidate_labels_string, 0.13, False]] ) demo.launch(debug=True)