chaoxu commited on
Commit
703e611
·
1 Parent(s): 3344730

Support API calls

Browse files
Files changed (2) hide show
  1. app.py +140 -13
  2. instructions_12345.md +32 -0
app.py CHANGED
@@ -12,17 +12,13 @@ elev_est_dir = os.path.abspath(os.path.join(code_dir, "one2345_elev_est"))
12
  sys.path.append(elev_est_dir)
13
 
14
  if not is_local_run:
15
- # import pip
16
- # pip.main(['install', elev_est_dir])
17
  # export TORCH_CUDA_ARCH_LIST="7.0;7.2;8.0;8.6"
18
  # export IABN_FORCE_CUDA=1
19
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
20
  os.environ["IABN_FORCE_CUDA"] = "1"
21
  os.environ["FORCE_CUDA"] = "1"
22
- # pip.main(["install", "inplace_abn"])
23
  subprocess.run(['pip', 'install', 'inplace_abn'])
24
  # FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/[email protected]
25
- # pip.main(["install", "--no-cache-dir", "git+https://github.com/mit-han-lab/[email protected]"])
26
  subprocess.run(['pip', 'install', '--no-cache-dir', 'git+https://github.com/mit-han-lab/[email protected]'])
27
 
28
  import shutil
@@ -154,12 +150,9 @@ class CameraVisualizer:
154
 
155
  self._raw_image = raw_image
156
  self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None)
157
- # self._8bit_image = Image.fromarray(raw_image.clip(0, 254)).convert(
158
- # 'P', palette='WEB', dither=None)
159
  self._image_colorscale = [
160
  [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)]
161
  self._elev = elev
162
- # return self.update_figure()
163
 
164
  def update_figure(self):
165
  fig = go.Figure()
@@ -243,9 +236,6 @@ class CameraVisualizer:
243
 
244
  # look at center of scene
245
  fig.update_layout(
246
- # width=640,
247
- # height=480,
248
- # height=400,
249
  height=450,
250
  autosize=True,
251
  hovermode=False,
@@ -312,7 +302,7 @@ def stage1_run(models, device, cam_vis, tmp_dir,
312
  stage2_steps = 50 # ddim_steps
313
  zero123_infer(model, tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
314
  try:
315
- elev_output = estimate_elev(tmp_dir)
316
  except:
317
  print("Failed to estimate polar angle")
318
  elev_output = 90
@@ -459,6 +449,119 @@ def init_bbox(image):
459
  gr.update(value=x_max, maximum=width),
460
  gr.update(value=y_max, maximum=height)]
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
  def run_demo(
464
  device_idx=_GPU_INDEX,
@@ -473,7 +576,6 @@ def run_demo(
473
  with open('instructions_12345.md', 'r') as f:
474
  article = f.read()
475
 
476
- # NOTE: Examples must match inputs
477
  example_folder = os.path.join(os.path.dirname(__file__), 'demo_examples')
478
  example_fns = os.listdir(example_folder)
479
  example_fns.sort()
@@ -494,7 +596,7 @@ def run_demo(
494
  image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image', tool=None)
495
 
496
  gr.Examples(
497
- examples=examples_full, # NOTE: elements must match inputs list!
498
  inputs=[image_block],
499
  outputs=[image_block],
500
  cache_examples=False,
@@ -569,6 +671,31 @@ def run_demo(
569
  </div>
570
  """)
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
573
 
574
  views = [view_1, view_2, view_3, view_4, view_5, view_6, view_7, view_8]
 
12
  sys.path.append(elev_est_dir)
13
 
14
  if not is_local_run:
 
 
15
  # export TORCH_CUDA_ARCH_LIST="7.0;7.2;8.0;8.6"
16
  # export IABN_FORCE_CUDA=1
17
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
18
  os.environ["IABN_FORCE_CUDA"] = "1"
19
  os.environ["FORCE_CUDA"] = "1"
 
20
  subprocess.run(['pip', 'install', 'inplace_abn'])
21
  # FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/[email protected]
 
22
  subprocess.run(['pip', 'install', '--no-cache-dir', 'git+https://github.com/mit-han-lab/[email protected]'])
23
 
24
  import shutil
 
150
 
151
  self._raw_image = raw_image
152
  self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None)
 
 
153
  self._image_colorscale = [
154
  [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)]
155
  self._elev = elev
 
156
 
157
  def update_figure(self):
158
  fig = go.Figure()
 
236
 
237
  # look at center of scene
238
  fig.update_layout(
 
 
 
239
  height=450,
240
  autosize=True,
241
  hovermode=False,
 
302
  stage2_steps = 50 # ddim_steps
303
  zero123_infer(model, tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
304
  try:
305
+ elev_output = int(estimate_elev(tmp_dir))
306
  except:
307
  print("Failed to estimate polar angle")
308
  elev_output = 90
 
449
  gr.update(value=x_max, maximum=width),
450
  gr.update(value=y_max, maximum=height)]
451
 
452
+ ### API functions
453
+ def preprocess_api(predictor, raw_im):
454
+ raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
455
+ image_rem = raw_im.convert('RGBA')
456
+ image_nobg = remove(image_rem, alpha_matting=True)
457
+ arr = np.asarray(image_nobg)[:,:,-1]
458
+ x_nonzero = np.nonzero(arr.sum(axis=0))
459
+ y_nonzero = np.nonzero(arr.sum(axis=1))
460
+ x_min = int(x_nonzero[0].min())
461
+ y_min = int(y_nonzero[0].min())
462
+ x_max = int(x_nonzero[0].max())
463
+ y_max = int(y_nonzero[0].max())
464
+ image_sam = sam_out_nosave(predictor, raw_im.convert("RGB"), x_min, y_min, x_max, y_max)
465
+ input_256 = image_preprocess_nosave(image_sam, lower_contrast=False, rescale=True)
466
+ torch.cuda.empty_cache()
467
+ return input_256
468
+
469
+ def estimate_elev_api(models, device, predictor,
470
+ input_im, preprocess=True, scale=3, ddim_steps=50):
471
+ model = models['turncam'].half()
472
+ tmp_dir = tempfile.TemporaryDirectory(dir=os.path.join(os.path.dirname(__file__), 'demo_tmp')).name
473
+ stage1_dir = os.path.join(tmp_dir, "stage1_8")
474
+ os.makedirs(stage1_dir, exist_ok=True)
475
+ if preprocess:
476
+ input_im = preprocess_api(predictor, input_im)
477
+ input_image = input_im.convert("RGB")
478
+ output_ims = predict_stage1_gradio(model, input_image, save_path=stage1_dir, adjust_set=[0], device=device, ddim_steps=ddim_steps, scale=scale)
479
+ stage2_steps = 50 # ddim_steps
480
+ zero123_infer(model, tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
481
+ try:
482
+ polar_angle = int(estimate_elev(tmp_dir))
483
+ except:
484
+ print("Failed to estimate polar angle")
485
+ polar_angle = 90
486
+ print("Estimated polar angle:", polar_angle)
487
+ return 90-polar_angle
488
+
489
+ def convert_mesh_format(exp_dir, output_format=".obj"):
490
+ ply_path = os.path.join(exp_dir, f"meshes_val_bg/lod0/mesh_00215000_gradio_lod0.ply")
491
+ mesh_path = os.path.join(exp_dir, f"mesh{output_format}")
492
+ mesh = trimesh.load_mesh(ply_path)
493
+ rotation_matrix = trimesh.transformations.rotation_matrix(np.pi/2, [1, 0, 0]) @ trimesh.transformations.rotation_matrix(np.pi, [0, 0, 1])
494
+ mesh.apply_transform(rotation_matrix)
495
+ mesh.vertices[:, 0] = -mesh.vertices[:, 0]
496
+ mesh.faces = np.fliplr(mesh.faces)
497
+ if output_format == ".obj":
498
+ # Export the mesh as .obj file with colors
499
+ mesh.export(mesh_path, file_type='obj', include_color=True)
500
+ else:
501
+ mesh.export(mesh_path, file_type='glb')
502
+ return mesh_path
503
+
504
+ def reconstruct(exp_dir, output_format=".ply", device_idx=0):
505
+
506
+ main_dir_path = os.path.dirname(__file__)
507
+ torch.cuda.empty_cache()
508
+ os.chdir(os.path.join(code_dir, 'SparseNeuS_demo_v1/'))
509
+
510
+ bash_script = f'CUDA_VISIBLE_DEVICES={device_idx} python exp_runner_generic_blender_val.py \
511
+ --specific_dataset_name {exp_dir} \
512
+ --mode export_mesh \
513
+ --conf confs/one2345_lod0_val_demo.conf'
514
+ print(bash_script)
515
+ os.system(bash_script)
516
+ os.chdir(main_dir_path)
517
+
518
+ ply_path = os.path.join(exp_dir, f"meshes_val_bg/lod0/mesh_00215000_gradio_lod0.ply")
519
+ if output_format == ".ply":
520
+ return ply_path
521
+ if output_format not in [".obj", ".glb"]:
522
+ print("Invalid output format, must be one of .ply, .obj, .glb")
523
+ return ply_path
524
+ return convert_mesh_format(exp_dir, output_format=output_format)
525
+
526
+ def gen_mesh_api(models, predictor, device,
527
+ input_im, preprocess=True, scale=3, ddim_steps=75, stage2_steps=50):
528
+ if preprocess:
529
+ input_im = preprocess_api(predictor, input_im)
530
+ model = models['turncam'].half()
531
+ # folder to save the stage 1 images
532
+ exp_dir = tempfile.TemporaryDirectory(dir=os.path.join(os.path.dirname(__file__), 'demo_tmp')).name
533
+ stage1_dir = os.path.join(exp_dir, "stage1_8")
534
+ os.makedirs(stage1_dir, exist_ok=True)
535
+
536
+ # stage 1: generate 4 views at the same elevation as the input
537
+ output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4)), device=device, ddim_steps=ddim_steps, scale=scale)
538
+
539
+ # stage 2 for the first image
540
+ # infer 4 nearby views for an image to estimate the polar angle of the input
541
+ stage2_steps = 50 # ddim_steps
542
+ zero123_infer(model, exp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
543
+ # estimate the camera pose (elevation) of the input image.
544
+ try:
545
+ polar_angle = int(estimate_elev(exp_dir))
546
+ except:
547
+ print("Failed to estimate polar angle")
548
+ polar_angle = 90
549
+ print("Estimated polar angle:", polar_angle)
550
+ gen_poses(exp_dir, polar_angle)
551
+
552
+ # stage 1: generate another 4 views at a different elevation
553
+ if polar_angle <= 75:
554
+ output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4,8)), device=device, ddim_steps=ddim_steps, scale=scale)
555
+ else:
556
+ output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(8,12)), device=device, ddim_steps=ddim_steps, scale=scale)
557
+ torch.cuda.empty_cache()
558
+ # stage 2 for the remaining 7 images, generate 7*4=28 views
559
+ if polar_angle <= 75:
560
+ zero123_infer(model, exp_dir, indices=list(range(1,8)), device=device, ddim_steps=stage2_steps, scale=scale)
561
+ else:
562
+ zero123_infer(model, exp_dir, indices=list(range(1,4))+list(range(8,12)), device=device, ddim_steps=stage2_steps, scale=scale)
563
+ return reconstruct(exp_dir)
564
+
565
 
566
  def run_demo(
567
  device_idx=_GPU_INDEX,
 
576
  with open('instructions_12345.md', 'r') as f:
577
  article = f.read()
578
 
 
579
  example_folder = os.path.join(os.path.dirname(__file__), 'demo_examples')
580
  example_fns = os.listdir(example_folder)
581
  example_fns.sort()
 
596
  image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image', tool=None)
597
 
598
  gr.Examples(
599
+ examples=examples_full,
600
  inputs=[image_block],
601
  outputs=[image_block],
602
  cache_examples=False,
 
671
  </div>
672
  """)
673
 
674
+ # hidden buttons for supporting API calls
675
+ elev_est_btn = gr.Button('Run API', variant='primary', visible=False)
676
+ elev_est_out = gr.Number(value=0, visible=False)
677
+ elev_preprocess_chk = gr.Checkbox(value=True, visible=False)
678
+
679
+ elev_est_btn.click(fn=partial(estimate_elev_api, models, device, predictor),
680
+ inputs=[image_block, elev_preprocess_chk],
681
+ outputs=[elev_est_out],
682
+ api_name='estimate_elevation',
683
+ queue=True)
684
+
685
+ preprocess_btn = gr.Button('Run API', variant='primary', visible=False)
686
+ preprocess_btn.click(fn=partial(preprocess_api, predictor),
687
+ inputs=[image_block],
688
+ outputs=[sam_block],
689
+ api_name='preprocess',
690
+ queue=True)
691
+
692
+ gen_mesh_btn = gr.Button('Run API', variant='primary', visible=False)
693
+ gen_mesh_btn.click(fn=partial(gen_mesh_api, models, predictor, device),
694
+ inputs=[image_block, elev_preprocess_chk],
695
+ outputs=[mesh_output],
696
+ api_name='generate_mesh',
697
+ queue=True)
698
+
699
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
700
 
701
  views = [view_1, view_2, view_3, view_4, view_5, view_6, view_7, view_8]
instructions_12345.md CHANGED
@@ -1,3 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ## Tuning Tips:
2
 
3
  1. The multi-view prediction module (Zero123) operates probabilistically. If some of the predicted views are not satisfactory, you may select and regenerate them.
 
1
+ ## APIs:
2
+ <details>
3
+ <summary>We offer handy APIs for our pipeline and its components.</summary>
4
+
5
+ ```python
6
+ from gradio_client import Client
7
+ client = Client("https://one-2-3-45-one-2-3-45.hf.space/")
8
+ input_img_path = "https://huggingface.co/spaces/One-2-3-45/One-2-3-45/resolve/main/demo_examples/01_wild_hydrant.png"
9
+
10
+ ### Single image to 3D mesh
11
+ generated_mesh_filepath = client.predict(
12
+ input_img_path,
13
+ True, # image preprocessing
14
+ api_name="/generate_mesh"
15
+ )
16
+
17
+ ### Elevation estimation
18
+ # DON'T TO ASK USERS TO ESTIMATE ELEVATION! This OFF-THE-SHELF algorithm is ALL YOU NEED!
19
+ elevation_angle_deg = client.predict(
20
+ input_img_path,
21
+ True, # image preprocessing
22
+ api_name="/estimate_elevation"
23
+ )
24
+
25
+ ### Image preprocessing: segment, rescale, and recenter
26
+ segmented_img_filepath = client.predict(
27
+ input_img_path,
28
+ api_name="/preprocess"
29
+ )
30
+ ```
31
+ </details>
32
+
33
  ## Tuning Tips:
34
 
35
  1. The multi-view prediction module (Zero123) operates probabilistically. If some of the predicted views are not satisfactory, you may select and regenerate them.