justin-zk commited on
Commit
cdccdd2
1 Parent(s): 68318f4

update code

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -88,8 +88,8 @@ def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float =
88
 
89
  def inference(ic_image, ic_mask, image1, image2):
90
  # in context image and mask
91
- ic_image = cv2.cvtColor(ic_image, cv2.COLOR_BGR2RGB)
92
- ic_make = cv2.cvtColor(ic_image,cv2.COLOR_BGR2RGB)
93
 
94
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
95
  sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
@@ -114,7 +114,7 @@ def inference(ic_image, ic_mask, image1, image2):
114
 
115
  for test_image in [image1, image2]:
116
  print("======> Testing Image" )
117
- test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
118
 
119
  # Image feature encoding
120
  predictor.set_image(test_image)
@@ -188,8 +188,8 @@ def inference_scribble(image, image1, image2):
188
  # in context image and mask
189
  ic_image = image["image"]
190
  ic_mask = image["mask"]
191
- ic_image = cv2.cvtColor(ic_image, cv2.COLOR_BGR2RGB)
192
- ic_make = cv2.cvtColor(ic_image,cv2.COLOR_BGR2RGB)
193
 
194
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
195
  sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
@@ -214,7 +214,7 @@ def inference_scribble(image, image1, image2):
214
 
215
  for test_image in [image1, image2]:
216
  print("======> Testing Image" )
217
- test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
218
 
219
  # Image feature encoding
220
  predictor.set_image(test_image)
@@ -286,8 +286,8 @@ def inference_scribble(image, image1, image2):
286
 
287
  def inference_finetune(ic_image, ic_mask, image1, image2):
288
  # in context image and mask
289
- ic_image = cv2.cvtColor(ic_image, cv2.COLOR_BGR2RGB)
290
- ic_make = cv2.cvtColor(ic_image,cv2.COLOR_BGR2RGB)
291
 
292
  gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
293
  gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
@@ -377,7 +377,7 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
377
  output_image = []
378
 
379
  for test_image in [image1, image2]:
380
- test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
381
 
382
  # Image feature encoding
383
  predictor.set_image(test_image)
@@ -466,14 +466,14 @@ description = """
466
  main = gr.Interface(
467
  fn=inference,
468
  inputs=[
469
- gr.Image(label="in context image",),
470
- gr.Image(label="in context mask"),
471
- gr.Image(label="test image1"),
472
- gr.Image(label="test image2"),
473
  ],
474
  outputs=[
475
- gr.Image(label="output image1").style(height=256, width=256),
476
- gr.Image(label="output image2").style(height=256, width=256),
477
  ],
478
  allow_flagging="never",
479
  cache_examples=False,
@@ -490,13 +490,13 @@ main = gr.Interface(
490
  main_scribble = gr.Interface(
491
  fn=inference_scribble,
492
  inputs=[
493
- gr.ImageMask(label="[Stroke] Draw on Image"),
494
- gr.Image(label="test image1"),
495
- gr.Image(label="test image2"),
496
  ],
497
  outputs=[
498
- gr.Image(label="output image1").style(height=256, width=256),
499
- gr.Image(label="output image2").style(height=256, width=256),
500
  ],
501
  allow_flagging="never",
502
  cache_examples=False,
@@ -510,17 +510,18 @@ main_scribble = gr.Interface(
510
  )
511
  """
512
 
 
513
  main_finetune = gr.Interface(
514
  fn=inference_finetune,
515
  inputs=[
516
- gr.Image(label="in context image",),
517
- gr.Image(label="in context mask"),
518
- gr.Image(label="test image1"),
519
- gr.Image(label="test image2"),
520
  ],
521
  outputs=[
522
- gr.Image(label="output image1").style(height=256, width=256),
523
- gr.Image(label="output image2").style(height=256, width=256),
524
  ],
525
  allow_flagging="never",
526
  cache_examples=False,
 
88
 
89
  def inference(ic_image, ic_mask, image1, image2):
90
  # in context image and mask
91
+ ic_image = np.array(ic_image.convert("RGB"))
92
+ ic_mask = np.array(ic_mask.convert("RGB"))
93
 
94
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
95
  sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
 
114
 
115
  for test_image in [image1, image2]:
116
  print("======> Testing Image" )
117
+ test_image = np.array(test_image.convert("RGB"))
118
 
119
  # Image feature encoding
120
  predictor.set_image(test_image)
 
188
  # in context image and mask
189
  ic_image = image["image"]
190
  ic_mask = image["mask"]
191
+ ic_image = np.array(ic_image.convert("RGB"))
192
+ ic_mask = np.array(ic_mask.convert("RGB"))
193
 
194
  sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
195
  sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
 
214
 
215
  for test_image in [image1, image2]:
216
  print("======> Testing Image" )
217
+ test_image = np.array(test_image.convert("RGB"))
218
 
219
  # Image feature encoding
220
  predictor.set_image(test_image)
 
286
 
287
  def inference_finetune(ic_image, ic_mask, image1, image2):
288
  # in context image and mask
289
+ ic_image = np.array(ic_image.convert("RGB"))
290
+ ic_mask = np.array(ic_mask.convert("RGB"))
291
 
292
  gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
293
  gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
 
377
  output_image = []
378
 
379
  for test_image in [image1, image2]:
380
+ test_image = np.array(test_image.convert("RGB"))
381
 
382
  # Image feature encoding
383
  predictor.set_image(test_image)
 
466
  main = gr.Interface(
467
  fn=inference,
468
  inputs=[
469
+ gr.Image(label="in context image", type='pil'),
470
+ gr.Image(label="in context mask", type='pil'),
471
+ gr.Image(label="test image1", type='pil'),
472
+ gr.Image(label="test image2", type='pil'),
473
  ],
474
  outputs=[
475
+ gr.Image(label="output image1", type='pil').style(height=256, width=256),
476
+ gr.Image(label="output image2", type='pil').style(height=256, width=256),
477
  ],
478
  allow_flagging="never",
479
  cache_examples=False,
 
490
  main_scribble = gr.Interface(
491
  fn=inference_scribble,
492
  inputs=[
493
+ gr.ImageMask(label="[Stroke] Draw on Image", brush_radius=4, type='pil'),
494
+ gr.Image(label="test image1", type='pil'),
495
+ gr.Image(label="test image2", type='pil'),
496
  ],
497
  outputs=[
498
+ gr.Image(label="output image1", type='pil').style(height=256, width=256),
499
+ gr.Image(label="output image2", type='pil').style(height=256, width=256),
500
  ],
501
  allow_flagging="never",
502
  cache_examples=False,
 
510
  )
511
  """
512
 
513
+
514
  main_finetune = gr.Interface(
515
  fn=inference_finetune,
516
  inputs=[
517
+ gr.Image(label="in context image", type='pil'),
518
+ gr.Image(label="in context mask", type='pil'),
519
+ gr.Image(label="test image1", type='pil'),
520
+ gr.Image(label="test image2", type='pil'),
521
  ],
522
  outputs=[
523
+ gr.Image(label="output image1", type='pil').style(height=256, width=256),
524
+ gr.Image(label="output image2", type='pil').style(height=256, width=256),
525
  ],
526
  allow_flagging="never",
527
  cache_examples=False,