ailm commited on
Commit
ed7df54
1 Parent(s): 4cae5ff

Upload 3 files

Browse files
Files changed (3) hide show
  1. dataTransform.py +29 -0
  2. styleTransfer.py +75 -0
  3. vggModel.py +24 -0
dataTransform.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torchvision.transforms as transforms #to transform the images
3
+
4
+
5
+ def load_image(image_path, device):
6
+
7
+ image_size = 356
8
+
9
+ loader = transforms.Compose(
10
+ [
11
+ transforms.Resize((image_size, image_size)), #RESIZE IMAGE
12
+ transforms.ToTensor() #TRANSFORM IMAGE TO TENSOR
13
+ ]
14
+ )
15
+
16
+ image = Image.open(image_path)
17
+ image = loader(image).unsqueeze(0) #(h, c, w) -> (1, h, c, w) adds batch dim
18
+
19
+ return image.to(device)
20
+
21
+
22
+ def tensor_to_image(tensor):
23
+ tensor = tensor.clone().detach() # Ensure the tensor is detached from the graph
24
+ tensor = tensor.squeeze(0) # Remove batch dimension if present
25
+ tensor = torch.clamp(tensor, 0, 1) # Clamp the values to [0, 1] range
26
+
27
+ unloader = transforms.ToPILImage()
28
+ image = unloader(tensor.cpu())
29
+ return image
styleTransfer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # for model
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from PIL import Image #for importing images
6
+ import torchvision.models as models #to load vgg 19 model
7
+ import torchvision.transforms as transforms
8
+ from tqdm import tqdm
9
+
10
+ from dataTransform import load_image
11
+ from vggModel import VGGNet
12
+
13
+ def style_transfer(content_img, style_img, total_steps, alpha=1e5, beta=1e10, learning_rate=0.001):
14
+ # Preprocess the input images
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ print('-'*30)
18
+ print(f'Device Initialized: {device}')
19
+ print('-'*30)
20
+ content_img = load_image(content_img, device)
21
+ style_img = load_image(style_img, device)
22
+ generated_img = content_img.clone().requires_grad_(True)
23
+ optimizer = optim.Adam([generated_img], lr = learning_rate)
24
+ model = VGGNet().to(device).eval()
25
+
26
+ # print(content_img.shape)
27
+ # print(style_img.shape)
28
+ # print(generated_img.shape)
29
+
30
+
31
+ for step in tqdm(range(total_steps)):
32
+
33
+ #first we send the 3 images from the vgg network
34
+
35
+ generated_feats = model(generated_img)
36
+ original_image_feats = model(content_img)
37
+ style_feats = model(style_img)
38
+
39
+ #defining the style loss
40
+
41
+ style_loss = original_loss = 0
42
+
43
+
44
+ for gen_feat, orig_image_feat, styl_feat in zip(generated_feats, original_image_feats, style_feats): #looping over each feature
45
+
46
+ # print(gen_feat.shape)
47
+ # print(orig_image_feat.shape)
48
+ # print(styl_feat.shape)
49
+
50
+ batch, channel, height, width = gen_feat.shape
51
+ original_loss += torch.mean((gen_feat - orig_image_feat)**2)
52
+
53
+ # computing gram matrix for gen and style to compute style loss
54
+
55
+ G = gen_feat.view(channel, height*width).mm(
56
+ gen_feat.view(channel, height*width).t()
57
+ )
58
+
59
+ # correlation matrix
60
+
61
+ A = styl_feat.view(channel, height*width).mm(
62
+ styl_feat.view(channel, height*width).t()
63
+ )
64
+
65
+ style_loss += torch.mean((G-A)**2)
66
+
67
+ total_loss = alpha*original_loss + beta*style_loss
68
+
69
+ optimizer.zero_grad()
70
+ total_loss.backward()
71
+ optimizer.step()
72
+
73
+ if step == total_steps - 1:
74
+ # Postprocess and return the final generated image
75
+ return generated_img
vggModel.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # for model
2
+ import torch.nn as nn
3
+ import torchvision.models as models #to load vgg 19 model
4
+
5
+
6
+ class VGGNet(nn.Module):
7
+
8
+ def __init__(self):
9
+
10
+ super(VGGNet, self).__init__()
11
+ self.chosen_features = ['0', '5', '10', '19', '28']
12
+ self.vgg = models.vgg19(pretrained = True).features #select only certain layers to extract fetaures
13
+
14
+
15
+ def forward(self,x):
16
+ features = [] #returns features from selected conv layers from VGG19 pretrained model
17
+
18
+ for layer_num, layer in self.vgg._modules.items():
19
+ x = layer(x)
20
+
21
+ if layer_num in self.chosen_features:
22
+ features.append(x)
23
+
24
+ return features