Kohaku-Blueleaf commited on
Commit
317b678
·
1 Parent(s): 26d4aa7

add missing files

Browse files
lineart_models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lineart import LineartDetector
2
+ from .lineart_anime import LineartAnimeDetector
3
+ from .mangaline_preprocessor import MangaLineExtraction
lineart_models/lineart.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/carolineec/informative-drawings
2
+ # MIT License
3
+ '''
4
+ MIT License
5
+
6
+ Copyright (c) 2022 Caroline Chan
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ '''
26
+
27
+ import os
28
+ import cv2
29
+ import torch
30
+ import numpy as np
31
+
32
+ import torch.nn as nn
33
+ from einops import rearrange
34
+ from .utils import load_file_from_url
35
+
36
+
37
+ norm_layer = nn.InstanceNorm2d
38
+
39
+
40
+ class ResidualBlock(nn.Module):
41
+ def __init__(self, in_features):
42
+ super(ResidualBlock, self).__init__()
43
+
44
+ conv_block = [ nn.ReflectionPad2d(1),
45
+ nn.Conv2d(in_features, in_features, 3),
46
+ norm_layer(in_features),
47
+ nn.ReLU(inplace=True),
48
+ nn.ReflectionPad2d(1),
49
+ nn.Conv2d(in_features, in_features, 3),
50
+ norm_layer(in_features)
51
+ ]
52
+
53
+ self.conv_block = nn.Sequential(*conv_block)
54
+
55
+ def forward(self, x):
56
+ return x + self.conv_block(x)
57
+
58
+
59
+ class Generator(nn.Module):
60
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
61
+ super(Generator, self).__init__()
62
+
63
+ # Initial convolution block
64
+ model0 = [ nn.ReflectionPad2d(3),
65
+ nn.Conv2d(input_nc, 64, 7),
66
+ norm_layer(64),
67
+ nn.ReLU(inplace=True) ]
68
+ self.model0 = nn.Sequential(*model0)
69
+
70
+ # Downsampling
71
+ model1 = []
72
+ in_features = 64
73
+ out_features = in_features*2
74
+ for _ in range(2):
75
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
76
+ norm_layer(out_features),
77
+ nn.ReLU(inplace=True) ]
78
+ in_features = out_features
79
+ out_features = in_features*2
80
+ self.model1 = nn.Sequential(*model1)
81
+
82
+ model2 = []
83
+ # Residual blocks
84
+ for _ in range(n_residual_blocks):
85
+ model2 += [ResidualBlock(in_features)]
86
+ self.model2 = nn.Sequential(*model2)
87
+
88
+ # Upsampling
89
+ model3 = []
90
+ out_features = in_features//2
91
+ for _ in range(2):
92
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
93
+ norm_layer(out_features),
94
+ nn.ReLU(inplace=True) ]
95
+ in_features = out_features
96
+ out_features = in_features//2
97
+ self.model3 = nn.Sequential(*model3)
98
+
99
+ # Output layer
100
+ model4 = [ nn.ReflectionPad2d(3),
101
+ nn.Conv2d(64, output_nc, 7)]
102
+ if sigmoid:
103
+ model4 += [nn.Sigmoid()]
104
+
105
+ self.model4 = nn.Sequential(*model4)
106
+
107
+ def forward(self, x, cond=None):
108
+ out = self.model0(x)
109
+ out = self.model1(out)
110
+ out = self.model2(out)
111
+ out = self.model3(out)
112
+ out = self.model4(out)
113
+
114
+ return out
115
+
116
+
117
+ class LineartDetector:
118
+ def __init__(self, model_path="hf_download"):
119
+ self.model = self.load_model('sk_model.pth', model_path)
120
+ self.model_coarse = self.load_model('sk_model2.pth', model_path)
121
+
122
+ def load_model(self, name, model_path="hf_download"):
123
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
124
+ modelpath = os.path.join(model_path, name)
125
+ if not os.path.exists(modelpath):
126
+ load_file_from_url(remote_model_path, model_dir=model_path)
127
+ model = Generator(3, 1, 3)
128
+ model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
129
+ model.eval()
130
+ model = model.cuda()
131
+ return model
132
+
133
+ def __call__(self, input_image, coarse=False):
134
+ model = self.model_coarse if coarse else self.model
135
+ assert input_image.ndim == 3
136
+ image = input_image
137
+ with torch.no_grad():
138
+ image = torch.from_numpy(image).float().cuda()
139
+ image = image / 255.0
140
+ image = rearrange(image, 'h w c -> 1 c h w')
141
+ line = model(image)[0][0]
142
+
143
+ line = line.cpu().numpy()
144
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
145
+
146
+ return line
lineart_models/lineart_anime.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Anime2sketch
2
+ # https://github.com/Mukosame/Anime2Sketch
3
+ '''
4
+ MIT License
5
+
6
+ Copyright (c) 2022 Caroline Chan
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ '''
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ import functools
31
+
32
+ import os
33
+ import cv2
34
+ from einops import rearrange
35
+
36
+
37
+ class UnetGenerator(nn.Module):
38
+ """Create a Unet-based generator"""
39
+
40
+ def __init__(
41
+ self,
42
+ input_nc,
43
+ output_nc,
44
+ num_downs,
45
+ ngf=64,
46
+ norm_layer=nn.BatchNorm2d,
47
+ use_dropout=False,
48
+ ):
49
+ """Construct a Unet generator
50
+ Parameters:
51
+ input_nc (int) -- the number of channels in input images
52
+ output_nc (int) -- the number of channels in output images
53
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
54
+ image of size 128x128 will become of size 1x1 # at the bottleneck
55
+ ngf (int) -- the number of filters in the last conv layer
56
+ norm_layer -- normalization layer
57
+ We construct the U-Net from the innermost layer to the outermost layer.
58
+ It is a recursive process.
59
+ """
60
+ super(UnetGenerator, self).__init__()
61
+ # construct unet structure
62
+ unet_block = UnetSkipConnectionBlock(
63
+ ngf * 8,
64
+ ngf * 8,
65
+ input_nc=None,
66
+ submodule=None,
67
+ norm_layer=norm_layer,
68
+ innermost=True,
69
+ ) # add the innermost layer
70
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
71
+ unet_block = UnetSkipConnectionBlock(
72
+ ngf * 8,
73
+ ngf * 8,
74
+ input_nc=None,
75
+ submodule=unet_block,
76
+ norm_layer=norm_layer,
77
+ use_dropout=use_dropout,
78
+ )
79
+ # gradually reduce the number of filters from ngf * 8 to ngf
80
+ unet_block = UnetSkipConnectionBlock(
81
+ ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
82
+ )
83
+ unet_block = UnetSkipConnectionBlock(
84
+ ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
85
+ )
86
+ unet_block = UnetSkipConnectionBlock(
87
+ ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
88
+ )
89
+ self.model = UnetSkipConnectionBlock(
90
+ output_nc,
91
+ ngf,
92
+ input_nc=input_nc,
93
+ submodule=unet_block,
94
+ outermost=True,
95
+ norm_layer=norm_layer,
96
+ ) # add the outermost layer
97
+
98
+ def forward(self, input):
99
+ """Standard forward"""
100
+ return self.model(input)
101
+
102
+
103
+ class UnetSkipConnectionBlock(nn.Module):
104
+ """Defines the Unet submodule with skip connection.
105
+ X -------------------identity----------------------
106
+ |-- downsampling -- |submodule| -- upsampling --|
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ outer_nc,
112
+ inner_nc,
113
+ input_nc=None,
114
+ submodule=None,
115
+ outermost=False,
116
+ innermost=False,
117
+ norm_layer=nn.BatchNorm2d,
118
+ use_dropout=False,
119
+ ):
120
+ """Construct a Unet submodule with skip connections.
121
+ Parameters:
122
+ outer_nc (int) -- the number of filters in the outer conv layer
123
+ inner_nc (int) -- the number of filters in the inner conv layer
124
+ input_nc (int) -- the number of channels in input images/features
125
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
126
+ outermost (bool) -- if this module is the outermost module
127
+ innermost (bool) -- if this module is the innermost module
128
+ norm_layer -- normalization layer
129
+ use_dropout (bool) -- if use dropout layers.
130
+ """
131
+ super(UnetSkipConnectionBlock, self).__init__()
132
+ self.outermost = outermost
133
+ if type(norm_layer) == functools.partial:
134
+ use_bias = norm_layer.func == nn.InstanceNorm2d
135
+ else:
136
+ use_bias = norm_layer == nn.InstanceNorm2d
137
+ if input_nc is None:
138
+ input_nc = outer_nc
139
+ downconv = nn.Conv2d(
140
+ input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
141
+ )
142
+ downrelu = nn.LeakyReLU(0.2, True)
143
+ downnorm = norm_layer(inner_nc)
144
+ uprelu = nn.ReLU(True)
145
+ upnorm = norm_layer(outer_nc)
146
+
147
+ if outermost:
148
+ upconv = nn.ConvTranspose2d(
149
+ inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1
150
+ )
151
+ down = [downconv]
152
+ up = [uprelu, upconv, nn.Tanh()]
153
+ model = down + [submodule] + up
154
+ elif innermost:
155
+ upconv = nn.ConvTranspose2d(
156
+ inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
157
+ )
158
+ down = [downrelu, downconv]
159
+ up = [uprelu, upconv, upnorm]
160
+ model = down + up
161
+ else:
162
+ upconv = nn.ConvTranspose2d(
163
+ inner_nc * 2,
164
+ outer_nc,
165
+ kernel_size=4,
166
+ stride=2,
167
+ padding=1,
168
+ bias=use_bias,
169
+ )
170
+ down = [downrelu, downconv, downnorm]
171
+ up = [uprelu, upconv, upnorm]
172
+
173
+ if use_dropout:
174
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
175
+ else:
176
+ model = down + [submodule] + up
177
+
178
+ self.model = nn.Sequential(*model)
179
+
180
+ def forward(self, x):
181
+ if self.outermost:
182
+ return self.model(x)
183
+ else: # add skip connections
184
+ return torch.cat([x, self.model(x)], 1)
185
+
186
+
187
+ class LineartAnimeDetector:
188
+ def __init__(self, model_path="hf_download"):
189
+ remote_model_path = (
190
+ "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth"
191
+ )
192
+ modelpath = os.path.join(model_path, "netG.pth")
193
+ if not os.path.exists(modelpath):
194
+ from .utils import load_file_from_url
195
+
196
+ load_file_from_url(remote_model_path, model_dir=model_path)
197
+ norm_layer = functools.partial(
198
+ nn.InstanceNorm2d, affine=False, track_running_stats=False
199
+ )
200
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
201
+ ckpt = torch.load(modelpath)
202
+ for key in list(ckpt.keys()):
203
+ if "module." in key:
204
+ ckpt[key.replace("module.", "")] = ckpt[key]
205
+ del ckpt[key]
206
+ net.load_state_dict(ckpt)
207
+ net = net.cuda()
208
+ net.eval()
209
+ self.model = net
210
+
211
+ def __call__(self, input_image):
212
+ H, W, C = input_image.shape
213
+ Hn = 256 * int(np.ceil(float(H) / 256.0))
214
+ Wn = 256 * int(np.ceil(float(W) / 256.0))
215
+ img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC)
216
+ with torch.no_grad():
217
+ image_feed = torch.from_numpy(img).float().cuda()
218
+ image_feed = image_feed / 127.5 - 1.0
219
+ image_feed = rearrange(image_feed, "h w c -> 1 c h w")
220
+
221
+ line = self.model(image_feed)[0, 0] * 127.5 + 127.5
222
+ line = line.cpu().numpy()
223
+
224
+ line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC)
225
+ line = line.clip(0, 255).astype(np.uint8)
226
+ return line
lineart_models/mangaline_preprocessor.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ MIT License
3
+
4
+ Copyright (c) 2021 Miaomiao Li
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+ '''
24
+ import os
25
+
26
+ import cv2
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from einops import rearrange
31
+
32
+ from .utils import load_file_from_url
33
+
34
+
35
+ class _bn_relu_conv(nn.Module):
36
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
37
+ super(_bn_relu_conv, self).__init__()
38
+ self.model = nn.Sequential(
39
+ nn.BatchNorm2d(in_filters, eps=1e-3),
40
+ nn.LeakyReLU(0.2),
41
+ nn.Conv2d(
42
+ in_filters,
43
+ nb_filters,
44
+ (fw, fh),
45
+ stride=subsample,
46
+ padding=(fw // 2, fh // 2),
47
+ padding_mode="zeros",
48
+ ),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.model(x)
53
+
54
+ # the following are for debugs
55
+ print(
56
+ "****",
57
+ np.max(x.cpu().numpy()),
58
+ np.min(x.cpu().numpy()),
59
+ np.mean(x.cpu().numpy()),
60
+ np.std(x.cpu().numpy()),
61
+ x.shape,
62
+ )
63
+ for i, layer in enumerate(self.model):
64
+ if i != 2:
65
+ x = layer(x)
66
+ else:
67
+ x = layer(x)
68
+ # x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
69
+ print(
70
+ "____",
71
+ np.max(x.cpu().numpy()),
72
+ np.min(x.cpu().numpy()),
73
+ np.mean(x.cpu().numpy()),
74
+ np.std(x.cpu().numpy()),
75
+ x.shape,
76
+ )
77
+ print(x[0])
78
+ return x
79
+
80
+
81
+ class _u_bn_relu_conv(nn.Module):
82
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
83
+ super(_u_bn_relu_conv, self).__init__()
84
+ self.model = nn.Sequential(
85
+ nn.BatchNorm2d(in_filters, eps=1e-3),
86
+ nn.LeakyReLU(0.2),
87
+ nn.Conv2d(
88
+ in_filters,
89
+ nb_filters,
90
+ (fw, fh),
91
+ stride=subsample,
92
+ padding=(fw // 2, fh // 2),
93
+ ),
94
+ nn.Upsample(scale_factor=2, mode="nearest"),
95
+ )
96
+
97
+ def forward(self, x):
98
+ return self.model(x)
99
+
100
+
101
+ class _shortcut(nn.Module):
102
+ def __init__(self, in_filters, nb_filters, subsample=1):
103
+ super(_shortcut, self).__init__()
104
+ self.process = False
105
+ self.model = None
106
+ if in_filters != nb_filters or subsample != 1:
107
+ self.process = True
108
+ self.model = nn.Sequential(
109
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
110
+ )
111
+
112
+ def forward(self, x, y):
113
+ # print(x.size(), y.size(), self.process)
114
+ if self.process:
115
+ y0 = self.model(x)
116
+ # print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
117
+ return y0 + y
118
+ else:
119
+ # print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
120
+ return x + y
121
+
122
+
123
+ class _u_shortcut(nn.Module):
124
+ def __init__(self, in_filters, nb_filters, subsample):
125
+ super(_u_shortcut, self).__init__()
126
+ self.process = False
127
+ self.model = None
128
+ if in_filters != nb_filters:
129
+ self.process = True
130
+ self.model = nn.Sequential(
131
+ nn.Conv2d(
132
+ in_filters,
133
+ nb_filters,
134
+ (1, 1),
135
+ stride=subsample,
136
+ padding_mode="zeros",
137
+ ),
138
+ nn.Upsample(scale_factor=2, mode="nearest"),
139
+ )
140
+
141
+ def forward(self, x, y):
142
+ if self.process:
143
+ return self.model(x) + y
144
+ else:
145
+ return x + y
146
+
147
+
148
+ class basic_block(nn.Module):
149
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
150
+ super(basic_block, self).__init__()
151
+ self.conv1 = _bn_relu_conv(
152
+ in_filters, nb_filters, 3, 3, subsample=init_subsample
153
+ )
154
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
155
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
156
+
157
+ def forward(self, x):
158
+ x1 = self.conv1(x)
159
+ x2 = self.residual(x1)
160
+ return self.shortcut(x, x2)
161
+
162
+
163
+ class _u_basic_block(nn.Module):
164
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
165
+ super(_u_basic_block, self).__init__()
166
+ self.conv1 = _u_bn_relu_conv(
167
+ in_filters, nb_filters, 3, 3, subsample=init_subsample
168
+ )
169
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
170
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
171
+
172
+ def forward(self, x):
173
+ y = self.residual(self.conv1(x))
174
+ return self.shortcut(x, y)
175
+
176
+
177
+ class _residual_block(nn.Module):
178
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
179
+ super(_residual_block, self).__init__()
180
+ layers = []
181
+ for i in range(repetitions):
182
+ init_subsample = 1
183
+ if i == repetitions - 1 and not is_first_layer:
184
+ init_subsample = 2
185
+ if i == 0:
186
+ l = basic_block(
187
+ in_filters=in_filters,
188
+ nb_filters=nb_filters,
189
+ init_subsample=init_subsample,
190
+ )
191
+ else:
192
+ l = basic_block(
193
+ in_filters=nb_filters,
194
+ nb_filters=nb_filters,
195
+ init_subsample=init_subsample,
196
+ )
197
+ layers.append(l)
198
+
199
+ self.model = nn.Sequential(*layers)
200
+
201
+ def forward(self, x):
202
+ return self.model(x)
203
+
204
+
205
+ class _upsampling_residual_block(nn.Module):
206
+ def __init__(self, in_filters, nb_filters, repetitions):
207
+ super(_upsampling_residual_block, self).__init__()
208
+ layers = []
209
+ for i in range(repetitions):
210
+ l = None
211
+ if i == 0:
212
+ l = _u_basic_block(
213
+ in_filters=in_filters, nb_filters=nb_filters
214
+ ) # (input)
215
+ else:
216
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters) # (input)
217
+ layers.append(l)
218
+
219
+ self.model = nn.Sequential(*layers)
220
+
221
+ def forward(self, x):
222
+ return self.model(x)
223
+
224
+
225
+ class res_skip(nn.Module):
226
+ def __init__(self):
227
+ super(res_skip, self).__init__()
228
+ self.block0 = _residual_block(
229
+ in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True
230
+ ) # (input)
231
+ self.block1 = _residual_block(
232
+ in_filters=24, nb_filters=48, repetitions=3
233
+ ) # (block0)
234
+ self.block2 = _residual_block(
235
+ in_filters=48, nb_filters=96, repetitions=5
236
+ ) # (block1)
237
+ self.block3 = _residual_block(
238
+ in_filters=96, nb_filters=192, repetitions=7
239
+ ) # (block2)
240
+ self.block4 = _residual_block(
241
+ in_filters=192, nb_filters=384, repetitions=12
242
+ ) # (block3)
243
+
244
+ self.block5 = _upsampling_residual_block(
245
+ in_filters=384, nb_filters=192, repetitions=7
246
+ ) # (block4)
247
+ self.res1 = _shortcut(
248
+ in_filters=192, nb_filters=192
249
+ ) # (block3, block5, subsample=(1,1))
250
+
251
+ self.block6 = _upsampling_residual_block(
252
+ in_filters=192, nb_filters=96, repetitions=5
253
+ ) # (res1)
254
+ self.res2 = _shortcut(
255
+ in_filters=96, nb_filters=96
256
+ ) # (block2, block6, subsample=(1,1))
257
+
258
+ self.block7 = _upsampling_residual_block(
259
+ in_filters=96, nb_filters=48, repetitions=3
260
+ ) # (res2)
261
+ self.res3 = _shortcut(
262
+ in_filters=48, nb_filters=48
263
+ ) # (block1, block7, subsample=(1,1))
264
+
265
+ self.block8 = _upsampling_residual_block(
266
+ in_filters=48, nb_filters=24, repetitions=2
267
+ ) # (res3)
268
+ self.res4 = _shortcut(
269
+ in_filters=24, nb_filters=24
270
+ ) # (block0,block8, subsample=(1,1))
271
+
272
+ self.block9 = _residual_block(
273
+ in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True
274
+ ) # (res4)
275
+ self.conv15 = _bn_relu_conv(
276
+ in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1
277
+ ) # (block7)
278
+
279
+ def forward(self, x):
280
+ x0 = self.block0(x)
281
+ x1 = self.block1(x0)
282
+ x2 = self.block2(x1)
283
+ x3 = self.block3(x2)
284
+ x4 = self.block4(x3)
285
+
286
+ x5 = self.block5(x4)
287
+ res1 = self.res1(x3, x5)
288
+
289
+ x6 = self.block6(res1)
290
+ res2 = self.res2(x2, x6)
291
+
292
+ x7 = self.block7(res2)
293
+ res3 = self.res3(x1, x7)
294
+
295
+ x8 = self.block8(res3)
296
+ res4 = self.res4(x0, x8)
297
+
298
+ x9 = self.block9(res4)
299
+ y = self.conv15(x9)
300
+
301
+ return y
302
+
303
+
304
+ class MangaLineExtraction:
305
+ def __init__(self, device=None, model_dir=None):
306
+ self.model = None
307
+ self.device = device
308
+ MangaLineExtraction.model_dir = model_dir
309
+
310
+ def load_model(self):
311
+ remote_model_path = (
312
+ "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth"
313
+ )
314
+ modelpath = os.path.join(self.model_dir, "erika.pth")
315
+ if not os.path.exists(modelpath):
316
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
317
+ # norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
318
+ net = res_skip()
319
+ ckpt = torch.load(modelpath)
320
+ for key in list(ckpt.keys()):
321
+ if "module." in key:
322
+ ckpt[key.replace("module.", "")] = ckpt[key]
323
+ del ckpt[key]
324
+ net.load_state_dict(ckpt)
325
+ net.eval()
326
+ self.model = net.to(self.device)
327
+
328
+ def unload_model(self):
329
+ if self.model is not None:
330
+ self.model.cpu()
331
+
332
+ def __call__(self, input_image):
333
+ if self.model is None:
334
+ self.load_model()
335
+ self.model.to(self.device)
336
+ # if width or height is not divisible by 16, pad the image
337
+ h, w = input_image.shape[:2]
338
+ # get adjusted pixel amount to max 1280x1280
339
+ total_pixels = h * w
340
+ if total_pixels > 1280 * 1280:
341
+ ratio = (1280 * 1280) / total_pixels
342
+ ratio = ratio**0.5
343
+ h = int(h * ratio)
344
+ w = int(w * ratio)
345
+ divisible = 16
346
+ h = h + (divisible - h % divisible) % divisible
347
+ w = w + (divisible - w % divisible) % divisible
348
+ input_image = cv2.resize(input_image, (w, h))
349
+ img = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
350
+ img = np.ascontiguousarray(img.copy()).copy()
351
+ with torch.no_grad():
352
+ image_feed = torch.from_numpy(img).float().to(self.device)
353
+ image_feed = rearrange(image_feed, "h w -> 1 1 h w")
354
+ line = self.model(image_feed).cpu().numpy()[0, 0]
355
+ # line = 255 - line
356
+ return line.clip(0, 255).astype(np.uint8)
lineart_models/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from urllib.parse import urlparse
6
+
7
+
8
+ def load_file_from_url(
9
+ url: str,
10
+ *,
11
+ model_dir: str,
12
+ progress: bool = True,
13
+ file_name: str | None = None,
14
+ ) -> str:
15
+ """Download a file from `url` into `model_dir`, using the file present if possible.
16
+
17
+ Returns the path to the downloaded file.
18
+ """
19
+ os.makedirs(model_dir, exist_ok=True)
20
+ if not file_name:
21
+ parts = urlparse(url)
22
+ file_name = os.path.basename(parts.path)
23
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
24
+ if not os.path.exists(cached_file):
25
+ print(f'Downloading: "{url}" to {cached_file}\n')
26
+ from torch.hub import download_url_to_file
27
+
28
+ download_url_to_file(url, cached_file, progress=progress)
29
+ return cached_file
30
+
31
+
32
+ def combine_linearts(lineart1: np.ndarray, lineart2: np.ndarray, erode=[False, False]) -> np.ndarray:
33
+ if erode[0]:
34
+ lineart1 = cv2.erode(lineart1, np.ones((3, 3), np.uint8))
35
+ if erode[1]:
36
+ lineart2 = cv2.erode(lineart2, np.ones((3, 3), np.uint8))
37
+ # unify the dark part of lineart1 and lineart2
38
+ union = np.where(lineart1 < lineart2, lineart1, lineart2)
39
+ return union