# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import enum import subprocess from typing import Callable, Optional class Framework(enum.Enum): TENSORFLOW = "tensorflow" PYTORCH = "pytorch" @dataclasses.dataclass(frozen=True) class DatasetSpec: framework: Optional[Framework] create_dataset_fn: Callable COCO_LABELS = { # 0: 'background', 1: "person", 2: "bicycle", 3: "car", 4: "motorcycle", 5: "airplane", 6: "bus", 7: "train", 8: "truck", 9: "boat", 10: "traffic light", 11: "fire hydrant", 13: "stop sign", 14: "parking meter", 15: "bench", 16: "bird", 17: "cat", 18: "dog", 19: "horse", 20: "sheep", 21: "cow", 22: "elephant", 23: "bear", 24: "zebra", 25: "giraffe", 27: "backpack", 28: "umbrella", 31: "handbag", 32: "tie", 33: "suitcase", 34: "frisbee", 35: "skis", 36: "snowboard", 37: "sports ball", 38: "kite", 39: "baseball bat", 40: "baseball glove", 41: "skateboard", 42: "surfboard", 43: "tennis racket", 44: "bottle", 46: "wine glass", 47: "cup", 48: "fork", 49: "knife", 50: "spoon", 51: "bowl", 52: "banana", 53: "apple", 54: "sandwich", 55: "orange", 56: "broccoli", 57: "carrot", 58: "hot dog", 59: "pizza", 60: "donut", 61: "cake", 62: "chair", 63: "couch", 64: "potted plant", 65: "bed", 67: "dining table", 70: "toilet", 72: "tv", 73: "laptop", 74: "mouse", 75: "remote", 76: "keyboard", 77: "cell phone", 78: "microwave", 79: "oven", 80: "toaster", 81: "sink", 82: "refrigerator", 84: "book", 85: "clock", 86: "vase", 87: "scissors", 88: "teddy bear", 89: "hair drier", 90: "toothbrush", } def _create_tfds_coco2017_validation(batch_size: Optional[int] = None) -> Callable: subprocess.run(["pip", "install", "--upgrade", "tensorflow-datasets"], check=True) import tensorflow_datasets as tfds # pytype: disable=import-error return tfds.load("coco/2017", split="validation", as_supervised=True, with_info=True, batch_size=batch_size) TFDS_COCO2017_VALIDATION_DATASET = DatasetSpec( framework=Framework.TENSORFLOW, create_dataset_fn=_create_tfds_coco2017_validation, ) def _create_tfds_tf_flowers(batch_size: Optional[int] = None): subprocess.run(["pip", "install", "--upgrade", "tensorflow-datasets"], check=True) import tensorflow_datasets as tfds # pytype: disable=import-error return tfds.load("tf_flowers", as_supervised=True, with_info=True, batch_size=batch_size) TFDS_TF_FLOWERS_DATASET = DatasetSpec( framework=Framework.TENSORFLOW, create_dataset_fn=_create_tfds_tf_flowers, ) DATASETS_CATALOGUE = [TFDS_COCO2017_VALIDATION_DATASET, TFDS_TF_FLOWERS_DATASET]