52Hz commited on
Commit
24aaf37
·
verified ·
1 Parent(s): 5422deb

Update main_test_CMFNet.py

Browse files
Files changed (1) hide show
  1. main_test_CMFNet.py +29 -26
main_test_CMFNet.py CHANGED
@@ -14,6 +14,34 @@ from natsort import natsorted
14
  from model.CMFNet import CMFNet
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def main():
18
  parser = argparse.ArgumentParser(description='Demo Image Deraindrop')
19
  parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
@@ -66,33 +94,8 @@ def main():
66
  f = os.path.splitext(os.path.split(file_)[-1])[0]
67
  save_img((os.path.join(out_dir, f + '.png')), restored)
68
 
 
69
 
70
- def save_img(filepath, img):
71
- cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
72
-
73
-
74
- def load_checkpoint(model, weights):
75
- checkpoint = torch.load(weights, map_location=torch.device('cpu'))
76
- try:
77
- model.load_state_dict(checkpoint["state_dict"])
78
- except:
79
- state_dict = checkpoint["state_dict"]
80
- new_state_dict = OrderedDict()
81
- for k, v in state_dict.items():
82
- name = k[7:] # remove `module.`
83
- new_state_dict[name] = v
84
- model.load_state_dict(new_state_dict)
85
-
86
- def clean_folder(folder):
87
- for filename in os.listdir(folder):
88
- file_path = os.path.join(folder, filename)
89
- try:
90
- if os.path.isfile(file_path) or os.path.islink(file_path):
91
- os.unlink(file_path)
92
- elif os.path.isdir(file_path):
93
- shutil.rmtree(file_path)
94
- except Exception as e:
95
- print('Failed to delete %s. Reason: %s' % (file_path, e))
96
 
97
  if __name__ == '__main__':
98
  main()
 
14
  from model.CMFNet import CMFNet
15
 
16
 
17
+ def save_img(filepath, img):
18
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
19
+
20
+
21
+ def load_checkpoint(model, weights):
22
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
23
+ try:
24
+ model.load_state_dict(checkpoint["state_dict"])
25
+ except:
26
+ state_dict = checkpoint["state_dict"]
27
+ new_state_dict = OrderedDict()
28
+ for k, v in state_dict.items():
29
+ name = k[7:] # remove `module.`
30
+ new_state_dict[name] = v
31
+ model.load_state_dict(new_state_dict)
32
+
33
+ def clean_folder(folder):
34
+ for filename in os.listdir(folder):
35
+ file_path = os.path.join(folder, filename)
36
+ try:
37
+ if os.path.isfile(file_path) or os.path.islink(file_path):
38
+ os.unlink(file_path)
39
+ elif os.path.isdir(file_path):
40
+ shutil.rmtree(file_path)
41
+ except Exception as e:
42
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
43
+
44
+
45
  def main():
46
  parser = argparse.ArgumentParser(description='Demo Image Deraindrop')
47
  parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
 
94
  f = os.path.splitext(os.path.split(file_)[-1])[0]
95
  save_img((os.path.join(out_dir, f + '.png')), restored)
96
 
97
+ clean_folder(inp_dir)
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == '__main__':
101
  main()