timm
/

Image Feature Extraction
timm
PyTorch
Safetensors

Forward_features data parallel

#3
by soorooshak - opened

Hi,
Is there a way to easily get patch embeddings with dataparallel ?
1)Forward function outputs only the cls token embedding.
2)Accessing .module.forward_features removes the parallelism (OOM error on larger batch sizes)
3)Replacing the forward function with forward_features causes a device mismatch:

    dino_model = timm.create_model('vit_large_patch14_dinov2.lvd142m', pretrained=True, num_classes=0).to(args.device)
    dino_model=torch.nn.DataParallel(dino_model,device_ids=args.device_ids)
    dino_model.forward=copy.deepcopy(dino_model.forward_features)
    self.dino_model=torch.nn.DataParallel(dino_model,device_ids=args.device_ids)
    patch_ft = self.dino_model(imgs)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:5 and cuda:3! (when checking argument for argument weight in method wrapper__cudnn_convolution)

this happens in a convolution layer in the patch_embed funtion.

Thanks for your help !

PyTorch Image Models org

@soorooshak DataParallel wraps forward() methods, you might want to try doing the .forward patch you have there before wrapping? Or maybe further wrapping the vit model in a nn.Module that calls foward_features()?

Otherwise, the equivalent of foward features should be to set a null pooling on the model dino_model = timm.create_model('vit_large_patch14_dinov2.lvd142m', pretrained=True, num_classes=0, global_pool='')

Sign up or log in to comment