args
Browse files
wfx.py
CHANGED
@@ -22,6 +22,8 @@ def parse_args():
|
|
22 |
args.add_argument('--model', type=str, required=True)
|
23 |
args.add_argument('--custom-pipeline', type=str, default=None)
|
24 |
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
|
|
|
|
|
25 |
return args.parse_args()
|
26 |
|
27 |
def quantize_unet(m):
|
@@ -68,7 +70,10 @@ class WFX():
|
|
68 |
except ImportError:
|
69 |
logger.warning('triton not found, disabling triton')
|
70 |
|
71 |
-
self.compiler_config.enable_cuda_graph =
|
|
|
|
|
|
|
72 |
|
73 |
for key in self.compiler_config.__dict__:
|
74 |
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
|
|
|
22 |
args.add_argument('--model', type=str, required=True)
|
23 |
args.add_argument('--custom-pipeline', type=str, default=None)
|
24 |
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
|
25 |
+
args.add_argument('--enable-cuda-graph', action='store_true', default=False)
|
26 |
+
args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False)
|
27 |
return args.parse_args()
|
28 |
|
29 |
def quantize_unet(m):
|
|
|
70 |
except ImportError:
|
71 |
logger.warning('triton not found, disabling triton')
|
72 |
|
73 |
+
self.compiler_config.enable_cuda_graph = args.enable_cuda_graph
|
74 |
+
|
75 |
+
if args.disable_prefer_lowp_gemm:
|
76 |
+
self.compiler_config.prefer_lowp_gemm = False
|
77 |
|
78 |
for key in self.compiler_config.__dict__:
|
79 |
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
|