|
|
|
import os |
|
|
|
from .parrots_wrapper import TORCH_VERSION |
|
|
|
parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') |
|
|
|
if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON': |
|
from parrots.jit import pat as jit |
|
else: |
|
|
|
def jit(func=None, |
|
check_input=None, |
|
full_shape=True, |
|
derivate=False, |
|
coderize=False, |
|
optimize=False): |
|
|
|
def wrapper(func): |
|
|
|
def wrapper_inner(*args, **kargs): |
|
return func(*args, **kargs) |
|
|
|
return wrapper_inner |
|
|
|
if func is None: |
|
return wrapper |
|
else: |
|
return func |
|
|
|
|
|
if TORCH_VERSION == 'parrots': |
|
from parrots.utils.tester import skip_no_elena |
|
else: |
|
|
|
def skip_no_elena(func): |
|
|
|
def wrapper(*args, **kargs): |
|
return func(*args, **kargs) |
|
|
|
return wrapper |
|
|