LowKey / util /attack_utils.py
Jacob Logas
Fix device usage?
0792228 unverified
# Helper function for extracting features from pre-trained models
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from util.feature_extraction_utils import warp_image, normalize_batch
from util.prepare_utils import get_ensemble, extract_features
from lpips_pytorch import LPIPS
from tqdm import trange
tensor_transform = transforms.ToTensor()
pil_transform = transforms.ToPILImage()
class Attack(nn.Module):
def __init__(
self,
models,
dim,
attack_type,
eps,
c_sim=0.5,
net_type="alex",
lr=0.05,
n_iters=100,
noise_size=0.001,
n_starts=10,
c_tv=None,
sigma_gf=None,
kernel_size_gf=None,
combination=False,
warp=False,
theta_warp=None,
V_reduction=None,
):
super(Attack, self).__init__()
self.extractor_ens = get_ensemble(
models, sigma_gf, kernel_size_gf, combination, V_reduction, warp, theta_warp
)
# print("There are '{}'' models in the attack ensemble".format(len(self.extractor_ens)))
self.dim = dim
self.eps = eps
self.c_sim = c_sim
self.net_type = net_type
self.lr = lr
self.n_iters = n_iters
self.noise_size = noise_size
self.n_starts = n_starts
self.c_tv = None
self.attack_type = attack_type
self.warp = warp
self.theta_warp = theta_warp
if self.attack_type == "lpips":
self.lpips_loss = LPIPS(self.net_type)
def execute(self, images, dir_vec, direction):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device in Excute:", device)
self.lpips_loss.to(device)
images = Variable(images).to(device)
dir_vec = dir_vec.to(device)
# take norm wrt dim
dir_vec_norm = dir_vec.norm(dim=2).unsqueeze(2).to(device)
dist = torch.zeros(images.shape[0]).to(device)
adv_images = images.detach().clone()
if self.warp:
self.face_img = warp_image(images, self.theta_warp)
for start in range(self.n_starts):
# update adversarial images old and distance old
adv_images_old = adv_images.detach().clone()
dist_old = dist.clone()
# add noise to initialize ( - noise_size, noise_size)
noise_uniform = Variable(
2 * self.noise_size * torch.rand(images.size()) - self.noise_size
).to(device)
adv_images = Variable(
images.detach().clone() + noise_uniform, requires_grad=True
).to(device)
for i in trange(self.n_iters):
adv_features = extract_features(
adv_images, self.extractor_ens, self.dim
).to(device)
# normalize feature vectors in ensembles
loss = direction * torch.mean(
(adv_features - dir_vec) ** 2 / dir_vec_norm
)
if self.c_tv is not None:
tv_out = self.total_var_reg(images, adv_images)
loss -= self.c_tv * tv_out
if self.attack_type == "lpips":
lpips_out = self.lpips_reg(images, adv_images)
loss -= self.c_sim * lpips_out
grad = torch.autograd.grad(loss, [adv_images])
adv_images = adv_images + self.lr * grad[0].sign()
perturbation = adv_images - images
if self.attack_type == "sgd":
perturbation = torch.clamp(
perturbation, min=-self.eps, max=self.eps
)
adv_images = images + perturbation
adv_images = torch.clamp(adv_images, min=0, max=1)
adv_features = extract_features(
adv_images, self.extractor_ens, self.dim
).to(device)
dist = torch.mean((adv_features - dir_vec) ** 2 / dir_vec_norm, dim=[1, 2])
if direction == 1:
adv_images[dist < dist_old] = adv_images_old[dist < dist_old]
dist[dist < dist_old] = dist_old[dist < dist_old]
else:
adv_images[dist > dist_old] = adv_images_old[dist > dist_old]
dist[dist > dist_old] = dist_old[dist > dist_old]
return adv_images.detach().cpu()
def lpips_reg(self, images, adv_images):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.warp:
face_adv = warp_image(adv_images, self.theta_warp)
lpips_out = self.lpips_loss(
normalize_batch(self.face_img).to(device),
normalize_batch(face_adv).to(device),
)[0][0][0][0] / (2 * adv_images.shape[0])
lpips_out += self.lpips_loss(
normalize_batch(images).to(device),
normalize_batch(adv_images).to(device),
)[0][0][0][0] / (2 * adv_images.shape[0])
else:
lpips_out = (
self.lpips_loss(
normalize_batch(images).to(device),
normalize_batch(adv_images).to(device),
)[0][0][0][0]
/ adv_images.shape[0]
)
return lpips_out
def total_var_reg(images, adv_images):
perturbation = adv_images - images
tv = torch.mean(
torch.abs(perturbation[:, :, :, :-1] - perturbation[:, :, :, 1:])
) + torch.mean(
torch.abs(perturbation[:, :, :-1, :] - perturbation[:, :, 1:, :])
)
return tv