|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.benchmark as benchmark |
|
from torch.backends.cuda import sdp_kernel, SDPBackend |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def benchmark_torch_function_in_milliseconds(f, *args, **kwargs): |
|
t0 = benchmark.Timer( |
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
|
) |
|
return t0.blocked_autorange().mean * 1e3 |
|
|
|
batch_size = 32 |
|
max_sequence_len = 1024 |
|
num_heads = 32 |
|
embed_dimension = 32 |
|
|
|
dtype = torch.float16 |
|
|
|
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) |
|
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) |
|
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) |
|
|
|
print(f"The default implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds") |
|
|
|
backend_map = { |
|
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, |
|
SDPBackend.EFFICIENT_ATTENTION: {"enable_math": False, "enable_flash": False, "enable_mem_efficient": True} |
|
} |
|
|
|
with sdp_kernel(**backend_map[SDPBackend.MATH]): |
|
print(f"The math implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds") |
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): |
|
try: |
|
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_milliseconds(F.scaled_dot_product_attention, query, key, value):.3f} milliseconds") |
|
except RuntimeError: |
|
print("EfficientAttention is not supported. See warnings for reasons.") |
|
|