Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import glob | |
import os.path | |
import torch | |
try: | |
from torch.utils.cpp_extension import load as load_ext | |
from torch.utils.cpp_extension import CUDA_HOME | |
except ImportError: | |
raise ImportError("The cpp layer extensions requires PyTorch 0.4 or higher") | |
def _load_C_extensions(): | |
this_dir = os.path.dirname(os.path.abspath(__file__)) | |
this_dir = os.path.dirname(this_dir) | |
this_dir = os.path.join(this_dir, "csrc") | |
main_file = glob.glob(os.path.join(this_dir, "*.cpp")) | |
source_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) | |
source_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) | |
source = main_file + source_cpu | |
extra_cflags = [] | |
if torch.cuda.is_available() and CUDA_HOME is not None: | |
source.extend(source_cuda) | |
extra_cflags = ["-DWITH_CUDA"] | |
source = [os.path.join(this_dir, s) for s in source] | |
extra_include_paths = [this_dir] | |
return load_ext( | |
"torchvision", | |
source, | |
extra_cflags=extra_cflags, | |
extra_include_paths=extra_include_paths, | |
) | |
_C = _load_C_extensions() | |