Forward_features data parallel
Hi,
Is there a way to easily get patch embeddings with dataparallel ?
1)Forward function outputs only the cls token embedding.
2)Accessing .module.forward_features removes the parallelism (OOM error on larger batch sizes)
3)Replacing the forward function with forward_features causes a device mismatch:
dino_model = timm.create_model('vit_large_patch14_dinov2.lvd142m', pretrained=True, num_classes=0).to(args.device)
dino_model=torch.nn.DataParallel(dino_model,device_ids=args.device_ids)
dino_model.forward=copy.deepcopy(dino_model.forward_features)
self.dino_model=torch.nn.DataParallel(dino_model,device_ids=args.device_ids)
patch_ft = self.dino_model(imgs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:5 and cuda:3! (when checking argument for argument weight in method wrapper__cudnn_convolution)
this happens in a convolution layer in the patch_embed funtion.
Thanks for your help !
@soorooshak DataParallel wraps forward() methods, you might want to try doing the .forward patch you have there before wrapping? Or maybe further wrapping the vit model in a nn.Module that calls foward_features()?
Otherwise, the equivalent of foward features should be to set a null pooling on the model dino_model = timm.create_model('vit_large_patch14_dinov2.lvd142m', pretrained=True, num_classes=0, global_pool='')