Spaces:
Build error
Build error
Commit
·
8728fb1
1
Parent(s):
d44c9b5
Update app.py
Browse files
app.py
CHANGED
@@ -31,20 +31,20 @@ def lab2rgb(L, AB):
|
|
31 |
rgb = color.lab2rgb(Lab) * 255
|
32 |
return rgb
|
33 |
|
34 |
-
def get_transform(params=None, grayscale=False, method=Image.BICUBIC):
|
35 |
#params
|
36 |
-
preprocess = '
|
37 |
load_size = 256
|
38 |
crop_size = 256
|
39 |
transform_list = []
|
40 |
if grayscale:
|
41 |
transform_list.append(transforms.Grayscale(1))
|
42 |
-
if
|
43 |
osize = [load_size, load_size]
|
44 |
transform_list.append(transforms.Resize(osize, method))
|
45 |
-
if 'crop' in preprocess:
|
46 |
-
|
47 |
-
|
48 |
|
49 |
return transforms.Compose(transform_list)
|
50 |
|
@@ -67,7 +67,7 @@ def inferRestoration(img, model_name):
|
|
67 |
return result
|
68 |
|
69 |
def inferColorization(img,model_name):
|
70 |
-
print(model_name)
|
71 |
if model_name == "Pix2Pix Resnet 9block":
|
72 |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
|
73 |
elif model_name == "Pix2Pix Unet 256":
|
@@ -96,10 +96,12 @@ def inferColorization(img,model_name):
|
|
96 |
image_pil = transforms.ToPILImage()(result)
|
97 |
return image_pil
|
98 |
|
99 |
-
transform_seq = get_transform()
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
103 |
lab_t = transforms.ToTensor()(lab)
|
104 |
A = lab_t[[0], ...] / 50.0 - 1.0
|
105 |
B = lab_t[[1, 2], ...] / 110.0
|
@@ -160,4 +162,4 @@ examples = [['example/1.jpeg',"BOPBTL","Deoldify"],['example/2.jpg',"BOPBTL","De
|
|
160 |
iface = gr.Interface(run,
|
161 |
[gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
|
162 |
outputs="image",
|
163 |
-
examples=examples).launch(debug=True,share=
|
|
|
31 |
rgb = color.lab2rgb(Lab) * 255
|
32 |
return rgb
|
33 |
|
34 |
+
def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC):
|
35 |
#params
|
36 |
+
preprocess = 'resize'
|
37 |
load_size = 256
|
38 |
crop_size = 256
|
39 |
transform_list = []
|
40 |
if grayscale:
|
41 |
transform_list.append(transforms.Grayscale(1))
|
42 |
+
if model_name == "Pix2Pix Unet 256":
|
43 |
osize = [load_size, load_size]
|
44 |
transform_list.append(transforms.Resize(osize, method))
|
45 |
+
# if 'crop' in preprocess:
|
46 |
+
# if params is None:
|
47 |
+
# transform_list.append(transforms.RandomCrop(crop_size))
|
48 |
|
49 |
return transforms.Compose(transform_list)
|
50 |
|
|
|
67 |
return result
|
68 |
|
69 |
def inferColorization(img,model_name):
|
70 |
+
#print(model_name)
|
71 |
if model_name == "Pix2Pix Resnet 9block":
|
72 |
model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
|
73 |
elif model_name == "Pix2Pix Unet 256":
|
|
|
96 |
image_pil = transforms.ToPILImage()(result)
|
97 |
return image_pil
|
98 |
|
99 |
+
transform_seq = get_transform(model_name)
|
100 |
+
img = transform_seq(img)
|
101 |
+
# if model_name == "Pix2Pix Unet 256":
|
102 |
+
# img.resize((256,256))
|
103 |
+
img = np.array(img)
|
104 |
+
lab = color.rgb2lab(img).astype(np.float32)
|
105 |
lab_t = transforms.ToTensor()(lab)
|
106 |
A = lab_t[[0], ...] / 50.0 - 1.0
|
107 |
B = lab_t[[1, 2], ...] / 110.0
|
|
|
162 |
iface = gr.Interface(run,
|
163 |
[gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
|
164 |
outputs="image",
|
165 |
+
examples=examples).launch(debug=True,share=True)
|