Transformers documentation

Optimize inference using torch.compile()

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.46.2).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Optimize inference using torch.compile()

This guide aims to provide a benchmark on the inference speed-ups introduced with torch.compile() for computer vision models in 🤗 Transformers.

Benefits of torch.compile

Depending on the model and the GPU, torch.compile() yields up to 30% speed-up during inference. To use torch.compile(), simply install any version of torch above 2.0.

Compiling a model takes time, so it’s useful if you are compiling the model only once instead of every time you infer. To compile any computer vision model of your choice, call torch.compile() on the model as shown below:

from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to("cuda")
+ model = torch.compile(model)

compile() comes with multiple modes for compiling, which essentially differ in compilation time and inference overhead. max-autotune takes longer than reduce-overhead but results in faster inference. Default mode is fastest for compilation but is not as efficient compared to reduce-overhead for inference time. In this guide, we used the default mode. You can learn more about it here.

We benchmarked torch.compile with different computer vision models, tasks, types of hardware, and batch sizes on torch version 2.0.1.

Benchmarking code

Below you can find the benchmarking code for each task. We warm up the GPU before inference and take the mean time of 300 inferences, using the same image each time.

Image Classification with ViT

import torch
from PIL import Image
import requests
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to("cuda")
model = torch.compile(model)

processed_input = processor(image, return_tensors='pt').to(device="cuda")

with torch.no_grad():
    _ = model(**processed_input)

Object Detection with DETR

from transformers import AutoImageProcessor, AutoModelForObjectDetection

processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
model = torch.compile(model)

texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**inputs)

Image Segmentation with Segformer

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to("cuda")
model = torch.compile(model)
seg_inputs = processor(images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**seg_inputs)

Below you can find the list of the models we benchmarked.

Image Classification

Image Segmentation

Object Detection

Below you can find visualization of inference durations with and without torch.compile() and percentage improvements for each model in different hardware and batch sizes.

Duration Comparison on V100 with Batch Size of 1

Percentage Improvement on T4 with Batch Size of 4

Below you can find inference durations in milliseconds for each model with and without compile(). Note that OwlViT results in OOM in larger batch sizes.

A100 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 9.325 7.584
Image Segmentation/Segformer 11.759 10.500
Object Detection/OwlViT 24.978 18.420
Image Classification/BeiT 11.282 8.448
Object Detection/DETR 34.619 19.040
Image Classification/ConvNeXT 10.410 10.208
Image Classification/ResNet 6.531 4.124
Image Segmentation/Mask2former 60.188 49.117
Image Segmentation/Maskformer 75.764 59.487
Image Segmentation/MobileNet 8.583 3.974
Object Detection/Resnet-101 36.276 18.197
Object Detection/Conditional-DETR 31.219 17.993

A100 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 14.832 14.499
Image Segmentation/Segformer 18.838 16.476
Image Classification/BeiT 13.205 13.048
Object Detection/DETR 48.657 32.418
Image Classification/ConvNeXT 22.940 21.631
Image Classification/ResNet 6.657 4.268
Image Segmentation/Mask2former 74.277 61.781
Image Segmentation/Maskformer 180.700 159.116
Image Segmentation/MobileNet 14.174 8.515
Object Detection/Resnet-101 68.101 44.998
Object Detection/Conditional-DETR 56.470 35.552

A100 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 40.944 40.010
Image Segmentation/Segformer 37.005 31.144
Image Classification/BeiT 41.854 41.048
Object Detection/DETR 164.382 161.902
Image Classification/ConvNeXT 82.258 75.561
Image Classification/ResNet 7.018 5.024
Image Segmentation/Mask2former 178.945 154.814
Image Segmentation/Maskformer 638.570 579.826
Image Segmentation/MobileNet 51.693 30.310
Object Detection/Resnet-101 232.887 155.021
Object Detection/Conditional-DETR 180.491 124.032

V100 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 10.495 6.00
Image Segmentation/Segformer 13.321 5.862
Object Detection/OwlViT 25.769 22.395
Image Classification/BeiT 11.347 7.234
Object Detection/DETR 33.951 19.388
Image Classification/ConvNeXT 11.623 10.412
Image Classification/ResNet 6.484 3.820
Image Segmentation/Mask2former 64.640 49.873
Image Segmentation/Maskformer 95.532 72.207
Image Segmentation/MobileNet 9.217 4.753
Object Detection/Resnet-101 52.818 28.367
Object Detection/Conditional-DETR 39.512 20.816

V100 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 15.181 14.501
Image Segmentation/Segformer 16.787 16.188
Image Classification/BeiT 15.171 14.753
Object Detection/DETR 88.529 64.195
Image Classification/ConvNeXT 29.574 27.085
Image Classification/ResNet 6.109 4.731
Image Segmentation/Mask2former 90.402 76.926
Image Segmentation/Maskformer 234.261 205.456
Image Segmentation/MobileNet 24.623 14.816
Object Detection/Resnet-101 134.672 101.304
Object Detection/Conditional-DETR 97.464 69.739

V100 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 52.209 51.633
Image Segmentation/Segformer 61.013 55.499
Image Classification/BeiT 53.938 53.581
Object Detection/DETR OOM OOM
Image Classification/ConvNeXT 109.682 100.771
Image Classification/ResNet 14.857 12.089
Image Segmentation/Mask2former 249.605 222.801
Image Segmentation/Maskformer 831.142 743.645
Image Segmentation/MobileNet 93.129 55.365
Object Detection/Resnet-101 482.425 361.843
Object Detection/Conditional-DETR 344.661 255.298

T4 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 16.520 15.786
Image Segmentation/Segformer 16.116 14.205
Object Detection/OwlViT 53.634 51.105
Image Classification/BeiT 16.464 15.710
Object Detection/DETR 73.100 53.99
Image Classification/ConvNeXT 32.932 30.845
Image Classification/ResNet 6.031 4.321
Image Segmentation/Mask2former 79.192 66.815
Image Segmentation/Maskformer 200.026 188.268
Image Segmentation/MobileNet 18.908 11.997
Object Detection/Resnet-101 106.622 82.566
Object Detection/Conditional-DETR 77.594 56.984

T4 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 43.653 43.626
Image Segmentation/Segformer 45.327 42.445
Image Classification/BeiT 52.007 51.354
Object Detection/DETR 277.850 268.003
Image Classification/ConvNeXT 119.259 105.580
Image Classification/ResNet 13.039 11.388
Image Segmentation/Mask2former 201.540 184.670
Image Segmentation/Maskformer 764.052 711.280
Image Segmentation/MobileNet 74.289 48.677
Object Detection/Resnet-101 421.859 357.614
Object Detection/Conditional-DETR 289.002 226.945

T4 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 163.914 160.907
Image Segmentation/Segformer 192.412 163.620
Image Classification/BeiT 188.978 187.976
Object Detection/DETR OOM OOM
Image Classification/ConvNeXT 422.886 388.078
Image Classification/ResNet 44.114 37.604
Image Segmentation/Mask2former 756.337 695.291
Image Segmentation/Maskformer 2842.940 2656.88
Image Segmentation/MobileNet 299.003 201.942
Object Detection/Resnet-101 1619.505 1262.758
Object Detection/Conditional-DETR 1137.513 897.390

PyTorch Nightly

We also benchmarked on PyTorch nightly (2.1.0dev, find the wheel here) and observed improvement in latency both for uncompiled and compiled models.

A100

Task/Model Batch Size torch 2.0 - no compile torch 2.0 -
compile
Image Classification/BeiT Unbatched 12.462 6.954
Image Classification/BeiT 4 14.109 12.851
Image Classification/BeiT 16 42.179 42.147
Object Detection/DETR Unbatched 30.484 15.221
Object Detection/DETR 4 46.816 30.942
Object Detection/DETR 16 163.749 163.706

T4

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/BeiT Unbatched 14.408 14.052
Image Classification/BeiT 4 47.381 46.604
Image Classification/BeiT 16 42.179 42.147
Object Detection/DETR Unbatched 68.382 53.481
Object Detection/DETR 4 269.615 204.785
Object Detection/DETR 16 OOM OOM

V100

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/BeiT Unbatched 13.477 7.926
Image Classification/BeiT 4 15.103 14.378
Image Classification/BeiT 16 52.517 51.691
Object Detection/DETR Unbatched 28.706 19.077
Object Detection/DETR 4 88.402 62.949
Object Detection/DETR 16 OOM OOM

Reduce Overhead

We benchmarked reduce-overhead compilation mode for A100 and T4 in Nightly.

A100

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ConvNeXT Unbatched 11.758 7.335
Image Classification/ConvNeXT 4 23.171 21.490
Image Classification/ResNet Unbatched 7.435 3.801
Image Classification/ResNet 4 7.261 2.187
Object Detection/Conditional-DETR Unbatched 32.823 11.627
Object Detection/Conditional-DETR 4 50.622 33.831
Image Segmentation/MobileNet Unbatched 9.869 4.244
Image Segmentation/MobileNet 4 14.385 7.946

T4

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ConvNeXT Unbatched 32.137 31.84
Image Classification/ConvNeXT 4 120.944 110.209
Image Classification/ResNet Unbatched 9.761 7.698
Image Classification/ResNet 4 15.215 13.871
Object Detection/Conditional-DETR Unbatched 72.150 57.660
Object Detection/Conditional-DETR 4 301.494 247.543
Image Segmentation/MobileNet Unbatched 22.266 19.339
Image Segmentation/MobileNet 4 78.311 50.983
< > Update on GitHub