Dnau15 commited on
Commit
fac857d
·
0 Parent(s):

initial commit

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +12 -0
  3. app.py +248 -0
  4. configs/prediction/default.yaml +24 -0
  5. requirements.txt +20 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Lama
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.6.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ from saicinpainting.evaluation.utils import move_to_device
7
+ from saicinpainting.evaluation.refinement import refine_predict
8
+ from saicinpainting.evaluation.data import pad_img_to_modulo
9
+ from saicinpainting.training.trainers import load_checkpoint
10
+
11
+ import numpy as np
12
+ import torch
13
+ import yaml
14
+ from omegaconf import OmegaConf
15
+ from torch.utils.data._utils.collate import default_collate
16
+ import os
17
+ #from gradio_imageslider import ImageSlider
18
+ import requests
19
+ import zipfile
20
+ import os
21
+
22
+ # URL of the file to download
23
+ url = "https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip"
24
+
25
+ # Local filename to save the downloaded file
26
+ local_filename = "big-lama.zip"
27
+
28
+ # Directory to extract the files into
29
+ extract_dir = "big-lama"
30
+
31
+ # Check if the extracted directory already exists
32
+ if os.path.exists(extract_dir):
33
+ print(f"The directory '{extract_dir}' already exists. Skipping download and extraction.")
34
+ else:
35
+ # Check if the zip file already exists
36
+ if not os.path.exists(local_filename):
37
+ # Download the file
38
+ with requests.get(url, stream=True) as response:
39
+ response.raise_for_status()
40
+ with open(local_filename, 'wb') as f:
41
+ for chunk in response.iter_content(chunk_size=8192):
42
+ f.write(chunk)
43
+ print(f"Downloaded '{local_filename}' successfully.")
44
+ else:
45
+ print(f"The file '{local_filename}' already exists. Skipping download.")
46
+
47
+ # Unzip the file
48
+ with zipfile.ZipFile(local_filename, 'r') as zip_ref:
49
+ zip_ref.extractall()
50
+ print(f"Extracted '{local_filename}' into '{extract_dir}' successfully.")
51
+
52
+ # Optionally, remove the zip file after extraction
53
+ os.remove(local_filename)
54
+ print(f"Removed '{local_filename}' after extraction.")
55
+
56
+ generator = torch.Generator(device="cuda").manual_seed(42)
57
+
58
+ size = (1024, 1024)
59
+
60
+
61
+ def image_preprocess(image: Image, mode="RGB", return_orig=False):
62
+ img = np.array(image.convert(mode))
63
+ if img.ndim == 3:
64
+ img = np.transpose(img, (2, 0, 1))
65
+ out_img = img.astype("float32") / 255
66
+ if return_orig:
67
+ return out_img, img
68
+ else:
69
+ return out_img
70
+
71
+
72
+ def infer(prompt, image):
73
+ source = image["background"].convert("RGB").resize(size)
74
+
75
+ mask = image["layers"][0]
76
+
77
+ mask = mask.point(lambda p: p > 0 and 255).split()[3]
78
+ mask.convert("RGB")
79
+
80
+ # binary_mask = mask.point(lambda p: 255 if p > 0 else 0)
81
+ # inverted_mask = ImageChops.invert(binary_mask)
82
+
83
+ # alpha_image = Image.new("RGB", source.size, (0, 0, 0))
84
+ # cnet_image = Image.composite(source, alpha_image, inverted_mask)
85
+
86
+ device = torch.device("cpu")
87
+
88
+ predict_config_path = "/home/naumov/lama_predict/configs/prediction/default.yaml"
89
+
90
+ with open(predict_config_path, "r") as f:
91
+ predict_config = OmegaConf.create(yaml.safe_load(f))
92
+
93
+ train_config_path = os.path.join(predict_config.model.path, "config.yaml")
94
+ with open(train_config_path, "r") as f:
95
+ train_config = OmegaConf.create(yaml.safe_load(f))
96
+
97
+ train_config.training_model.predict_only = True
98
+ train_config.visualizer.kind = "noop"
99
+
100
+ checkpoint_path = os.path.join(
101
+ predict_config.model.path, "models", predict_config.model.checkpoint
102
+ )
103
+
104
+ model = load_checkpoint(
105
+ train_config, checkpoint_path, strict=False, map_location="cpu"
106
+ )
107
+ model.freeze()
108
+ if not predict_config.get("refine", False):
109
+ model.to(device)
110
+
111
+ img = image_preprocess(source, mode="RGB")
112
+ mask = image_preprocess(mask, mode="L")
113
+
114
+ result = dict(image=img, mask=mask[None, ...])
115
+
116
+ if (
117
+ predict_config.dataset.pad_out_to_modulo is not None
118
+ and predict_config.dataset.pad_out_to_modulo > 1
119
+ ):
120
+ result["unpad_to_size"] = result["image"].shape[1:]
121
+ result["image"] = pad_img_to_modulo(
122
+ result["image"], predict_config.dataset.pad_out_to_modulo
123
+ )
124
+ result["mask"] = pad_img_to_modulo(
125
+ result["mask"], predict_config.dataset.pad_out_to_modulo
126
+ )
127
+
128
+ batch = default_collate([result])
129
+ if predict_config.get("refine", False):
130
+ assert "unpad_to_size" in batch, "Unpadded size is required for the refinement"
131
+ # image unpadding is taken care of in the refiner, so that output image
132
+ # is same size as the input image
133
+ cur_res = refine_predict(batch, model, **predict_config.refiner)
134
+ cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy()
135
+ else:
136
+ with torch.no_grad():
137
+ batch = move_to_device(batch, device)
138
+ batch["mask"] = (batch["mask"] > 0) * 1
139
+ batch = model(batch)
140
+ cur_res = (
141
+ batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
142
+ )
143
+ unpad_to_size = batch.get("unpad_to_size", None)
144
+ if unpad_to_size is not None:
145
+ orig_height, orig_width = unpad_to_size
146
+ cur_res = cur_res[:orig_height, :orig_width]
147
+
148
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
149
+
150
+ yield cur_res
151
+
152
+
153
+ def clear_result():
154
+ return gr.update(value=None)
155
+
156
+
157
+ css = """.main-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.main-div div h1{font-weight:900;margin-bottom:7px}.main-div p{margin-bottom:10px;font-size:94%}a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
158
+ """
159
+ prefix = ""
160
+
161
+ title = f"""
162
+ <div class="main-div">
163
+ <div>
164
+ <h1>Small Stable Diffusion V0</h1>
165
+ </div>
166
+ <p>
167
+ Demo for <a href="https://huggingface.co/OFA-Sys/small-stable-diffusion-v0">Small Stable Diffusion V0</a> Stable Diffusion model.<br>
168
+ {"Add the following tokens to your prompts for the model to work properly: <b>prefix</b>" if prefix else ""}
169
+ </p>
170
+ Running on {"<b>GPU 🔥</b>" if torch.cuda.is_available() else f"<b>CPU 🥶</b>. For faster inference it is recommended to <b>upgrade to GPU in <a href='https://huggingface.co/spaces/akhaliq/small-stable-diffusion-v0/settings'>Settings</a></b>"} after duplicating the space<br><br>
171
+ <a style="display:inline-block" href="https://huggingface.co/spaces/akhaliq/small-stable-diffusion-v0?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
172
+ </div>
173
+ """
174
+
175
+ with gr.Blocks(css=css) as demo:
176
+ gr.HTML(title)
177
+ with gr.Row():
178
+ with gr.Row():
179
+ with gr.Column():
180
+ prompt = gr.Textbox(
181
+ label="Prompt",
182
+ info="Describe what to inpaint the mask with",
183
+ lines=3,
184
+ )
185
+ with gr.Column():
186
+ with gr.Row():
187
+ with gr.Column():
188
+ run_button = gr.Button("Generate")
189
+ with gr.Row():
190
+ input_image = gr.ImageMask(
191
+ type="pil",
192
+ label="Input Image",
193
+ crop_size=(1024, 1024),
194
+ layers=False,
195
+ height=712,
196
+ width=712
197
+ )
198
+
199
+ result = gr.Image(
200
+ interactive=False,
201
+ label="Generated Image",
202
+ )
203
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
204
+
205
+ def use_output_as_input(output_image):
206
+ return gr.update(value=output_image)
207
+
208
+ use_as_input_button.click(
209
+ fn=use_output_as_input, inputs=[result], outputs=[input_image]
210
+ )
211
+
212
+ run_button.click(
213
+ fn=clear_result,
214
+ inputs=None,
215
+ outputs=result,
216
+ ).then(
217
+ fn=lambda: gr.update(visible=False),
218
+ inputs=None,
219
+ outputs=use_as_input_button,
220
+ ).then(
221
+ fn=infer,
222
+ inputs=[prompt, input_image],
223
+ outputs=result,
224
+ ).then(
225
+ fn=lambda: gr.update(visible=True),
226
+ inputs=None,
227
+ outputs=use_as_input_button,
228
+ )
229
+
230
+ prompt.submit(
231
+ fn=clear_result,
232
+ inputs=None,
233
+ outputs=result,
234
+ ).then(
235
+ fn=lambda: gr.update(visible=False),
236
+ inputs=None,
237
+ outputs=use_as_input_button,
238
+ ).then(
239
+ fn=infer,
240
+ inputs=[prompt, input_image],
241
+ outputs=result,
242
+ ).then(
243
+ fn=lambda: gr.update(visible=True),
244
+ inputs=None,
245
+ outputs=use_as_input_button,
246
+ )
247
+
248
+ demo.launch()
configs/prediction/default.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ indir: no # to be overriden in CLI
2
+ outdir: no # to be overriden in CLI
3
+
4
+ model:
5
+ path: big-lama # to be overriden in CLI
6
+ checkpoint: best.ckpt
7
+
8
+ dataset:
9
+ kind: default
10
+ img_suffix: .png
11
+ pad_out_to_modulo: 8
12
+
13
+ device: cuda
14
+ out_key: inpainted
15
+
16
+ refine: False # refiner will only run if this is True
17
+ refiner:
18
+ gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0,"
19
+ modulo: ${dataset.pad_out_to_modulo}
20
+ n_iters: 15 # number of iterations of refinement for each scale
21
+ lr: 0.002 # learning rate
22
+ min_side: 512 # all sides of image on all scales should be >= min_side / sqrt(2)
23
+ max_scales: 3 # max number of downscaling scales for the image-mask pyramid
24
+ px_budget: 1800000 # pixels budget. Any image will be resized to satisfy height*width <= px_budget
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pyyaml
2
+ tqdm
3
+ numpy
4
+ easydict==1.9.0
5
+ scikit-image==0.17.2
6
+ scikit-learn==0.24.2
7
+ opencv-python
8
+ tensorflow
9
+ joblib
10
+ matplotlib
11
+ pandas
12
+ albumentations==0.5.2
13
+ hydra-core==1.1.0
14
+ pytorch-lightning==1.2.9
15
+ tabulate
16
+ kornia==0.5.0
17
+ webdataset
18
+ packaging
19
+ scikit-learn==0.24.2
20
+ wldhx.yadisk-direct