Harisreedhar commited on
Commit
71c9afb
β€’
1 Parent(s): 47353b7
app.py CHANGED
@@ -4,7 +4,6 @@ import glob
4
  import time
5
  import torch
6
  import shutil
7
- import gfpgan
8
  import argparse
9
  import platform
10
  import datetime
@@ -13,22 +12,22 @@ import insightface
13
  import onnxruntime
14
  import numpy as np
15
  import gradio as gr
16
- from moviepy.editor import VideoFileClip, ImageSequenceClip
 
17
 
18
- from face_analyser import detect_conditions, analyse_face
19
- from utils import trim_video, StreamerThread, ProcessBar, open_directory
 
 
20
  from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
21
- from swapper import (
22
- swap_face,
23
- swap_face_with_condition,
24
- swap_specific,
25
- swap_options_list,
26
- )
27
 
28
  ## ------------------------------ USER ARGS ------------------------------
29
 
30
  parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
31
  parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
 
32
  parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
33
  parser.add_argument(
34
  "--colab", action="store_true", help="Enable colab mode", default=False
@@ -40,11 +39,12 @@ user_args = parser.parse_args()
40
  USE_COLAB = user_args.colab
41
  USE_CUDA = user_args.cuda
42
  DEF_OUTPUT_PATH = user_args.out_dir
 
43
  WORKSPACE = None
44
  OUTPUT_FILE = None
45
  CURRENT_FRAME = None
46
  STREAMER = None
47
- DETECT_CONDITION = "left most"
48
  DETECT_SIZE = 640
49
  DETECT_THRESH = 0.6
50
  NUM_OF_SRC_SPECIFIC = 10
@@ -67,6 +67,7 @@ FACE_SWAPPER = None
67
  FACE_ANALYSER = None
68
  FACE_ENHANCER = None
69
  FACE_PARSER = None
 
70
 
71
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
72
  # Note: For AMD,MAC or non CUDA users, change settings here
@@ -99,25 +100,22 @@ def load_face_analyser_model(name="buffalo_l"):
99
  )
100
 
101
 
102
- def load_face_swapper_model(name="./assets/pretrained_models/inswapper_128.onnx"):
103
  global FACE_SWAPPER
104
- path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
105
  if FACE_SWAPPER is None:
106
- FACE_SWAPPER = insightface.model_zoo.get_model(path, providers=PROVIDER)
107
-
108
-
109
- def load_face_enhancer_model(name="./assets/pretrained_models/GFPGANv1.4.pth"):
110
- global FACE_ENHANCER
111
- path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
112
- if FACE_ENHANCER is None:
113
- FACE_ENHANCER = gfpgan.GFPGANer(model_path=path, upscale=1)
114
 
115
 
116
- def load_face_parser_model(name="./assets/pretrained_models/79999_iter.pth"):
117
  global FACE_PARSER
118
- path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
119
  if FACE_PARSER is None:
120
- FACE_PARSER = init_parser(name, mode=device)
 
 
 
 
 
121
 
122
 
123
  load_face_analyser_model()
@@ -138,12 +136,18 @@ def process(
138
  condition,
139
  age,
140
  distance,
141
- face_enhance,
142
  enable_face_parser,
143
  mask_includes,
144
  mask_soft_kernel,
145
  mask_soft_iterations,
146
  blur_amount,
 
 
 
 
 
 
147
  *specifics,
148
  ):
149
  global WORKSPACE
@@ -177,12 +181,13 @@ def process(
177
  gr.update(value=OUTPUT_FILE, visible=True),
178
  )
179
 
180
- ## ------------------------------ LOAD PENDING MODELS ------------------------------
181
  start_time = time.time()
182
- specifics = list(specifics)
183
- half = len(specifics) // 2
184
- sources = specifics[:half]
185
- specifics = specifics[half:]
 
 
186
 
187
  yield "### \n βŒ› Loading face analyser model...", *ui_before()
188
  load_face_analyser_model()
@@ -190,87 +195,100 @@ def process(
190
  yield "### \n βŒ› Loading face swapper model...", *ui_before()
191
  load_face_swapper_model()
192
 
193
- if face_enhance:
194
- yield "### \n βŒ› Loading face enhancer model...", *ui_before()
195
- load_face_enhancer_model()
 
 
196
 
197
  if enable_face_parser:
198
  yield "### \n βŒ› Loading face parsing model...", *ui_before()
199
  load_face_parser_model()
200
 
201
- yield "### \n βŒ› Analysing Face...", *ui_before()
202
-
203
  includes = mask_regions_to_list(mask_includes)
204
- if mask_soft_iterations > 0:
205
- smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=int(mask_soft_iterations)).to(device)
206
- else:
207
- smooth_mask = None
208
-
209
- models = {
210
- "swap": FACE_SWAPPER,
211
- "enhance": FACE_ENHANCER,
212
- "enhance_sett": face_enhance,
213
- "face_parser": FACE_PARSER,
214
- "face_parser_sett": (enable_face_parser, includes, smooth_mask, int(blur_amount))
215
- }
216
-
217
- ## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
218
-
219
- analysed_source_specific = []
220
- if condition == "Specific Face":
221
- for source, specific in zip(sources, specifics):
222
- if source is None or specific is None:
223
- continue
224
- analysed_source = analyse_face(
225
- source,
226
- FACE_ANALYSER,
227
- return_single_face=True,
228
- detect_condition=DETECT_CONDITION,
229
- )
230
- analysed_specific = analyse_face(
231
- specific,
232
- FACE_ANALYSER,
233
- return_single_face=True,
234
- detect_condition=DETECT_CONDITION,
235
- )
236
- analysed_source_specific.append([analysed_source, analysed_specific])
237
- else:
238
- source = cv2.imread(source_path)
239
- analysed_source = analyse_face(
240
- source,
241
  FACE_ANALYSER,
242
- return_single_face=True,
 
 
243
  detect_condition=DETECT_CONDITION,
 
244
  )
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  ## ------------------------------ IMAGE ------------------------------
247
 
248
  if input_type == "Image":
249
  target = cv2.imread(image_path)
250
- analysed_target = analyse_face(target, FACE_ANALYSER, return_single_face=False)
251
- if condition == "Specific Face":
252
- swapped = swap_specific(
253
- analysed_source_specific,
254
- analysed_target,
255
- target,
256
- models,
257
- threshold=distance,
258
- )
259
- else:
260
- swapped = swap_face_with_condition(
261
- target, analysed_target, analysed_source, condition, age, models
262
- )
263
 
264
- filename = os.path.join(output_path, output_name + ".png")
265
- cv2.imwrite(filename, swapped)
266
- OUTPUT_FILE = filename
267
- WORKSPACE = output_path
268
- PREVIEW = swapped[:, :, ::-1]
269
 
270
- tot_exec_time = time.time() - start_time
271
- _min, _sec = divmod(tot_exec_time, 60)
 
272
 
273
- yield f"Completed in {int(_min)} min {int(_sec)} sec.", *ui_after()
274
 
275
  ## ------------------------------ VIDEO ------------------------------
276
 
@@ -278,72 +296,26 @@ def process(
278
  temp_path = os.path.join(output_path, output_name, "sequence")
279
  os.makedirs(temp_path, exist_ok=True)
280
 
281
- video_clip = VideoFileClip(video_path)
282
- duration = video_clip.duration
283
- fps = video_clip.fps
284
- total_frames = video_clip.reader.nframes
285
-
286
- analysed_targets = []
287
- process_bar = ProcessBar(30, total_frames)
288
- yield "### \n βŒ› Analysing...", *ui_before()
289
- for i, frame in enumerate(video_clip.iter_frames()):
290
- analysed_targets.append(
291
- analyse_face(frame, FACE_ANALYSER, return_single_face=False)
292
- )
293
- info_text = "Analysing Faces || "
294
- info_text += process_bar.get(i)
295
- print("\033[1A\033[K", end="", flush=True)
296
- print(info_text)
297
- if i % 10 == 0:
298
- yield "### \n" + info_text, *ui_before()
299
- video_clip.close()
300
-
301
  image_sequence = []
302
- video_clip = VideoFileClip(video_path)
303
- audio_clip = video_clip.audio if video_clip.audio is not None else None
304
- process_bar = ProcessBar(30, total_frames)
305
- yield "### \n βŒ› Swapping...", *ui_before()
306
- for i, frame in enumerate(video_clip.iter_frames()):
307
- swapped = frame
308
- analysed_target = analysed_targets[i]
309
-
310
- if condition == "Specific Face":
311
- swapped = swap_specific(
312
- analysed_source_specific,
313
- analysed_target,
314
- frame,
315
- models,
316
- threshold=distance,
317
- )
318
- else:
319
- swapped = swap_face_with_condition(
320
- frame, analysed_target, analysed_source, condition, age, models
321
- )
322
-
323
- image_path = os.path.join(temp_path, f"frame_{i}.png")
324
- cv2.imwrite(image_path, swapped[:, :, ::-1])
325
- image_sequence.append(image_path)
326
-
327
- info_text = "Swapping Faces || "
328
- info_text += process_bar.get(i)
329
- print("\033[1A\033[K", end="", flush=True)
330
- print(info_text)
331
- if i % 6 == 0:
332
- PREVIEW = swapped
333
- yield "### \n" + info_text, *ui_before()
334
-
335
- yield "### \n βŒ› Merging...", *ui_before()
336
- edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
337
-
338
- if audio_clip is not None:
339
- edited_video_clip = edited_video_clip.set_audio(audio_clip)
340
-
341
  output_video_path = os.path.join(output_path, output_name + ".mp4")
342
- edited_video_clip.set_duration(duration).write_videofile(
343
- output_video_path, codec="libx264"
344
- )
345
- edited_video_clip.close()
346
- video_clip.close()
347
 
348
  if os.path.exists(temp_path) and not keep_output_sequence:
349
  yield "### \n βŒ› Removing temporary files...", *ui_before()
@@ -352,99 +324,38 @@ def process(
352
  WORKSPACE = output_path
353
  OUTPUT_FILE = output_video_path
354
 
355
- tot_exec_time = time.time() - start_time
356
- _min, _sec = divmod(tot_exec_time, 60)
357
-
358
- yield f"βœ”οΈ Completed in {int(_min)} min {int(_sec)} sec.", *ui_after_vid()
359
 
360
  ## ------------------------------ DIRECTORY ------------------------------
361
 
362
  elif input_type == "Directory":
363
- source = cv2.imread(source_path)
364
- source = analyse_face(
365
- source,
366
- FACE_ANALYSER,
367
- return_single_face=True,
368
- detect_condition=DETECT_CONDITION,
369
- )
370
  extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
371
  temp_path = os.path.join(output_path, output_name)
372
  if os.path.exists(temp_path):
373
  shutil.rmtree(temp_path)
374
  os.mkdir(temp_path)
375
- swapped = None
376
 
377
- files = []
378
  for file_path in glob.glob(os.path.join(directory_path, "*")):
379
  if any(file_path.lower().endswith(ext) for ext in extensions):
380
- files.append(file_path)
381
-
382
- files_length = len(files)
383
- filename = None
384
- for i, file_path in enumerate(files):
385
- target = cv2.imread(file_path)
386
- analysed_target = analyse_face(
387
- target, FACE_ANALYSER, return_single_face=False
388
- )
389
-
390
- if condition == "Specific Face":
391
- swapped = swap_specific(
392
- analysed_source_specific,
393
- analysed_target,
394
- target,
395
- models,
396
- threshold=distance,
397
- )
398
- else:
399
- swapped = swap_face_with_condition(
400
- target, analysed_target, analysed_source, condition, age, models
401
- )
402
 
403
- filename = os.path.join(temp_path, os.path.basename(file_path))
404
- cv2.imwrite(filename, swapped)
405
- info_text = f"### \n βŒ› Processing file {i+1} of {files_length}"
406
- PREVIEW = swapped[:, :, ::-1]
407
- yield info_text, *ui_before()
408
 
 
409
  WORKSPACE = temp_path
410
- OUTPUT_FILE = filename
411
-
412
- tot_exec_time = time.time() - start_time
413
- _min, _sec = divmod(tot_exec_time, 60)
414
 
415
- yield f"βœ”οΈ Completed in {int(_min)} min {int(_sec)} sec.", *ui_after()
416
 
417
  ## ------------------------------ STREAM ------------------------------
418
 
419
  elif input_type == "Stream":
420
- yield "### \n βŒ› Starting...", *ui_before()
421
- global STREAMER
422
- STREAMER = StreamerThread(src=directory_path)
423
- STREAMER.start()
424
-
425
- while True:
426
- try:
427
- target = STREAMER.frame
428
- analysed_target = analyse_face(
429
- target, FACE_ANALYSER, return_single_face=False
430
- )
431
- if condition == "Specific Face":
432
- swapped = swap_specific(
433
- target,
434
- analysed_target,
435
- analysed_source_specific,
436
- models,
437
- threshold=distance,
438
- )
439
- else:
440
- swapped = swap_face_with_condition(
441
- target, analysed_target, analysed_source, condition, age, models
442
- )
443
- PREVIEW = swapped[:, :, ::-1]
444
- yield f"Streaming...", *ui_before()
445
- except AttributeError:
446
- yield "Streaming...", *ui_before()
447
- STREAMER.stop()
448
 
449
 
450
  ## ------------------------------ GRADIO FUNC ------------------------------
@@ -626,10 +537,6 @@ with gr.Blocks(css=css) as interface:
626
  )
627
 
628
  with gr.Tab("πŸͺ„ Other Settings"):
629
- with gr.Accordion("Enhance Face", open=True):
630
- enable_face_enhance = gr.Checkbox(
631
- label="Enable GFPGAN", value=False, interactive=True
632
- )
633
  with gr.Accordion("Advanced Mask", open=False):
634
  enable_face_parser_mask = gr.Checkbox(
635
  label="Enable Face Parsing",
@@ -665,6 +572,30 @@ with gr.Blocks(css=css) as interface:
665
  interactive=True,
666
  )
667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  source_image_input = gr.Image(
669
  label="Source face", type="filepath", interactive=True
670
  )
@@ -690,7 +621,7 @@ with gr.Blocks(css=css) as interface:
690
 
691
  with gr.Group():
692
  input_type = gr.Radio(
693
- ["Image", "Video"],#["Image", "Video", "Directory", "Stream"],
694
  label="Target Type",
695
  value="Video",
696
  )
@@ -701,7 +632,7 @@ with gr.Blocks(css=css) as interface:
701
  )
702
 
703
  with gr.Box(visible=True) as input_video_group:
704
- vid_widget = gr.Video #gr.Video if USE_COLAB else gr.Text
705
  video_input = vid_widget(
706
  label="Target Video Path", interactive=True
707
  )
@@ -794,14 +725,14 @@ with gr.Blocks(css=css) as interface:
794
  fn=slider_changed,
795
  inputs=[show_trim_preview_btn, video_input, start_frame],
796
  outputs=[preview_image, preview_video],
797
- show_progress=False,
798
  )
799
 
800
  end_frame_event = end_frame.release(
801
  fn=slider_changed,
802
  inputs=[show_trim_preview_btn, video_input, end_frame],
803
  outputs=[preview_image, preview_video],
804
- show_progress=False,
805
  )
806
 
807
  input_type.change(
@@ -839,12 +770,18 @@ with gr.Blocks(css=css) as interface:
839
  swap_option,
840
  age,
841
  distance_slider,
842
- enable_face_enhance,
843
  enable_face_parser_mask,
844
  mask_include,
845
  mask_soft_kernel,
846
  mask_soft_iterations,
847
  blur_amount,
 
 
 
 
 
 
848
  *src_specific_inputs,
849
  ]
850
 
@@ -857,7 +794,7 @@ with gr.Blocks(css=css) as interface:
857
  ]
858
 
859
  swap_event = swap_button.click(
860
- fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=False
861
  )
862
 
863
  cancel_button.click(
@@ -871,7 +808,7 @@ with gr.Blocks(css=css) as interface:
871
  start_frame_event,
872
  end_frame_event,
873
  ],
874
- show_progress=False,
875
  )
876
  output_directory_button.click(
877
  lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
 
4
  import time
5
  import torch
6
  import shutil
 
7
  import argparse
8
  import platform
9
  import datetime
 
12
  import onnxruntime
13
  import numpy as np
14
  import gradio as gr
15
+ from tqdm import tqdm
16
+ from moviepy.editor import VideoFileClip
17
 
18
+ from nsfw_detector import get_nsfw_detector
19
+ from face_swapper import Inswapper, paste_to_whole
20
+ from face_analyser import detect_conditions, get_analysed_data, swap_options_list
21
+ from face_enhancer import load_face_enhancer_model, face_enhancer_list, gfpgan_enhance, realesrgan_enhance
22
  from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
23
+ from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref
24
+
 
 
 
 
25
 
26
  ## ------------------------------ USER ARGS ------------------------------
27
 
28
  parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
29
  parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
30
+ parser.add_argument("--batch_size", help="Gpu batch size", default=32)
31
  parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
32
  parser.add_argument(
33
  "--colab", action="store_true", help="Enable colab mode", default=False
 
39
  USE_COLAB = user_args.colab
40
  USE_CUDA = user_args.cuda
41
  DEF_OUTPUT_PATH = user_args.out_dir
42
+ BATCH_SIZE = user_args.batch_size
43
  WORKSPACE = None
44
  OUTPUT_FILE = None
45
  CURRENT_FRAME = None
46
  STREAMER = None
47
+ DETECT_CONDITION = "best detection"
48
  DETECT_SIZE = 640
49
  DETECT_THRESH = 0.6
50
  NUM_OF_SRC_SPECIFIC = 10
 
67
  FACE_ANALYSER = None
68
  FACE_ENHANCER = None
69
  FACE_PARSER = None
70
+ NSFW_DETECTOR = None
71
 
72
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
73
  # Note: For AMD,MAC or non CUDA users, change settings here
 
100
  )
101
 
102
 
103
+ def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
104
  global FACE_SWAPPER
 
105
  if FACE_SWAPPER is None:
106
+ batch = int(BATCH_SIZE) if device == "cuda" else 1
107
+ FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
 
 
 
 
 
 
108
 
109
 
110
+ def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
111
  global FACE_PARSER
 
112
  if FACE_PARSER is None:
113
+ FACE_PARSER = init_parser(path, mode=device)
114
+
115
+ def load_nsfw_detector_model(path="./assets/pretrained_models/nsfwmodel_281.pth"):
116
+ global NSFW_DETECTOR
117
+ if NSFW_DETECTOR is None:
118
+ NSFW_DETECTOR = get_nsfw_detector(model_path=path, device=device)
119
 
120
 
121
  load_face_analyser_model()
 
136
  condition,
137
  age,
138
  distance,
139
+ face_enhancer_name,
140
  enable_face_parser,
141
  mask_includes,
142
  mask_soft_kernel,
143
  mask_soft_iterations,
144
  blur_amount,
145
+ face_scale,
146
+ enable_laplacian_blend,
147
+ crop_top,
148
+ crop_bott,
149
+ crop_left,
150
+ crop_right,
151
  *specifics,
152
  ):
153
  global WORKSPACE
 
181
  gr.update(value=OUTPUT_FILE, visible=True),
182
  )
183
 
 
184
  start_time = time.time()
185
+ total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
186
+ get_finsh_text = lambda start_time: f"βœ”οΈ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
187
+
188
+ ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
189
+ yield "### \n βŒ› Loading NSFW detector model...", *ui_before()
190
+ load_nsfw_detector_model()
191
 
192
  yield "### \n βŒ› Loading face analyser model...", *ui_before()
193
  load_face_analyser_model()
 
195
  yield "### \n βŒ› Loading face swapper model...", *ui_before()
196
  load_face_swapper_model()
197
 
198
+ if face_enhancer_name != "NONE":
199
+ yield f"### \n βŒ› Loading {face_enhancer_name} model...", *ui_before()
200
+ FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
201
+ else:
202
+ FACE_ENHANCER = None
203
 
204
  if enable_face_parser:
205
  yield "### \n βŒ› Loading face parsing model...", *ui_before()
206
  load_face_parser_model()
207
 
 
 
208
  includes = mask_regions_to_list(mask_includes)
209
+ smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=int(mask_soft_iterations)).to(device) if mask_soft_iterations > 0 else None
210
+ specifics = list(specifics)
211
+ half = len(specifics) // 2
212
+ sources = specifics[:half]
213
+ specifics = specifics[half:]
214
+
215
+ ## ------------------------------ ANALYSE & SWAP FUNC ------------------------------
216
+
217
+ def swap_process(image_sequence):
218
+ yield "### \n βŒ› Checking contents...", *ui_before()
219
+ nsfw = NSFW_DETECTOR.is_nsfw(image_sequence)
220
+ if nsfw:
221
+ message = "NSFW Content detected !!!"
222
+ yield f"### \n πŸ”ž {message}", *ui_before()
223
+ assert not nsfw, message
224
+ return False
225
+ if device == "cuda": torch.cuda.empty_cache()
226
+
227
+ yield "### \n βŒ› Analysing face data...", *ui_before()
228
+ if condition != "Specific Face":
229
+ source_data = source_path, age
230
+ else:
231
+ source_data = ((sources, specifics), distance)
232
+ analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  FACE_ANALYSER,
234
+ image_sequence,
235
+ source_data,
236
+ swap_condition=condition,
237
  detect_condition=DETECT_CONDITION,
238
+ scale=face_scale
239
  )
240
 
241
+ yield "### \n βŒ› Swapping faces...", *ui_before()
242
+ preds, aimgs, matrs = FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources)
243
+ torch.cuda.empty_cache()
244
+
245
+ if enable_face_parser:
246
+ yield "### \n βŒ› Applying face-parsing mask...", *ui_before()
247
+ for idx, (pred, aimg) in tqdm(enumerate(zip(preds, aimgs)), total=len(preds), desc="Face parsing"):
248
+ preds[idx] = swap_regions(pred, aimg, FACE_PARSER, smooth_mask, includes=includes, blur=int(blur_amount))
249
+ torch.cuda.empty_cache()
250
+
251
+ if face_enhancer_name != "NONE":
252
+ yield f"### \n βŒ› Enhancing faces with {face_enhancer_name}...", *ui_before()
253
+ for idx, pred in tqdm(enumerate(preds), total=len(preds), desc=f"{face_enhancer_name}"):
254
+ if face_enhancer_name == 'GFPGAN':
255
+ pred = gfpgan_enhance(pred, FACE_ENHANCER)
256
+ elif face_enhancer_name.startswith("REAL-ESRGAN"):
257
+ pred = realesrgan_enhance(pred, FACE_ENHANCER)
258
+
259
+ preds[idx] = cv2.resize(pred, (512,512))
260
+ aimgs[idx] = cv2.resize(aimgs[idx], (512,512))
261
+ matrs[idx] /= 0.25
262
+ torch.cuda.empty_cache()
263
+
264
+ split_preds = split_list_by_lengths(preds, num_faces_per_frame)
265
+ split_aimgs = split_list_by_lengths(aimgs, num_faces_per_frame)
266
+ split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
267
+
268
+ yield "### \n βŒ› Post-processing...", *ui_before()
269
+ for idx, frame_img in tqdm(enumerate(image_sequence), total=len(image_sequence), desc="Post-Processing"):
270
+ whole_img_path = frame_img
271
+ whole_img = cv2.imread(whole_img_path)
272
+ for p, a, m in zip(split_preds[idx], split_aimgs[idx], split_matrs[idx]):
273
+ whole_img = paste_to_whole(p, a, m, whole_img, laplacian_blend=enable_laplacian_blend, crop_mask=(crop_top,crop_bott,crop_left,crop_right))
274
+ cv2.imwrite(whole_img_path, whole_img)
275
+
276
+
277
  ## ------------------------------ IMAGE ------------------------------
278
 
279
  if input_type == "Image":
280
  target = cv2.imread(image_path)
281
+ output_file = os.path.join(output_path, output_name + ".png")
282
+ cv2.imwrite(output_file, target)
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ for info_update in swap_process([output_file]):
285
+ yield info_update
 
 
 
286
 
287
+ OUTPUT_FILE = output_file
288
+ WORKSPACE = output_path
289
+ PREVIEW = cv2.imread(output_file)[:, :, ::-1]
290
 
291
+ yield get_finsh_text(start_time), *ui_after()
292
 
293
  ## ------------------------------ VIDEO ------------------------------
294
 
 
296
  temp_path = os.path.join(output_path, output_name, "sequence")
297
  os.makedirs(temp_path, exist_ok=True)
298
 
299
+ yield "### \n βŒ› Extracting video frames...", *ui_before()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  image_sequence = []
301
+ cap = cv2.VideoCapture(video_path)
302
+ curr_idx = 0
303
+ while True:
304
+ ret, frame = cap.read()
305
+ if not ret:break
306
+ frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
307
+ cv2.imwrite(frame_path, frame)
308
+ image_sequence.append(frame_path)
309
+ curr_idx += 1
310
+ cap.release()
311
+ cv2.destroyAllWindows()
312
+
313
+ for info_update in swap_process(image_sequence):
314
+ yield info_update
315
+
316
+ yield "### \n βŒ› Merging sequence...", *ui_before()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  output_video_path = os.path.join(output_path, output_name + ".mp4")
318
+ merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
 
 
 
 
319
 
320
  if os.path.exists(temp_path) and not keep_output_sequence:
321
  yield "### \n βŒ› Removing temporary files...", *ui_before()
 
324
  WORKSPACE = output_path
325
  OUTPUT_FILE = output_video_path
326
 
327
+ yield get_finsh_text(start_time), *ui_after_vid()
 
 
 
328
 
329
  ## ------------------------------ DIRECTORY ------------------------------
330
 
331
  elif input_type == "Directory":
 
 
 
 
 
 
 
332
  extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
333
  temp_path = os.path.join(output_path, output_name)
334
  if os.path.exists(temp_path):
335
  shutil.rmtree(temp_path)
336
  os.mkdir(temp_path)
 
337
 
338
+ file_paths =[]
339
  for file_path in glob.glob(os.path.join(directory_path, "*")):
340
  if any(file_path.lower().endswith(ext) for ext in extensions):
341
+ img = cv2.imread(file_path)
342
+ new_file_path = os.path.join(temp_path, os.path.basename(file_path))
343
+ cv2.imwrite(new_file_path, img)
344
+ file_paths.append(new_file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ for info_update in swap_process(file_paths):
347
+ yield info_update
 
 
 
348
 
349
+ PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
350
  WORKSPACE = temp_path
351
+ OUTPUT_FILE = file_paths[-1]
 
 
 
352
 
353
+ yield get_finsh_text(start_time), *ui_after()
354
 
355
  ## ------------------------------ STREAM ------------------------------
356
 
357
  elif input_type == "Stream":
358
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
 
361
  ## ------------------------------ GRADIO FUNC ------------------------------
 
537
  )
538
 
539
  with gr.Tab("πŸͺ„ Other Settings"):
 
 
 
 
540
  with gr.Accordion("Advanced Mask", open=False):
541
  enable_face_parser_mask = gr.Checkbox(
542
  label="Enable Face Parsing",
 
572
  interactive=True,
573
  )
574
 
575
+ face_scale = gr.Slider(
576
+ label="Face Scale",
577
+ minimum=0,
578
+ maximum=2,
579
+ value=1,
580
+ interactive=True,
581
+ )
582
+
583
+ with gr.Accordion("Crop Mask", open=False):
584
+ crop_top = gr.Number(label="Top", value=0, minimum=0, interactive=True)
585
+ crop_bott = gr.Number(label="Bottom", value=0, minimum=0, interactive=True)
586
+ crop_left = gr.Number(label="Left", value=0, minimum=0, interactive=True)
587
+ crop_right = gr.Number(label="Right", value=0, minimum=0, interactive=True)
588
+
589
+ enable_laplacian_blend = gr.Checkbox(
590
+ label="Laplacian Blending",
591
+ value=True,
592
+ interactive=True,
593
+ )
594
+
595
+ face_enhancer_name = gr.Dropdown(
596
+ face_enhancer_list, label="Face Enhancer", value="NONE", multiselect=False, interactive=True
597
+ )
598
+
599
  source_image_input = gr.Image(
600
  label="Source face", type="filepath", interactive=True
601
  )
 
621
 
622
  with gr.Group():
623
  input_type = gr.Radio(
624
+ ["Image", "Video"],
625
  label="Target Type",
626
  value="Video",
627
  )
 
632
  )
633
 
634
  with gr.Box(visible=True) as input_video_group:
635
+ vid_widget = gr.Video if USE_COLAB else gr.Text
636
  video_input = vid_widget(
637
  label="Target Video Path", interactive=True
638
  )
 
725
  fn=slider_changed,
726
  inputs=[show_trim_preview_btn, video_input, start_frame],
727
  outputs=[preview_image, preview_video],
728
+ show_progress=True,
729
  )
730
 
731
  end_frame_event = end_frame.release(
732
  fn=slider_changed,
733
  inputs=[show_trim_preview_btn, video_input, end_frame],
734
  outputs=[preview_image, preview_video],
735
+ show_progress=True,
736
  )
737
 
738
  input_type.change(
 
770
  swap_option,
771
  age,
772
  distance_slider,
773
+ face_enhancer_name,
774
  enable_face_parser_mask,
775
  mask_include,
776
  mask_soft_kernel,
777
  mask_soft_iterations,
778
  blur_amount,
779
+ face_scale,
780
+ enable_laplacian_blend,
781
+ crop_top,
782
+ crop_bott,
783
+ crop_left,
784
+ crop_right,
785
  *src_specific_inputs,
786
  ]
787
 
 
794
  ]
795
 
796
  swap_event = swap_button.click(
797
+ fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
798
  )
799
 
800
  cancel_button.click(
 
808
  start_frame_event,
809
  end_frame_event,
810
  ],
811
+ show_progress=True,
812
  )
813
  output_directory_button.click(
814
  lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
assets/pretrained_models/RealESRGAN_x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899
3
+ size 67061725
assets/pretrained_models/RealESRGAN_x4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa00f09ad753d88576b21ed977e97d634976377031b178acc3b5b238df463400
3
+ size 67040989
assets/pretrained_models/RealESRGAN_x8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b72fb469d12f05a4770813d2603eb1b550f40df6fb8b37d6c7bc2db3d2bff5e
3
+ size 67189359
assets/pretrained_models/codeformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1009e537e0c2a07d4cabce6355f53cb66767cd4b4297ec7a4a64ca4b8a5684b7
3
+ size 376637898
assets/pretrained_models/nsfwmodel_281.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac92f5326f0d83f24f51ba4ac9f2a79314d29199e900a8ea495a74816ad3eb67
3
+ size 4925
face_analyser.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  detect_conditions = [
2
  "left most",
3
  "right most",
@@ -5,11 +11,27 @@ detect_conditions = [
5
  "bottom most",
6
  "most width",
7
  "most height",
 
8
  ]
9
 
 
 
 
 
 
 
 
 
10
 
11
- def analyse_face(image, model, return_single_face=True, detect_condition="left most"):
12
  faces = model.get(image)
 
 
 
 
 
 
 
13
  if not return_single_face:
14
  return faces
15
 
@@ -30,3 +52,79 @@ def analyse_face(image, model, return_single_face=True, detect_condition="left m
30
  return sorted(faces, key=lambda face: face["bbox"][2])[-1]
31
  elif detect_condition == "most height":
32
  return sorted(faces, key=lambda face: face["bbox"][3])[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from utils import scale_bbox_from_center
6
+
7
  detect_conditions = [
8
  "left most",
9
  "right most",
 
11
  "bottom most",
12
  "most width",
13
  "most height",
14
+ "best detection",
15
  ]
16
 
17
+ swap_options_list = [
18
+ "All face",
19
+ "Age less than",
20
+ "Age greater than",
21
+ "All Male",
22
+ "All Female",
23
+ "Specific Face",
24
+ ]
25
 
26
+ def analyse_face(image, model, return_single_face=True, detect_condition="best detection", scale=1.0):
27
  faces = model.get(image)
28
+ if scale != 1: # landmark-scale
29
+ for i, face in enumerate(faces):
30
+ landmark = face['kps']
31
+ center = np.mean(landmark, axis=0)
32
+ landmark = center + (landmark - center) * scale
33
+ faces[i]['kps'] = landmark
34
+
35
  if not return_single_face:
36
  return faces
37
 
 
52
  return sorted(faces, key=lambda face: face["bbox"][2])[-1]
53
  elif detect_condition == "most height":
54
  return sorted(faces, key=lambda face: face["bbox"][3])[-1]
55
+ elif detect_condition == "best detection":
56
+ return sorted(faces, key=lambda face: face["det_score"])[-1]
57
+
58
+
59
+ def cosine_distance(a, b):
60
+ a /= np.linalg.norm(a)
61
+ b /= np.linalg.norm(b)
62
+ return 1 - np.dot(a, b)
63
+
64
+
65
+ def get_analysed_data(face_analyser, image_sequence, source_data, swap_condition="All face", detect_condition="left most", scale=1.0):
66
+ if swap_condition != "Specific Face":
67
+ source_path, age = source_data
68
+ source_image = cv2.imread(source_path)
69
+ analysed_source = analyse_face(source_image, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
70
+ else:
71
+ analysed_source_specifics = []
72
+ source_specifics, threshold = source_data
73
+ for source, specific in zip(*source_specifics):
74
+ if source is None or specific is None:
75
+ continue
76
+ analysed_source = analyse_face(source, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
77
+ analysed_specific = analyse_face(specific, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
78
+ analysed_source_specifics.append([analysed_source, analysed_specific])
79
+
80
+ analysed_target_list = []
81
+ analysed_source_list = []
82
+ whole_frame_eql_list = []
83
+ num_faces_per_frame = []
84
+
85
+ total_frames = len(image_sequence)
86
+ curr_idx = 0
87
+ for curr_idx, frame_path in tqdm(enumerate(image_sequence), total=total_frames, desc="Analysing face data"):
88
+ frame = cv2.imread(frame_path)
89
+ analysed_faces = analyse_face(frame, face_analyser, return_single_face=False, detect_condition=detect_condition, scale=scale)
90
+
91
+ n_faces = 0
92
+ for analysed_face in analysed_faces:
93
+ if swap_condition == "All face":
94
+ analysed_target_list.append(analysed_face)
95
+ analysed_source_list.append(analysed_source)
96
+ whole_frame_eql_list.append(frame_path)
97
+ n_faces += 1
98
+ elif swap_condition == "Age less than" and analysed_face["age"] < age:
99
+ analysed_target_list.append(analysed_face)
100
+ analysed_source_list.append(analysed_source)
101
+ whole_frame_eql_list.append(frame_path)
102
+ n_faces += 1
103
+ elif swap_condition == "Age greater than" and analysed_face["age"] > age:
104
+ analysed_target_list.append(analysed_face)
105
+ analysed_source_list.append(analysed_source)
106
+ whole_frame_eql_list.append(frame_path)
107
+ n_faces += 1
108
+ elif swap_condition == "All Male" and analysed_face["gender"] == 1:
109
+ analysed_target_list.append(analysed_face)
110
+ analysed_source_list.append(analysed_source)
111
+ whole_frame_eql_list.append(frame_path)
112
+ n_faces += 1
113
+ elif swap_condition == "All Female" and analysed_face["gender"] == 0:
114
+ analysed_target_list.append(analysed_face)
115
+ analysed_source_list.append(analysed_source)
116
+ whole_frame_eql_list.append(frame_path)
117
+ n_faces += 1
118
+ elif swap_condition == "Specific Face":
119
+ for analysed_source, analysed_specific in analysed_source_specifics:
120
+ distance = cosine_distance(analysed_specific["embedding"], analysed_face["embedding"])
121
+ if distance < threshold:
122
+ analysed_target_list.append(analysed_face)
123
+ analysed_source_list.append(analysed_source)
124
+ whole_frame_eql_list.append(frame_path)
125
+ n_faces += 1
126
+
127
+ num_faces_per_frame.append(n_faces)
128
+
129
+ return analysed_target_list, analysed_source_list, whole_frame_eql_list, num_faces_per_frame
130
+
face_enhancer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gfpgan
4
+ from PIL import Image
5
+ from upscaler.RealESRGAN import RealESRGAN
6
+
7
+ face_enhancer_list = ['NONE', 'GFPGAN', 'REAL-ESRGAN 2x', 'REAL-ESRGAN 4x', 'REAL-ESRGAN 8x']
8
+
9
+ def load_face_enhancer_model(name='GFPGAN', device="cpu"):
10
+ if name == 'GFPGAN':
11
+ model_path = "./assets/pretrained_models/GFPGANv1.4.pth"
12
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
13
+ model = gfpgan.GFPGANer(model_path=model_path, upscale=1)
14
+ elif name == 'REAL-ESRGAN 2x':
15
+ model_path = "./assets/pretrained_models/RealESRGAN_x2.pth"
16
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
17
+ model = RealESRGAN(device, scale=2)
18
+ model.load_weights(model_path, download=False)
19
+ elif name == 'REAL-ESRGAN 4x':
20
+ model_path = "./assets/pretrained_models/RealESRGAN_x4.pth"
21
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
22
+ model = RealESRGAN(device, scale=4)
23
+ model.load_weights(model_path, download=False)
24
+ elif name == 'REAL-ESRGAN 8x':
25
+ model_path = "./assets/pretrained_models/RealESRGAN_x8.pth"
26
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
27
+ model = RealESRGAN(device, scale=8)
28
+ model.load_weights(model_path, download=False)
29
+ else:
30
+ model = None
31
+ return model
32
+
33
+ def gfpgan_enhance(img, model, has_aligned=True):
34
+ _, imgs, _ = model.enhance(img, paste_back=True, has_aligned=has_aligned)
35
+ return imgs[0]
36
+
37
+ def realesrgan_enhance(img, model):
38
+ img = model.predict(img)
39
+ return img
face_parsing/__init__.py CHANGED
@@ -1 +1,3 @@
1
- from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
 
 
 
1
+ from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
2
+ from .model import BiSeNet
3
+ from .parse_mask import init_parsing_model, get_parsed_mask
face_parsing/parse_mask.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torchvision
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as transforms
10
+
11
+ from . model import BiSeNet
12
+
13
+ transform = transforms.Compose([
14
+ transforms.Resize((512, 512)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
17
+ ])
18
+
19
+ def init_parsing_model(model_path, device="cpu"):
20
+ net = BiSeNet(19)
21
+ net.to(device)
22
+ net.load_state_dict(torch.load(model_path))
23
+ net.eval()
24
+ return net
25
+
26
+ def transform_images(imgs):
27
+ tensor_images = torch.stack([transform(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) for img in imgs], dim=0)
28
+ return tensor_images
29
+
30
+ def get_parsed_mask(net, imgs, classes=[1, 2, 3, 4, 5, 10, 11, 12, 13], device="cpu", batch_size=8):
31
+ masks = []
32
+ for i in tqdm(range(0, len(imgs), batch_size), total=len(imgs) // batch_size, desc="Face-parsing"):
33
+ batch_imgs = imgs[i:i + batch_size]
34
+
35
+ tensor_images = transform_images(batch_imgs).to(device)
36
+ with torch.no_grad():
37
+ out = net(tensor_images)[0]
38
+ parsing = out.argmax(dim=1).cpu().numpy()
39
+ batch_masks = np.isin(parsing, classes)
40
+
41
+ masks.append(batch_masks)
42
+
43
+ masks = np.concatenate(masks, axis=0)
44
+ # masks = np.repeat(np.expand_dims(masks, axis=1), 3, axis=1)
45
+
46
+ for i, mask in enumerate(masks):
47
+ cv2.imwrite(f"mask/{i}.jpg", (mask * 255).astype("uint8"))
48
+
49
+ return masks
50
+
face_parsing/swap.py CHANGED
@@ -98,6 +98,7 @@ def get_mask(parsing, classes):
98
  res += parsing == val
99
  return res
100
 
 
101
  def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
102
  parsing = image_to_parsing(source, net)
103
 
@@ -117,12 +118,10 @@ def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,
117
  if blur > 0:
118
  mask = cv2.GaussianBlur(mask, (0, 0), blur)
119
 
120
- resized_source = cv2.resize((source/255).astype("float32"), (512, 512))
121
- resized_target = cv2.resize((target/255).astype("float32"), (512, 512))
122
-
123
  result = mask * resized_source + (1 - mask) * resized_target
124
- normalized_result = (result - np.min(result)) / (np.max(result) - np.min(result))
125
- result = cv2.resize((result*255).astype("uint8"), (source.shape[1], source.shape[0]))
126
 
127
  return result
128
 
 
98
  res += parsing == val
99
  return res
100
 
101
+
102
  def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
103
  parsing = image_to_parsing(source, net)
104
 
 
118
  if blur > 0:
119
  mask = cv2.GaussianBlur(mask, (0, 0), blur)
120
 
121
+ resized_source = cv2.resize((source).astype("float32"), (512, 512))
122
+ resized_target = cv2.resize((target).astype("float32"), (512, 512))
 
123
  result = mask * resized_source + (1 - mask) * resized_target
124
+ result = cv2.resize(result.astype("uint8"), (source.shape[1], source.shape[0]))
 
125
 
126
  return result
127
 
face_swapper.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import onnx
4
+ import cv2
5
+ import onnxruntime
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from onnx import numpy_helper
9
+ from skimage import transform as trans
10
+
11
+ arcface_dst = np.array(
12
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
13
+ [41.5493, 92.3655], [70.7299, 92.2041]],
14
+ dtype=np.float32)
15
+
16
+ def estimate_norm(lmk, image_size=112, mode='arcface'):
17
+ assert lmk.shape == (5, 2)
18
+ assert image_size % 112 == 0 or image_size % 128 == 0
19
+ if image_size % 112 == 0:
20
+ ratio = float(image_size) / 112.0
21
+ diff_x = 0
22
+ else:
23
+ ratio = float(image_size) / 128.0
24
+ diff_x = 8.0 * ratio
25
+ dst = arcface_dst * ratio
26
+ dst[:, 0] += diff_x
27
+ tform = trans.SimilarityTransform()
28
+ tform.estimate(lmk, dst)
29
+ M = tform.params[0:2, :]
30
+ return M
31
+
32
+
33
+ def norm_crop2(img, landmark, image_size=112, mode='arcface'):
34
+ M = estimate_norm(landmark, image_size, mode)
35
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
36
+ return warped, M
37
+
38
+
39
+ class Inswapper():
40
+ def __init__(self, model_file=None, batch_size=32, providers=['CPUExecutionProvider']):
41
+ self.model_file = model_file
42
+ self.batch_size = batch_size
43
+
44
+ model = onnx.load(self.model_file)
45
+ graph = model.graph
46
+ self.emap = numpy_helper.to_array(graph.initializer[-1])
47
+ self.input_mean = 0.0
48
+ self.input_std = 255.0
49
+
50
+ self.session_options = onnxruntime.SessionOptions()
51
+ self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=providers)
52
+
53
+ inputs = self.session.get_inputs()
54
+ self.input_names = [inp.name for inp in inputs]
55
+ outputs = self.session.get_outputs()
56
+ self.output_names = [out.name for out in outputs]
57
+ assert len(self.output_names) == 1
58
+ self.output_shape = outputs[0].shape
59
+ input_cfg = inputs[0]
60
+ input_shape = input_cfg.shape
61
+ self.input_shape = input_shape
62
+ self.input_size = tuple(input_shape[2:4][::-1])
63
+
64
+ def forward(self, imgs, latents):
65
+ batch_preds = []
66
+ for img, latent in zip(imgs, latents):
67
+ img = (img - self.input_mean) / self.input_std
68
+ pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
69
+ batch_preds.append(pred)
70
+ return batch_preds
71
+
72
+ def get(self, imgs, target_faces, source_faces):
73
+ batch_preds = []
74
+ batch_aimgs = []
75
+ batch_ms = []
76
+ for img, target_face, source_face in zip(imgs, target_faces, source_faces):
77
+ if isinstance(img, str):
78
+ img = cv2.imread(img)
79
+ aimg, M = norm_crop2(img, target_face.kps, self.input_size[0])
80
+ blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
81
+ (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
82
+ latent = source_face.normed_embedding.reshape((1, -1))
83
+ latent = np.dot(latent, self.emap)
84
+ latent /= np.linalg.norm(latent)
85
+ pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
86
+ pred = pred.transpose((0, 2, 3, 1))[0]
87
+ pred = np.clip(255 * pred, 0, 255).astype(np.uint8)[:, :, ::-1]
88
+ batch_preds.append(pred)
89
+ batch_aimgs.append(aimg)
90
+ batch_ms.append(M)
91
+ return batch_preds, batch_aimgs, batch_ms
92
+
93
+ def batch_forward(self, img_list, target_f_list, source_f_list):
94
+ num_samples = len(img_list)
95
+ num_batches = (num_samples + self.batch_size - 1) // self.batch_size
96
+
97
+ preds = []
98
+ aimgs = []
99
+ ms = []
100
+ for i in tqdm(range(num_batches), desc="Swapping face by batch"):
101
+ start_idx = i * self.batch_size
102
+ end_idx = min((i + 1) * self.batch_size, num_samples)
103
+
104
+ batch_img = img_list[start_idx:end_idx]
105
+ batch_target_f = target_f_list[start_idx:end_idx]
106
+ batch_source_f = source_f_list[start_idx:end_idx]
107
+
108
+ batch_pred, batch_aimg, batch_m = self.get(batch_img, batch_target_f, batch_source_f)
109
+ preds.extend(batch_pred)
110
+ aimgs.extend(batch_aimg)
111
+ ms.extend(batch_m)
112
+ return preds, aimgs, ms
113
+
114
+
115
+ def laplacian_blending(A, B, m, num_levels=4):
116
+ assert A.shape == B.shape
117
+ assert B.shape == m.shape
118
+ height = m.shape[0]
119
+ width = m.shape[1]
120
+ size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096])
121
+ size = size_list[np.where(size_list > max(height, width))][0]
122
+ GA = np.zeros((size, size, 3), dtype=np.float32)
123
+ GA[:height, :width, :] = A
124
+ GB = np.zeros((size, size, 3), dtype=np.float32)
125
+ GB[:height, :width, :] = B
126
+ GM = np.zeros((size, size, 3), dtype=np.float32)
127
+ GM[:height, :width, :] = m
128
+ gpA = [GA]
129
+ gpB = [GB]
130
+ gpM = [GM]
131
+ for i in range(num_levels):
132
+ GA = cv2.pyrDown(GA)
133
+ GB = cv2.pyrDown(GB)
134
+ GM = cv2.pyrDown(GM)
135
+ gpA.append(np.float32(GA))
136
+ gpB.append(np.float32(GB))
137
+ gpM.append(np.float32(GM))
138
+ lpA = [gpA[num_levels-1]]
139
+ lpB = [gpB[num_levels-1]]
140
+ gpMr = [gpM[num_levels-1]]
141
+ for i in range(num_levels-1,0,-1):
142
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
143
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
144
+ lpA.append(LA)
145
+ lpB.append(LB)
146
+ gpMr.append(gpM[i-1])
147
+ LS = []
148
+ for la,lb,gm in zip(lpA,lpB,gpMr):
149
+ ls = la * gm + lb * (1.0 - gm)
150
+ LS.append(ls)
151
+ ls_ = LS[0]
152
+ for i in range(1,num_levels):
153
+ ls_ = cv2.pyrUp(ls_)
154
+ ls_ = cv2.add(ls_, LS[i])
155
+ ls_ = np.clip(ls_[:height, :width, :], 0, 255)
156
+ return ls_
157
+
158
+
159
+ def paste_to_whole(bgr_fake, aimg, M, whole_img, laplacian_blend=True, crop_mask=(0,0,0,0)):
160
+ IM = cv2.invertAffineTransform(M)
161
+
162
+ img_white = np.full((aimg.shape[0], aimg.shape[1]), 255, dtype=np.float32)
163
+
164
+ top = int(crop_mask[0])
165
+ bottom = int(crop_mask[1])
166
+ if top + bottom < aimg.shape[1]:
167
+ if top > 0: img_white[:top, :] = 0
168
+ if bottom > 0: img_white[-bottom:, :] = 0
169
+
170
+ left = int(crop_mask[2])
171
+ right = int(crop_mask[3])
172
+ if left + right < aimg.shape[0]:
173
+ if left > 0: img_white[:, :left] = 0
174
+ if right > 0: img_white[:, -right:] = 0
175
+
176
+ bgr_fake = cv2.warpAffine(
177
+ bgr_fake, IM, (whole_img.shape[1], whole_img.shape[0]), borderValue=0.0
178
+ )
179
+ img_white = cv2.warpAffine(
180
+ img_white, IM, (whole_img.shape[1], whole_img.shape[0]), borderValue=0.0
181
+ )
182
+ img_white[img_white > 20] = 255
183
+ img_mask = img_white
184
+ mask_h_inds, mask_w_inds = np.where(img_mask == 255)
185
+ mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
186
+ mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
187
+ mask_size = int(np.sqrt(mask_h * mask_w))
188
+
189
+ k = max(mask_size // 10, 10)
190
+ img_mask = cv2.erode(img_mask, np.ones((k, k), np.uint8), iterations=1)
191
+
192
+ k = max(mask_size // 20, 5)
193
+ kernel_size = (k, k)
194
+ blur_size = tuple(2 * i + 1 for i in kernel_size)
195
+ img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) / 255
196
+ img_mask = np.tile(np.expand_dims(img_mask, axis=-1), (1, 1, 3))
197
+
198
+ if laplacian_blend:
199
+ bgr_fake = laplacian_blending(bgr_fake.astype("float32").clip(0,255), whole_img.astype("float32").clip(0,255), img_mask.clip(0,1))
200
+ bgr_fake = bgr_fake.astype("float32")
201
+
202
+ fake_merged = img_mask * bgr_fake + (1 - img_mask) * whole_img.astype(np.float32)
203
+ return fake_merged.astype("uint8")
nsfw_detector.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import Normalize
2
+ import torchvision.transforms as T
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import timm
8
+ from tqdm import tqdm
9
+
10
+ normalize_t = Normalize((0.4814, 0.4578, 0.4082), (0.2686, 0.2613, 0.2757))
11
+
12
+ #nsfw classifier
13
+ class NSFWClassifier(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ nsfw_model=self
17
+ nsfw_model.root_model = timm.create_model('convnext_base_in22ft1k', pretrained=True)
18
+ nsfw_model.linear_probe = nn.Linear(1024, 1, bias=False)
19
+
20
+ def forward(self, x):
21
+ nsfw_model = self
22
+ x = normalize_t(x)
23
+ x = nsfw_model.root_model.stem(x)
24
+ x = nsfw_model.root_model.stages(x)
25
+ x = nsfw_model.root_model.head.global_pool(x)
26
+ x = nsfw_model.root_model.head.norm(x)
27
+ x = nsfw_model.root_model.head.flatten(x)
28
+ x = nsfw_model.linear_probe(x)
29
+ return x
30
+
31
+ def is_nsfw(self, img_paths, threshold = 0.93):
32
+ skip_step = 1
33
+ total_len = len(img_paths)
34
+ if total_len < 100: skip_step = 1
35
+ if total_len > 100 and total_len < 500: skip_step = 10
36
+ if total_len > 500 and total_len < 1000: skip_step = 20
37
+ if total_len > 1000 and total_len < 10000: skip_step = 50
38
+ if total_len > 10000: skip_step = 100
39
+
40
+ for idx in tqdm(range(0, total_len, skip_step), total=total_len, desc="Checking for NSFW contents"):
41
+ img = Image.open(img_paths[idx]).convert('RGB')
42
+ img = img.resize((224, 224))
43
+ img = np.array(img)/255
44
+ img = T.ToTensor()(img).unsqueeze(0).float()
45
+ if next(self.parameters()).is_cuda:
46
+ img = img.cuda()
47
+ with torch.no_grad():
48
+ score = self.forward(img).sigmoid()[0].item()
49
+ if score > threshold:return True
50
+ return False
51
+
52
+ def get_nsfw_detector(model_path='nsfwmodel_281.pth', device="cpu"):
53
+ #load base model
54
+ nsfw_model = NSFWClassifier()
55
+ nsfw_model = nsfw_model.eval()
56
+ #load linear weights
57
+ linear_pth = model_path
58
+ linear_state_dict = torch.load(linear_pth, map_location='cpu')
59
+ nsfw_model.linear_probe.load_state_dict(linear_state_dict)
60
+ nsfw_model = nsfw_model.to(device)
61
+ return nsfw_model
requirements.txt CHANGED
@@ -4,8 +4,10 @@ gradio>=3.33.1
4
  insightface==0.7.3
5
  moviepy>=1.0.3
6
  numpy
7
- opencv-python>=4.7.0.72
8
- opencv-python-headless>=4.7.0.72
9
  onnx==1.14.0
10
  onnxruntime==1.15.0
 
 
11
  gfpgan==1.3.8
 
 
 
4
  insightface==0.7.3
5
  moviepy>=1.0.3
6
  numpy
 
 
7
  onnx==1.14.0
8
  onnxruntime==1.15.0
9
+ opencv-python>=4.7.0.72
10
+ opencv-python-headless>=4.7.0.72
11
  gfpgan==1.3.8
12
+ timm==0.9.2
13
+
upscaler/RealESRGAN/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import RealESRGAN
upscaler/RealESRGAN/arch_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
+ super(Upsample, self).__init__(*m)
106
+
107
+
108
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
+ """Warp an image or feature map with optical flow.
110
+
111
+ Args:
112
+ x (Tensor): Tensor with size (n, c, h, w).
113
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
+ Default: 'zeros'.
117
+ align_corners (bool): Before pytorch 1.3, the default value is
118
+ align_corners=True. After pytorch 1.3, the default value is
119
+ align_corners=False. Here, we use the True as default.
120
+
121
+ Returns:
122
+ Tensor: Warped image or feature map.
123
+ """
124
+ assert x.size()[-2:] == flow.size()[1:3]
125
+ _, _, h, w = x.size()
126
+ # create mesh grid
127
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
+ grid.requires_grad = False
130
+
131
+ vgrid = grid + flow
132
+ # scale grid to [-1,1]
133
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
+
138
+ # TODO, what if align_corners=False
139
+ return output
140
+
141
+
142
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
+ """Resize a flow according to ratio or shape.
144
+
145
+ Args:
146
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
+ size_type (str): 'ratio' or 'shape'.
148
+ sizes (list[int | float]): the ratio for resizing or the final output
149
+ shape.
150
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
151
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
+ ratio > 1.0).
154
+ 2) The order of output_size should be [out_h, out_w].
155
+ interp_mode (str): The mode of interpolation for resizing.
156
+ Default: 'bilinear'.
157
+ align_corners (bool): Whether align corners. Default: False.
158
+
159
+ Returns:
160
+ Tensor: Resized flow.
161
+ """
162
+ _, _, flow_h, flow_w = flow.size()
163
+ if size_type == 'ratio':
164
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
+ elif size_type == 'shape':
166
+ output_h, output_w = sizes[0], sizes[1]
167
+ else:
168
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
+
170
+ input_flow = flow.clone()
171
+ ratio_h = output_h / flow_h
172
+ ratio_w = output_w / flow_w
173
+ input_flow[:, 0, :, :] *= ratio_w
174
+ input_flow[:, 1, :, :] *= ratio_h
175
+ resized_flow = F.interpolate(
176
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
+ return resized_flow
178
+
179
+
180
+ # TODO: may write a cpp file
181
+ def pixel_unshuffle(x, scale):
182
+ """ Pixel unshuffle.
183
+
184
+ Args:
185
+ x (Tensor): Input feature with shape (b, c, hh, hw).
186
+ scale (int): Downsample ratio.
187
+
188
+ Returns:
189
+ Tensor: the pixel unshuffled feature.
190
+ """
191
+ b, c, hh, hw = x.size()
192
+ out_channel = c * (scale**2)
193
+ assert hh % scale == 0 and hw % scale == 0
194
+ h = hh // scale
195
+ w = hw // scale
196
+ x_view = x.view(b, c, h, scale, w, scale)
197
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
upscaler/RealESRGAN/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .rrdbnet_arch import RRDBNet
9
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
10
+ unpad_image
11
+
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ from huggingface_hub import hf_hub_url, cached_download
41
+ assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
42
+ config = HF_MODELS[self.scale]
43
+ cache_dir = os.path.dirname(model_path)
44
+ local_filename = os.path.basename(model_path)
45
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
46
+ cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
47
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
48
+
49
+ loadnet = torch.load(model_path)
50
+ if 'params' in loadnet:
51
+ self.model.load_state_dict(loadnet['params'], strict=True)
52
+ elif 'params_ema' in loadnet:
53
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
54
+ else:
55
+ self.model.load_state_dict(loadnet, strict=True)
56
+ self.model.eval()
57
+ self.model.to(self.device)
58
+
59
+ @torch.cuda.amp.autocast()
60
+ def predict(self, lr_image, batch_size=4, patches_size=192,
61
+ padding=24, pad_size=15):
62
+ scale = self.scale
63
+ device = self.device
64
+ lr_image = np.array(lr_image)
65
+ lr_image = pad_reflect(lr_image, pad_size)
66
+
67
+ patches, p_shape = split_image_into_overlapping_patches(
68
+ lr_image, patch_size=patches_size, padding_size=padding
69
+ )
70
+ img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
71
+
72
+ with torch.no_grad():
73
+ res = self.model(img[0:batch_size])
74
+ for i in range(batch_size, img.shape[0], batch_size):
75
+ res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
76
+
77
+ sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
78
+ np_sr_image = sr_image.numpy()
79
+
80
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
81
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
82
+ np_sr_image = stich_together(
83
+ np_sr_image, padded_image_shape=padded_size_scaled,
84
+ target_shape=scaled_image_shape, padding_size=padding * scale
85
+ )
86
+ sr_img = (np_sr_image*255).astype(np.uint8)
87
+ sr_img = unpad_image(sr_img, pad_size*scale)
88
+ #sr_img = Image.fromarray(sr_img)
89
+
90
+ return sr_img
upscaler/RealESRGAN/rrdbnet_arch.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ if scale == 8:
99
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ if self.scale == 8:
119
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
+ return out
upscaler/RealESRGAN/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import io
6
+
7
+ def pad_reflect(image, pad_size):
8
+ imsize = image.shape
9
+ height, width = imsize[:2]
10
+ new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
11
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
12
+
13
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
14
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
15
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
16
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
17
+
18
+ return new_img
19
+
20
+ def unpad_image(image, pad_size):
21
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
22
+
23
+
24
+ def process_array(image_array, expand=True):
25
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
26
+
27
+ image_batch = image_array / 255.0
28
+ if expand:
29
+ image_batch = np.expand_dims(image_batch, axis=0)
30
+ return image_batch
31
+
32
+
33
+ def process_output(output_tensor):
34
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
35
+
36
+ sr_img = output_tensor.clip(0, 1) * 255
37
+ sr_img = np.uint8(sr_img)
38
+ return sr_img
39
+
40
+
41
+ def pad_patch(image_patch, padding_size, channel_last=True):
42
+ """ Pads image_patch with with padding_size edge values. """
43
+
44
+ if channel_last:
45
+ return np.pad(
46
+ image_patch,
47
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
48
+ 'edge',
49
+ )
50
+ else:
51
+ return np.pad(
52
+ image_patch,
53
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
54
+ 'edge',
55
+ )
56
+
57
+
58
+ def unpad_patches(image_patches, padding_size):
59
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
60
+
61
+
62
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
63
+ """ Splits the image into partially overlapping patches.
64
+ The patches overlap by padding_size pixels.
65
+ Pads the image twice:
66
+ - first to have a size multiple of the patch size,
67
+ - then to have equal padding at the borders.
68
+ Args:
69
+ image_array: numpy array of the input image.
70
+ patch_size: size of the patches from the original image (without padding).
71
+ padding_size: size of the overlapping area.
72
+ """
73
+
74
+ xmax, ymax, _ = image_array.shape
75
+ x_remainder = xmax % patch_size
76
+ y_remainder = ymax % patch_size
77
+
78
+ # modulo here is to avoid extending of patch_size instead of 0
79
+ x_extend = (patch_size - x_remainder) % patch_size
80
+ y_extend = (patch_size - y_remainder) % patch_size
81
+
82
+ # make sure the image is divisible into regular patches
83
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
84
+
85
+ # add padding around the image to simplify computations
86
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
87
+
88
+ xmax, ymax, _ = padded_image.shape
89
+ patches = []
90
+
91
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
92
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
93
+
94
+ for x in x_lefts:
95
+ for y in y_tops:
96
+ x_left = x - padding_size
97
+ y_top = y - padding_size
98
+ x_right = x + patch_size + padding_size
99
+ y_bottom = y + patch_size + padding_size
100
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
101
+ patches.append(patch)
102
+
103
+ return np.array(patches), padded_image.shape
104
+
105
+
106
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
107
+ """ Reconstruct the image from overlapping patches.
108
+ After scaling, shapes and padding should be scaled too.
109
+ Args:
110
+ patches: patches obtained with split_image_into_overlapping_patches
111
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
112
+ target_shape: shape of the final image
113
+ padding_size: size of the overlapping area.
114
+ """
115
+
116
+ xmax, ymax, _ = padded_image_shape
117
+ patches = unpad_patches(patches, padding_size)
118
+ patch_size = patches.shape[1]
119
+ n_patches_per_row = ymax // patch_size
120
+
121
+ complete_image = np.zeros((xmax, ymax, 3))
122
+
123
+ row = -1
124
+ col = 0
125
+ for i in range(len(patches)):
126
+ if i % n_patches_per_row == 0:
127
+ row += 1
128
+ col = 0
129
+ complete_image[
130
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
131
+ ] = patches[i]
132
+ col += 1
133
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
upscaler/__init__.py ADDED
File without changes
utils.py CHANGED
@@ -110,3 +110,60 @@ def add_logo_to_image(img, logo=logo_image):
110
  roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
111
  ]
112
  return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
111
  ]
112
  return img
113
+
114
+ def split_list_by_lengths(data, length_list):
115
+ split_data = []
116
+ start_idx = 0
117
+ for length in length_list:
118
+ end_idx = start_idx + length
119
+ sublist = data[start_idx:end_idx]
120
+ split_data.append(sublist)
121
+ start_idx = end_idx
122
+ return split_data
123
+
124
+ def merge_img_sequence_from_ref(ref_video_path, image_sequence, output_file_name):
125
+ video_clip = VideoFileClip(ref_video_path)
126
+ fps = video_clip.fps
127
+ duration = video_clip.duration
128
+ total_frames = video_clip.reader.nframes
129
+ audio_clip = video_clip.audio if video_clip.audio is not None else None
130
+ edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
131
+
132
+ if audio_clip is not None:
133
+ edited_video_clip = edited_video_clip.set_audio(audio_clip)
134
+
135
+ edited_video_clip.set_duration(duration).write_videofile(
136
+ output_file_name, codec="libx264"
137
+ )
138
+ edited_video_clip.close()
139
+ video_clip.close()
140
+
141
+ def scale_bbox_from_center(bbox, scale_width, scale_height, image_width, image_height):
142
+ # Extract the coordinates of the bbox
143
+ x1, y1, x2, y2 = bbox
144
+
145
+ # Calculate the center point of the bbox
146
+ center_x = (x1 + x2) / 2
147
+ center_y = (y1 + y2) / 2
148
+
149
+ # Calculate the new width and height of the bbox based on the scaling factors
150
+ width = x2 - x1
151
+ height = y2 - y1
152
+ new_width = width * scale_width
153
+ new_height = height * scale_height
154
+
155
+ # Calculate the new coordinates of the bbox, considering the image boundaries
156
+ new_x1 = center_x - new_width / 2
157
+ new_y1 = center_y - new_height / 2
158
+ new_x2 = center_x + new_width / 2
159
+ new_y2 = center_y + new_height / 2
160
+
161
+ # Adjust the coordinates to ensure the bbox remains within the image boundaries
162
+ new_x1 = max(0, new_x1)
163
+ new_y1 = max(0, new_y1)
164
+ new_x2 = min(image_width - 1, new_x2)
165
+ new_y2 = min(image_height - 1, new_y2)
166
+
167
+ # Return the scaled bbox coordinates
168
+ scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
169
+ return scaled_bbox