Spaces:
Sleeping
Sleeping
Zongsheng
commited on
Commit
•
d6fbbca
1
Parent(s):
efc301f
change examples
Browse files- app.py +13 -14
- sampler.py +6 -1
app.py
CHANGED
@@ -20,15 +20,17 @@ from sampler import DifIRSampler
|
|
20 |
from ResizeRight.resize_right import resize
|
21 |
from basicsr.utils.download_util import load_file_from_url
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
configs = OmegaConf.load(cfg_path)
|
28 |
-
configs.aligned = False
|
29 |
configs.background_enhance = background_enhance
|
30 |
configs.face_upsample = face_upsample
|
31 |
-
|
32 |
started_timesteps = int(started_timesteps)
|
33 |
assert started_timesteps < int(configs.diffusion.params.timestep_respacing)
|
34 |
|
@@ -56,9 +58,6 @@ def predict(im_path, background_enhance, face_upsample, upscale, started_timeste
|
|
56 |
upscale = 2 # avoid momory exceeded due to too large img resolution
|
57 |
configs.detection.upscale = int(upscale)
|
58 |
|
59 |
-
# build the sampler for diffusion
|
60 |
-
sampler_dist = DifIRSampler(configs)
|
61 |
-
|
62 |
image_restored, face_restored, face_cropped = sampler_dist.sample_func_bfr_unaligned(
|
63 |
y0=im_lq,
|
64 |
start_timesteps=started_timesteps,
|
@@ -71,7 +70,7 @@ def predict(im_path, background_enhance, face_upsample, upscale, started_timeste
|
|
71 |
restored_image_dir.mkdir()
|
72 |
# save the whole image
|
73 |
save_path = restored_image_dir / Path(im_path).name
|
74 |
-
util_image.imwrite(image_restored, save_path, chn='
|
75 |
|
76 |
return image_restored, str(save_path)
|
77 |
|
@@ -114,10 +113,10 @@ If you have any questions, please feel free to contact me via <b>[email protected]
|
|
114 |
demo = gr.Interface(
|
115 |
predict,
|
116 |
inputs=[
|
117 |
-
gr.
|
118 |
-
gr.
|
119 |
-
gr.
|
120 |
-
gr.
|
121 |
gr.Slider(1, 200, value=100, step=10, label='Realism-Fidelity Trade-off')
|
122 |
],
|
123 |
outputs=[
|
|
|
20 |
from ResizeRight.resize_right import resize
|
21 |
from basicsr.utils.download_util import load_file_from_url
|
22 |
|
23 |
+
# setting configurations
|
24 |
+
cfg_path = 'configs/sample/iddpm_ffhq512_swinir.yaml'
|
25 |
+
configs = OmegaConf.load(cfg_path)
|
26 |
+
configs.aligned = False
|
27 |
+
|
28 |
+
# build the sampler for diffusion
|
29 |
+
sampler_dist = DifIRSampler(configs)
|
30 |
|
31 |
+
def predict(im_path, background_enhance, face_upsample, upscale, started_timesteps):
|
|
|
|
|
32 |
configs.background_enhance = background_enhance
|
33 |
configs.face_upsample = face_upsample
|
|
|
34 |
started_timesteps = int(started_timesteps)
|
35 |
assert started_timesteps < int(configs.diffusion.params.timestep_respacing)
|
36 |
|
|
|
58 |
upscale = 2 # avoid momory exceeded due to too large img resolution
|
59 |
configs.detection.upscale = int(upscale)
|
60 |
|
|
|
|
|
|
|
61 |
image_restored, face_restored, face_cropped = sampler_dist.sample_func_bfr_unaligned(
|
62 |
y0=im_lq,
|
63 |
start_timesteps=started_timesteps,
|
|
|
70 |
restored_image_dir.mkdir()
|
71 |
# save the whole image
|
72 |
save_path = restored_image_dir / Path(im_path).name
|
73 |
+
util_image.imwrite(image_restored, save_path, chn='rgb', dtype_in='uint8')
|
74 |
|
75 |
return image_restored, str(save_path)
|
76 |
|
|
|
113 |
demo = gr.Interface(
|
114 |
predict,
|
115 |
inputs=[
|
116 |
+
gr.Image(type="filepath", label="Input"),
|
117 |
+
gr.Checkbox(default=True, label="Background_Enhance"),
|
118 |
+
gr.Checkbox(default=True, label="Face_Upsample"),
|
119 |
+
gr.Number(default=2, label="Rescaling_Factor (up to 4)"),
|
120 |
gr.Slider(1, 200, value=100, step=10, label='Realism-Fidelity Trade-off')
|
121 |
],
|
122 |
outputs=[
|
sampler.py
CHANGED
@@ -54,7 +54,12 @@ class BaseSampler:
|
|
54 |
torch.cuda.manual_seed_all(seed)
|
55 |
|
56 |
def setup_dist(self):
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
self.rank = 0
|
59 |
|
60 |
def build_model(self):
|
|
|
54 |
torch.cuda.manual_seed_all(seed)
|
55 |
|
56 |
def setup_dist(self):
|
57 |
+
if torch.cuda.is_available():
|
58 |
+
self.device = torch.device('cuda')
|
59 |
+
print(f'Runing on GPU...')
|
60 |
+
else:
|
61 |
+
self.device = torch.device('cpu')
|
62 |
+
print(f'Runing on CPU...')
|
63 |
self.rank = 0
|
64 |
|
65 |
def build_model(self):
|