nikigoli commited on
Commit
d73d6bf
·
verified ·
1 Parent(s): 2f1d1a1

add-demo-notebook (#5)

Browse files

- Refactor app.py - extract reusable functions (aedd89b11c7db4e19dbd7d72566a0fbccea3bd85)
- Add sample notebook (96f9e24b2cf04f41c38a00d4706abb6b38ec88e4)
- Remove notebook output (eb3994e008d734612930f2bdd5a091882ba18603)

Files changed (4) hide show
  1. .gitignore +2 -2
  2. app.py +118 -148
  3. notebooks/demo.ipynb +492 -0
  4. requirements.txt +2 -0
.gitignore CHANGED
@@ -2,7 +2,7 @@
2
  env/
3
  __pycache__
4
  .python-version
5
-
6
 
7
  # vim
8
- *.sw[op]
 
2
  env/
3
  __pycache__
4
  .python-version
5
+ *.py[od]
6
 
7
  # vim
8
+ *.sw[op]
app.py CHANGED
@@ -14,11 +14,6 @@ import matplotlib.pyplot as plt
14
  import io
15
  from enum import Enum
16
  import os
17
- import subprocess
18
- from subprocess import call
19
- import shlex
20
- import shutil
21
- #os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), "tmp")
22
  cwd = os.getcwd()
23
  # Suppress warnings to avoid overflowing the log.
24
  import warnings
@@ -145,22 +140,6 @@ def build_model_and_transforms(args):
145
 
146
  return model, data_transform
147
 
148
- examples = [
149
- ["strawberry.jpg", "strawberry", {"image": "strawberry.jpg"}],
150
- ["strawberry.jpg", "blueberry", {"image": "strawberry.jpg"}],
151
- ["bird-1.JPG", "bird", {"image": "bird-2.JPG"}],
152
- ["fish.jpg", "fish", {"image": "fish.jpg"}],
153
- ["women.jpg", "girl", {"image": "women.jpg"}],
154
- ["women.jpg", "boy", {"image": "women.jpg"}],
155
- ["balloon.jpg", "hot air balloon", {"image": "balloon.jpg"}],
156
- ["deer.jpg", "deer", {"image": "deer.jpg"}],
157
- ["apple.jpg", "apple", {"image": "apple.jpg"}],
158
- ["egg.jpg", "egg", {"image": "egg.jpg"}],
159
- ["stamp.jpg", "stamp", {"image": "stamp.jpg"}],
160
- ["green-pea.jpg", "green pea", {"image": "green-pea.jpg"}],
161
- ["lego.jpg", "lego", {"image": "lego.jpg"}]
162
- ]
163
-
164
  # APP:
165
  def get_box_inputs(prompts):
166
  box_inputs = []
@@ -197,6 +176,107 @@ def get_ind_to_filter(text, word_ids, keywords):
197
 
198
  return inds_to_filter
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if __name__ == '__main__':
201
 
202
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
@@ -205,56 +285,19 @@ if __name__ == '__main__':
205
  model, transform = build_model_and_transforms(args)
206
  model = model.to(device)
207
 
 
 
208
  @spaces.GPU(duration=120)
209
  def count(image, text, prompts, state, device):
210
-
211
- keywords = "" # do not handle this for now
212
-
213
- # Handle no prompt case.
214
  if prompts is None:
215
  prompts = {"image": image, "points": []}
216
- input_image, _ = transform(image, {"exemplars": torch.tensor([])})
217
- input_image = input_image.unsqueeze(0).to(device)
218
- exemplars = get_box_inputs(prompts["points"])
219
-
220
- input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
221
- input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
222
- exemplars = [exemplars["exemplars"].to(device)]
223
-
224
- with torch.no_grad():
225
- model_output = model(
226
- nested_tensor_from_tensor_list(input_image),
227
- nested_tensor_from_tensor_list(input_image_exemplars),
228
- exemplars,
229
- [torch.tensor([0]).to(device) for _ in range(len(input_image))],
230
- captions=[text + " ."] * len(input_image),
231
- )
232
 
233
- ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
234
- logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
235
- boxes = model_output["pred_boxes"][0]
236
- if len(keywords.strip()) > 0:
237
- box_mask = (logits > CONF_THRESH).sum(dim=-1) == len(ind_to_filter)
238
- else:
239
- box_mask = logits.max(dim=-1).values > CONF_THRESH
240
- logits = logits[box_mask, :].cpu().numpy()
241
- boxes = boxes[box_mask, :].cpu().numpy()
242
-
243
- # Plot results.
244
- (w, h) = image.size
245
- det_map = np.zeros((h, w))
246
- det_map[(h * boxes[:, 1]).astype(int), (w * boxes[:, 0]).astype(int)] = 1
247
- det_map = ndimage.gaussian_filter(
248
- det_map, sigma=(w // 200, w // 200), order=0
249
- )
250
- plt.imshow(image)
251
- plt.imshow(det_map[None, :].transpose(1, 2, 0), 'jet', interpolation='none', alpha=0.7)
252
- plt.axis('off')
253
- img_buf = io.BytesIO()
254
- plt.savefig(img_buf, format='png', bbox_inches='tight')
255
- plt.close()
256
-
257
- output_img = Image.open(img_buf)
258
 
259
  if AppSteps.TEXT_AND_EXEMPLARS not in state:
260
  exemplar_image = ImagePrompter(type='pil', label='Visual Exemplar Image', value=prompts, interactive=True, visible=True)
@@ -274,92 +317,19 @@ if __name__ == '__main__':
274
  main_instructions_comp = gr.Markdown(visible=True)
275
  step_3 = gr.Tab(visible=True)
276
 
277
- out_label = "Detected instances predicted with"
278
- if len(text.strip()) > 0:
279
- out_label += " text"
280
- if exemplars[0].size()[0] == 1:
281
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
282
- elif exemplars[0].size()[0] > 1:
283
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
284
- else:
285
- out_label += "."
286
- elif exemplars[0].size()[0] > 0:
287
- if exemplars[0].size()[0] == 1:
288
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplar."
289
- else:
290
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
291
- else:
292
- out_label = "Nothing specified to detect."
293
-
294
- return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
295
 
296
  @spaces.GPU
297
  def count_main(image, text, prompts, device):
298
- keywords = "" # do not handle this for now
299
- # Handle no prompt case.
300
  if prompts is None:
301
  prompts = {"image": image, "points": []}
302
- input_image, _ = transform(image, {"exemplars": torch.tensor([])})
303
- input_image = input_image.unsqueeze(0).to(device)
304
- exemplars = get_box_inputs(prompts["points"])
305
-
306
- input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
307
- input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
308
- exemplars = [exemplars["exemplars"].to(device)]
309
-
310
- with torch.no_grad():
311
- model_output = model(
312
- nested_tensor_from_tensor_list(input_image),
313
- nested_tensor_from_tensor_list(input_image_exemplars),
314
- exemplars,
315
- [torch.tensor([0]).to(device) for _ in range(len(input_image))],
316
- captions=[text + " ."] * len(input_image),
317
- )
318
-
319
- ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
320
- logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
321
- boxes = model_output["pred_boxes"][0]
322
- if len(keywords.strip()) > 0:
323
- box_mask = (logits > CONF_THRESH).sum(dim=-1) == len(ind_to_filter)
324
- else:
325
- box_mask = logits.max(dim=-1).values > CONF_THRESH
326
- logits = logits[box_mask, :].cpu().numpy()
327
- boxes = boxes[box_mask, :].cpu().numpy()
328
-
329
- # Plot results.
330
- (w, h) = image.size
331
- det_map = np.zeros((h, w))
332
- det_map[(h * boxes[:, 1]).astype(int), (w * boxes[:, 0]).astype(int)] = 1
333
- det_map = ndimage.gaussian_filter(
334
- det_map, sigma=(w // 200, w // 200), order=0
335
- )
336
- plt.imshow(image)
337
- plt.imshow(det_map[None, :].transpose(1, 2, 0), 'jet', interpolation='none', alpha=0.7)
338
- plt.axis('off')
339
- img_buf = io.BytesIO()
340
- plt.savefig(img_buf, format='png', bbox_inches='tight')
341
- plt.close()
342
-
343
- output_img = Image.open(img_buf)
344
-
345
- out_label = "Detected instances predicted with"
346
- if len(text.strip()) > 0:
347
- out_label += " text"
348
- if exemplars[0].size()[0] == 1:
349
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
350
- elif exemplars[0].size()[0] > 1:
351
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
352
- else:
353
- out_label += "."
354
- elif exemplars[0].size()[0] > 0:
355
- if exemplars[0].size()[0] == 1:
356
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplar."
357
- else:
358
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
359
- else:
360
- out_label = "Nothing specified to detect."
361
 
362
- return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
363
 
364
  def remove_label(image):
365
  return gr.Image(show_label=False)
@@ -401,12 +371,12 @@ if __name__ == '__main__':
401
  with gr.Accordion("Open for Further Information", open=False):
402
  gr.Markdown(exemplar_img_drawing_instructions_part_2)
403
  with gr.Tab("Step 1", visible=True) as step_1:
404
- input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
405
  gr.Markdown('# Click "Count" to count the strawberries.')
406
 
407
  with gr.Column():
408
  with gr.Tab("Output Image"):
409
- detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
410
 
411
  with gr.Row():
412
  input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
 
14
  import io
15
  from enum import Enum
16
  import os
 
 
 
 
 
17
  cwd = os.getcwd()
18
  # Suppress warnings to avoid overflowing the log.
19
  import warnings
 
140
 
141
  return model, data_transform
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # APP:
144
  def get_box_inputs(prompts):
145
  box_inputs = []
 
176
 
177
  return inds_to_filter
178
 
179
+ def generate_heatmap(image, boxes):
180
+ # Plot results.
181
+ (w, h) = image.size
182
+ det_map = np.zeros((h, w))
183
+ det_map[(h * boxes[:, 1]).astype(int), (w * boxes[:, 0]).astype(int)] = 1
184
+ det_map = ndimage.gaussian_filter(
185
+ det_map, sigma=(w // 200, w // 200), order=0
186
+ )
187
+ plt.imshow(image)
188
+ plt.imshow(det_map[None, :].transpose(1, 2, 0), 'jet', interpolation='none', alpha=0.7)
189
+ plt.axis('off')
190
+ img_buf = io.BytesIO()
191
+ plt.savefig(img_buf, format='png', bbox_inches='tight')
192
+ plt.close()
193
+
194
+ output_img = Image.open(img_buf)
195
+ return output_img
196
+
197
+ def generate_output_label(text, num_exemplars):
198
+ out_label = "Detected instances predicted with"
199
+ if len(text.strip()) > 0:
200
+ out_label += " text"
201
+ if num_exemplars == 1:
202
+ out_label += " and " + str(num_exemplars) + " visual exemplar."
203
+ elif num_exemplars > 1:
204
+ out_label += " and " + str(num_exemplars) + " visual exemplars."
205
+ else:
206
+ out_label += "."
207
+ elif num_exemplars > 0:
208
+ if num_exemplars == 1:
209
+ out_label += " " + str(num_exemplars) + " visual exemplar."
210
+ else:
211
+ out_label += " " + str(num_exemplars) + " visual exemplars."
212
+ else:
213
+ out_label = "Nothing specified to detect."
214
+
215
+ return out_label
216
+
217
+ def preprocess(transform, image, input_prompts = None):
218
+ if input_prompts == None:
219
+ prompts = { "image": image, "points": []}
220
+ else:
221
+ prompts = input_prompts
222
+
223
+ input_image, _ = transform(image, None)
224
+ exemplar = get_box_inputs(prompts["points"])
225
+ # Wrapping exemplar in a dictionary to apply only relevant transforms
226
+ input_image_exemplar, exemplar = transform(prompts['image'], {"exemplars": torch.tensor(exemplar)})
227
+ exemplar = exemplar["exemplars"]
228
+
229
+ return input_image, input_image_exemplar, exemplar
230
+
231
+ def get_boxes_from_prediction(model_output, text, keywords = ""):
232
+ ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
233
+ logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
234
+ boxes = model_output["pred_boxes"][0]
235
+ if len(keywords.strip()) > 0:
236
+ box_mask = (logits > CONF_THRESH).sum(dim=-1) == len(ind_to_filter)
237
+ else:
238
+ box_mask = logits.max(dim=-1).values > CONF_THRESH
239
+ boxes = boxes[box_mask, :].cpu().numpy()
240
+ logits = logits[box_mask, :].cpu().numpy()
241
+ return boxes, logits
242
+
243
+ def predict(model, transform, image, text, prompts, device):
244
+ keywords = "" # do not handle this for now
245
+ input_image, input_image_exemplar, exemplar = preprocess(transform, image, prompts)
246
+
247
+ input_images = input_image.unsqueeze(0).to(device)
248
+ input_image_exemplars = input_image_exemplar.unsqueeze(0).to(device)
249
+ exemplars = [exemplar.to(device)]
250
+
251
+ with torch.no_grad():
252
+ model_output = model(
253
+ nested_tensor_from_tensor_list(input_images),
254
+ nested_tensor_from_tensor_list(input_image_exemplars),
255
+ exemplars,
256
+ [torch.tensor([0]).to(device) for _ in range(len(input_images))],
257
+ captions=[text + " ."] * len(input_images),
258
+ )
259
+
260
+ keywords = ""
261
+ return get_boxes_from_prediction(model_output, text, keywords)
262
+
263
+ examples = [
264
+ ["strawberry.jpg", "strawberry", {"image": "strawberry.jpg"}],
265
+ ["strawberry.jpg", "blueberry", {"image": "strawberry.jpg"}],
266
+ ["bird-1.JPG", "bird", {"image": "bird-2.JPG"}],
267
+ ["fish.jpg", "fish", {"image": "fish.jpg"}],
268
+ ["women.jpg", "girl", {"image": "women.jpg"}],
269
+ ["women.jpg", "boy", {"image": "women.jpg"}],
270
+ ["balloon.jpg", "hot air balloon", {"image": "balloon.jpg"}],
271
+ ["deer.jpg", "deer", {"image": "deer.jpg"}],
272
+ ["apple.jpg", "apple", {"image": "apple.jpg"}],
273
+ ["egg.jpg", "egg", {"image": "egg.jpg"}],
274
+ ["stamp.jpg", "stamp", {"image": "stamp.jpg"}],
275
+ ["green-pea.jpg", "green pea", {"image": "green-pea.jpg"}],
276
+ ["lego.jpg", "lego", {"image": "lego.jpg"}]
277
+ ]
278
+
279
+
280
  if __name__ == '__main__':
281
 
282
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
 
285
  model, transform = build_model_and_transforms(args)
286
  model = model.to(device)
287
 
288
+ _predict = lambda image, text, prompts: predict(model, transform, image, text, prompts, device)
289
+
290
  @spaces.GPU(duration=120)
291
  def count(image, text, prompts, state, device):
 
 
 
 
292
  if prompts is None:
293
  prompts = {"image": image, "points": []}
294
+
295
+ boxes, _ = _predict(image, text, prompts)
296
+ count = len(boxes)
297
+ output_img = generate_heatmap(image, boxes)
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ num_exemplars = len(get_box_inputs(prompts["points"]))
300
+ out_label = generate_output_label(text, num_exemplars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  if AppSteps.TEXT_AND_EXEMPLARS not in state:
303
  exemplar_image = ImagePrompter(type='pil', label='Visual Exemplar Image', value=prompts, interactive=True, visible=True)
 
317
  main_instructions_comp = gr.Markdown(visible=True)
318
  step_3 = gr.Tab(visible=True)
319
 
320
+ return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=count), new_submit_btn, gr.Tab(visible=True), step_3, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  @spaces.GPU
323
  def count_main(image, text, prompts, device):
 
 
324
  if prompts is None:
325
  prompts = {"image": image, "points": []}
326
+ boxes, _ = _predict(image, text, prompts)
327
+ count = len(boxes)
328
+ output_img = generate_heatmap(image, boxes)
329
+ num_exemplars = len(get_box_inputs(prompts["points"]))
330
+ out_label = generate_output_label(text, num_exemplars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=count))
333
 
334
  def remove_label(image):
335
  return gr.Image(show_label=False)
 
371
  with gr.Accordion("Open for Further Information", open=False):
372
  gr.Markdown(exemplar_img_drawing_instructions_part_2)
373
  with gr.Tab("Step 1", visible=True) as step_1:
374
+ input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False)
375
  gr.Markdown('# Click "Count" to count the strawberries.')
376
 
377
  with gr.Column():
378
  with gr.Tab("Output Image"):
379
+ detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True)
380
 
381
  with gr.Row():
382
  input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
notebooks/demo.ipynb ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "yxig5CdZuHb9"
7
+ },
8
+ "source": [
9
+ "# CountGD - Multimodela open-world object counting\n",
10
+ "\n"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "9wyM6J2HuHb-"
17
+ },
18
+ "source": [
19
+ "## Setup\n",
20
+ "\n",
21
+ "The following cells will setup the runtime environment with the following\n",
22
+ "\n",
23
+ "- Mount Google Drive\n",
24
+ "- Install dependencies for running the model\n",
25
+ "- Load the model into memory"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {
31
+ "id": "jn061Tl8uHb-"
32
+ },
33
+ "source": [
34
+ "### Mount Google Drive (if running on colab)\n",
35
+ "\n",
36
+ "The following bit of code will mount your Google Drive folder at `/content/drive`, allowing you to process files directly from it as well as store the results alongside it.\n",
37
+ "\n",
38
+ "Once you execute the next cell, you will be requested to share access with the notebook. Please follow the instructions on screen to do so.\n",
39
+ "If you are not running this on colab, you will still be able to use the files available on your environment."
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {
46
+ "colab": {
47
+ "base_uri": "https://localhost:8080/"
48
+ },
49
+ "collapsed": true,
50
+ "id": "DkSUXqMPuHb-",
51
+ "outputId": "6b82521e-3afd-4545-b13f-8cfea0975d95"
52
+ },
53
+ "outputs": [],
54
+ "source": [
55
+ "# Check if running colab\n",
56
+ "import logging\n",
57
+ "\n",
58
+ "logging.basicConfig(\n",
59
+ " level=logging.INFO,\n",
60
+ " format='%(asctime)s %(levelname)-8s %(name)s %(message)s'\n",
61
+ ")\n",
62
+ "try:\n",
63
+ " import google.colab\n",
64
+ " RUNNING_IN_COLAB = True\n",
65
+ "except:\n",
66
+ " RUNNING_IN_COLAB = False\n",
67
+ "\n",
68
+ "if RUNNING_IN_COLAB:\n",
69
+ " from google.colab import drive\n",
70
+ " drive.mount('/content/drive')\n",
71
+ "\n",
72
+ "from IPython.core.magic import register_cell_magic\n",
73
+ "from IPython import get_ipython\n",
74
+ "@register_cell_magic\n",
75
+ "def skip_if(line, cell):\n",
76
+ " if eval(line):\n",
77
+ " return\n",
78
+ " get_ipython().run_cell(cell)\n",
79
+ "\n",
80
+ "\n",
81
+ "%env RUNNING_IN_COLAB {RUNNING_IN_COLAB}\n"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {
87
+ "id": "kas5YtyluHb_"
88
+ },
89
+ "source": [
90
+ "### Install Dependencies\n",
91
+ "\n",
92
+ "The environment will be setup with the code, models and required dependencies."
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {
99
+ "colab": {
100
+ "base_uri": "https://localhost:8080/"
101
+ },
102
+ "id": "982Yiv5tuHb_",
103
+ "outputId": "2f570d1a-c6cc-49c3-c336-1d784d33a169"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "%%bash\n",
108
+ "\n",
109
+ "set -euxo pipefail\n",
110
+ "\n",
111
+ "if [ \"${RUNNING_IN_COLAB}\" == \"True\" ]; then\n",
112
+ " echo \"Downloading the repository...\"\n",
113
+ " if [ ! -d /content/countgd ]; then\n",
114
+ " git clone \"https://huggingface.co/spaces/nikigoli/countgd\" /content/countgd\n",
115
+ " fi\n",
116
+ " cd /content/countgd\n",
117
+ " git fetch origin refs/pr/5:refs/remotes/origin/pr/5\n",
118
+ " git checkout pr/5\n",
119
+ "else\n",
120
+ " # TODO check if cwd is the correct git repo\n",
121
+ " # If users use vscode, then we set the default start directory to root of the repo\n",
122
+ " echo \"Running in $(pwd)\"\n",
123
+ "fi\n",
124
+ "\n",
125
+ "# TODO check for gcc-11 or above\n",
126
+ "\n",
127
+ "# Install pip packages\n",
128
+ "pip install --upgrade pip setuptools wheel\n",
129
+ "pip install -r requirements.txt\n",
130
+ "\n",
131
+ "# Compile modules\n",
132
+ "export CUDA_HOME=/usr/local/cuda/\n",
133
+ "cd models/GroundingDINO/ops\n",
134
+ "python3 setup.py build\n",
135
+ "pip install .\n",
136
+ "python3 test.py"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {
143
+ "colab": {
144
+ "base_uri": "https://localhost:8080/"
145
+ },
146
+ "id": "58iD_HGnvcRJ",
147
+ "outputId": "fe356a68-dced-4f6f-93cc-d83da2f84e28"
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "%cd {\"/content/countgd\" if RUNNING_IN_COLAB else '.'}"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {
157
+ "id": "gH7A8zthuHb_"
158
+ },
159
+ "source": [
160
+ "## Inference"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {
166
+ "id": "IspbBV0XuHb_"
167
+ },
168
+ "source": [
169
+ "### Loading the model"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {
176
+ "colab": {
177
+ "base_uri": "https://localhost:8080/"
178
+ },
179
+ "id": "5nBT_HCUuHb_",
180
+ "outputId": "95ceb6c6-bee8-4921-8bff-d28937045f78"
181
+ },
182
+ "outputs": [],
183
+ "source": [
184
+ "import app\n",
185
+ "import importlib\n",
186
+ "importlib.reload(app)\n",
187
+ "from app import (\n",
188
+ " build_model_and_transforms,\n",
189
+ " get_device,\n",
190
+ " get_args_parser,\n",
191
+ " generate_heatmap,\n",
192
+ " predict,\n",
193
+ ")\n",
194
+ "args = get_args_parser().parse_args([])\n",
195
+ "device = get_device()\n",
196
+ "model, transform = build_model_and_transforms(args)\n",
197
+ "model = model.to(device)\n",
198
+ "\n",
199
+ "run = lambda image, text: predict(model, transform, image, text, None, device)\n",
200
+ "get_output = lambda image, boxes: (len(boxes), generate_heatmap(image, boxes))\n"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "metadata": {
206
+ "id": "gfjraK3vuHb_"
207
+ },
208
+ "source": [
209
+ "### Input / Output Utils\n",
210
+ "\n",
211
+ "Helper functions for reading / writing to zipfiles and csv"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 17,
217
+ "metadata": {
218
+ "id": "qg0g5B-fuHb_"
219
+ },
220
+ "outputs": [],
221
+ "source": [
222
+ "import io\n",
223
+ "import csv\n",
224
+ "from pathlib import Path\n",
225
+ "from contextlib import contextmanager\n",
226
+ "import zipfile\n",
227
+ "import filetype\n",
228
+ "from PIL import Image\n",
229
+ "logger = logging.getLogger()\n",
230
+ "\n",
231
+ "def images_from_zipfile(p: Path):\n",
232
+ " if not zipfile.is_zipfile(p):\n",
233
+ " raise ValueError(f'{p} is not a zipfile!')\n",
234
+ "\n",
235
+ " with zipfile.ZipFile(p, 'r') as zipf:\n",
236
+ " def process_entry(info: zipfile.ZipInfo):\n",
237
+ " with zipf.open(info) as f:\n",
238
+ " if not filetype.is_image(f):\n",
239
+ " logger.debug(f'Skipping file - {info.filename} as it is not an image')\n",
240
+ " return\n",
241
+ " # Try loading the file\n",
242
+ " try:\n",
243
+ " with Image.open(f) as im:\n",
244
+ " im.load()\n",
245
+ " return (info.filename, im)\n",
246
+ " except:\n",
247
+ " logger.exception(f'Error reading file {info.filename}')\n",
248
+ "\n",
249
+ " num_files = sum(1 for info in zipf.infolist() if info.is_dir() == False)\n",
250
+ " logger.info(f'Found {num_files} file(s) in the zip')\n",
251
+ " yield from (process_entry(info) for info in zipf.infolist() if info.is_dir() == False)\n",
252
+ "\n",
253
+ "@contextmanager\n",
254
+ "def zipfile_writer(p: Path):\n",
255
+ " with zipfile.ZipFile(p, 'w') as zipf:\n",
256
+ " def write_output(image, image_filename):\n",
257
+ " buf = io.BytesIO()\n",
258
+ " image.save(buf, 'PNG')\n",
259
+ " zipf.writestr(image_filename, buf.getvalue())\n",
260
+ " yield write_output\n",
261
+ "\n",
262
+ "@contextmanager\n",
263
+ "def csvfile_writer(p: Path):\n",
264
+ " with p.open('w', newline='') as csvfile:\n",
265
+ " fieldnames = ['filename', 'count']\n",
266
+ " csv_writer = csv.DictWriter(csvfile, fieldnames = fieldnames)\n",
267
+ " csv_writer.writeheader()\n",
268
+ "\n",
269
+ " yield csv_writer.writerow"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 15,
275
+ "metadata": {
276
+ "id": "rFXRk-_uuHb_"
277
+ },
278
+ "outputs": [],
279
+ "source": [
280
+ "from tqdm import tqdm\n",
281
+ "import os\n",
282
+ "def process_zipfile(input_zipfile: Path, text: str):\n",
283
+ " if not input_zipfile.exists() or not input_zipfile.is_file() or not os.access(input_zipfile, os.R_OK):\n",
284
+ " logger.error(f'Cannot open / read zipfile: {input_zipfile}. Please check if it exists')\n",
285
+ " return\n",
286
+ "\n",
287
+ " if text == \"\":\n",
288
+ " logger.error('Please provide the object you would like to count')\n",
289
+ " return\n",
290
+ "\n",
291
+ " output_zipfile = input_zipfile.parent / f'{input_zipfile.stem}_countgd.zip'\n",
292
+ " output_csvfile = input_zipfile.parent / f'{input_zipfile.stem}.csv'\n",
293
+ "\n",
294
+ " logger.info(f'Writing outputs to {output_zipfile.name} and {output_csvfile.name} in {input_zipfile.parent} folder')\n",
295
+ " with zipfile_writer(output_zipfile) as add_to_zip, csvfile_writer(output_csvfile) as write_row:\n",
296
+ " for filename, im in tqdm(images_from_zipfile(input_zipfile)):\n",
297
+ " boxes, _ = run(im, text)\n",
298
+ " count, heatmap = get_output(im, boxes)\n",
299
+ " write_row({'filename': filename, 'count': count})\n",
300
+ " add_to_zip(heatmap, filename)"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "markdown",
305
+ "metadata": {
306
+ "id": "TmqsSxrsuHb_"
307
+ },
308
+ "source": [
309
+ "### Run\n",
310
+ "\n",
311
+ "Use the form on colab to set the parameters, providing the zipfile with input images and a promt text representing the object you want to count.\n",
312
+ "\n",
313
+ "If you are not running on colab, change the values in the next cell\n",
314
+ "\n",
315
+ "Make sure to run the cell once you change the value."
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": 8,
321
+ "metadata": {
322
+ "id": "ZaN918EkuHb_"
323
+ },
324
+ "outputs": [],
325
+ "source": [
326
+ "# @title ## Parameters { display-mode: \"form\", run: \"auto\" }\n",
327
+ "# @markdown Set the following options to pass to the CountGD Model\n",
328
+ "\n",
329
+ "# @markdown ---\n",
330
+ "# @markdown ### Enter a file path to a zip:\n",
331
+ "zipfile_path = \"test_images.zip\" # @param {type:\"string\"}\n",
332
+ "# @markdown\n",
333
+ "# @markdown ### Which object would you like to count?\n",
334
+ "prompt = \"strawberry\" # @param {type:\"string\"}\n",
335
+ "# @markdown ---"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "metadata": {
342
+ "colab": {
343
+ "base_uri": "https://localhost:8080/",
344
+ "height": 66,
345
+ "referenced_widgets": [
346
+ "b14c910dd2594285bb4ad4740099e70c",
347
+ "01631442369e43138c2c5c4a9fe38ceb",
348
+ "ff84907ef88a431bab4bd3d1567cc42a"
349
+ ]
350
+ },
351
+ "id": "fd-ShBCsuHb_",
352
+ "outputId": "5b36bb90-ac6e-46fe-a853-ff11d43dd9f6"
353
+ },
354
+ "outputs": [],
355
+ "source": [
356
+ "import ipywidgets as widgets\n",
357
+ "from IPython.display import display\n",
358
+ "button = widgets.Button(description=\"Run\")\n",
359
+ "\n",
360
+ "def on_button_clicked(b):\n",
361
+ " # Display the message within the output widget.\n",
362
+ " process_zipfile(Path(zipfile_path), prompt)\n",
363
+ "\n",
364
+ "button.on_click(on_button_clicked)\n",
365
+ "display(button)"
366
+ ]
367
+ }
368
+ ],
369
+ "metadata": {
370
+ "accelerator": "GPU",
371
+ "colab": {
372
+ "collapsed_sections": [
373
+ "gfjraK3vuHb_"
374
+ ],
375
+ "gpuType": "T4",
376
+ "provenance": []
377
+ },
378
+ "kernelspec": {
379
+ "display_name": "env",
380
+ "language": "python",
381
+ "name": "python3"
382
+ },
383
+ "language_info": {
384
+ "codemirror_mode": {
385
+ "name": "ipython",
386
+ "version": 3
387
+ },
388
+ "file_extension": ".py",
389
+ "mimetype": "text/x-python",
390
+ "name": "python",
391
+ "nbconvert_exporter": "python",
392
+ "pygments_lexer": "ipython3",
393
+ "version": "3.12.7"
394
+ },
395
+ "widgets": {
396
+ "application/vnd.jupyter.widget-state+json": {
397
+ "01631442369e43138c2c5c4a9fe38ceb": {
398
+ "model_module": "@jupyter-widgets/base",
399
+ "model_module_version": "1.2.0",
400
+ "model_name": "LayoutModel",
401
+ "state": {
402
+ "_model_module": "@jupyter-widgets/base",
403
+ "_model_module_version": "1.2.0",
404
+ "_model_name": "LayoutModel",
405
+ "_view_count": null,
406
+ "_view_module": "@jupyter-widgets/base",
407
+ "_view_module_version": "1.2.0",
408
+ "_view_name": "LayoutView",
409
+ "align_content": null,
410
+ "align_items": null,
411
+ "align_self": null,
412
+ "border": null,
413
+ "bottom": null,
414
+ "display": null,
415
+ "flex": null,
416
+ "flex_flow": null,
417
+ "grid_area": null,
418
+ "grid_auto_columns": null,
419
+ "grid_auto_flow": null,
420
+ "grid_auto_rows": null,
421
+ "grid_column": null,
422
+ "grid_gap": null,
423
+ "grid_row": null,
424
+ "grid_template_areas": null,
425
+ "grid_template_columns": null,
426
+ "grid_template_rows": null,
427
+ "height": null,
428
+ "justify_content": null,
429
+ "justify_items": null,
430
+ "left": null,
431
+ "margin": null,
432
+ "max_height": null,
433
+ "max_width": null,
434
+ "min_height": null,
435
+ "min_width": null,
436
+ "object_fit": null,
437
+ "object_position": null,
438
+ "order": null,
439
+ "overflow": null,
440
+ "overflow_x": null,
441
+ "overflow_y": null,
442
+ "padding": null,
443
+ "right": null,
444
+ "top": null,
445
+ "visibility": null,
446
+ "width": null
447
+ }
448
+ },
449
+ "b14c910dd2594285bb4ad4740099e70c": {
450
+ "model_module": "@jupyter-widgets/controls",
451
+ "model_module_version": "1.5.0",
452
+ "model_name": "ButtonModel",
453
+ "state": {
454
+ "_dom_classes": [],
455
+ "_model_module": "@jupyter-widgets/controls",
456
+ "_model_module_version": "1.5.0",
457
+ "_model_name": "ButtonModel",
458
+ "_view_count": null,
459
+ "_view_module": "@jupyter-widgets/controls",
460
+ "_view_module_version": "1.5.0",
461
+ "_view_name": "ButtonView",
462
+ "button_style": "",
463
+ "description": "Run",
464
+ "disabled": false,
465
+ "icon": "",
466
+ "layout": "IPY_MODEL_01631442369e43138c2c5c4a9fe38ceb",
467
+ "style": "IPY_MODEL_ff84907ef88a431bab4bd3d1567cc42a",
468
+ "tooltip": ""
469
+ }
470
+ },
471
+ "ff84907ef88a431bab4bd3d1567cc42a": {
472
+ "model_module": "@jupyter-widgets/controls",
473
+ "model_module_version": "1.5.0",
474
+ "model_name": "ButtonStyleModel",
475
+ "state": {
476
+ "_model_module": "@jupyter-widgets/controls",
477
+ "_model_module_version": "1.5.0",
478
+ "_model_name": "ButtonStyleModel",
479
+ "_view_count": null,
480
+ "_view_module": "@jupyter-widgets/base",
481
+ "_view_module_version": "1.2.0",
482
+ "_view_name": "StyleView",
483
+ "button_color": null,
484
+ "font_weight": ""
485
+ }
486
+ }
487
+ }
488
+ }
489
+ },
490
+ "nbformat": 4,
491
+ "nbformat_minor": 0
492
+ }
requirements.txt CHANGED
@@ -12,6 +12,8 @@ ushlex
12
  gradio>=4.0.0,<5
13
  gradio_image_prompter-0.1.0-py3-none-any.whl
14
  spaces
 
 
15
  --extra-index-url https://download.pytorch.org/whl/cu121
16
  torch<2.6
17
  torchvision
 
12
  gradio>=4.0.0,<5
13
  gradio_image_prompter-0.1.0-py3-none-any.whl
14
  spaces
15
+ filetype
16
+ tqdm
17
  --extra-index-url https://download.pytorch.org/whl/cu121
18
  torch<2.6
19
  torchvision