Weiyu Liu commited on
Commit
8c02843
·
1 Parent(s): a77a4ae
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +131 -0
  2. configs/base.yaml +3 -0
  3. configs/conditional_pose_diffusion.yaml +81 -0
  4. configs/pairwise_collision.yaml +42 -0
  5. data/data00000000.h5 +3 -0
  6. data/data00000002.h5 +3 -0
  7. data/data00000003.h5 +3 -0
  8. data/data00000004.h5 +3 -0
  9. data/data00000006.h5 +3 -0
  10. data/data00000008.h5 +3 -0
  11. data/data00000009.h5 +3 -0
  12. data/data00000012.h5 +3 -0
  13. data/data00000013.h5 +3 -0
  14. data/data00000015.h5 +3 -0
  15. data/type_vocabs_coarse.json +1 -0
  16. packages.txt +1 -0
  17. requirements.txt +13 -0
  18. scripts/infer.py +78 -0
  19. scripts/infer_with_discriminator.py +81 -0
  20. scripts/train_discriminator.py +46 -0
  21. scripts/train_generator.py +49 -0
  22. src/StructDiffusion/__init__.py +0 -0
  23. src/StructDiffusion/__pycache__/__init__.cpython-37.pyc +0 -0
  24. src/StructDiffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  25. src/StructDiffusion/data/__init__.py +0 -0
  26. src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc +0 -0
  27. src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc +0 -0
  28. src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc +0 -0
  29. src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc +0 -0
  30. src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc +0 -0
  31. src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc +0 -0
  32. src/StructDiffusion/data/pairwise_collision.py +361 -0
  33. src/StructDiffusion/data/semantic_arrangement.py +579 -0
  34. src/StructDiffusion/data/semantic_arrangement_demo.py +563 -0
  35. src/StructDiffusion/diffusion/__init__.py +0 -0
  36. src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc +0 -0
  37. src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  38. src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc +0 -0
  39. src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc +0 -0
  40. src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc +0 -0
  41. src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc +0 -0
  42. src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc +0 -0
  43. src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc +0 -0
  44. src/StructDiffusion/diffusion/noise_schedule.py +81 -0
  45. src/StructDiffusion/diffusion/pose_conversion.py +103 -0
  46. src/StructDiffusion/diffusion/sampler.py +296 -0
  47. src/StructDiffusion/language/__init__.py +0 -0
  48. src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc +0 -0
  49. src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc +0 -0
  50. src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc +0 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import gradio as gr
8
+ from omegaconf import OmegaConf
9
+
10
+ import sys
11
+ sys.path.append('./src')
12
+
13
+ from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
14
+ from StructDiffusion.language.tokenizer import Tokenizer
15
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
16
+ from StructDiffusion.diffusion.sampler import Sampler
17
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
+ from StructDiffusion.utils.files import get_checkpoint_path_from_dir
19
+ from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
20
+ from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
21
+
22
+
23
+ class Infer_Wrapper:
24
+
25
+ def __init__(self, args, cfg):
26
+
27
+ # load
28
+ pl.seed_everything(args.eval_random_seed)
29
+ self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
30
+
31
+ checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
32
+ checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
33
+
34
+ self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
35
+ # override ignore_rgb for visualization
36
+ cfg.DATASET.ignore_rgb = False
37
+ self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
38
+
39
+ self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
40
+
41
+ def run(self, di):
42
+
43
+ # di = np.random.choice(len(self.dataset))
44
+
45
+ raw_datum = self.dataset.get_raw_data(di)
46
+ print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
47
+ datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
48
+ batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
49
+
50
+ num_poses = datum["goal_poses"].shape[0]
51
+ xs = self.sampler.sample(batch, num_poses)
52
+
53
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
54
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
55
+
56
+ # vis
57
+ vis_obj_xyzs = new_obj_xyzs[:3]
58
+ if torch.is_tensor(vis_obj_xyzs):
59
+ if vis_obj_xyzs.is_cuda:
60
+ vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
61
+ vis_obj_xyzs = vis_obj_xyzs.numpy()
62
+
63
+ # for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
64
+ # if verbose:
65
+ # print("example {}".format(bi))
66
+ # print(vis_obj_xyz.shape)
67
+ #
68
+ # if trimesh:
69
+ # show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
70
+ vis_obj_xyz = vis_obj_xyzs[0]
71
+ scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
72
+
73
+ scene_filename = "./tmp_data/scene.glb"
74
+ scene.export(scene_filename)
75
+
76
+ # pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
77
+ # scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
78
+ #
79
+ # vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
80
+ # vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
81
+ # vis_pc.export(pc_filename)
82
+ #
83
+ # scene = trimesh.Scene()
84
+ # # add the coordinate frame first
85
+ # # geom = trimesh.creation.axis(0.01)
86
+ # # scene.add_geometry(geom)
87
+ # table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
88
+ # table.apply_translation([0.5, 0, -0.01])
89
+ # table.visual.vertex_colors = [150, 111, 87, 125]
90
+ # scene.add_geometry(table)
91
+ # # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
92
+ # # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
93
+ # # bounds.apply_translation([0, 0, 0])
94
+ # # bounds.visual.vertex_colors = [30, 30, 30, 30]
95
+ # # scene.add_geometry(bounds)
96
+ # # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
97
+ # # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
98
+ # # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
99
+ # # [0.0, 0.0, 0.0, 1.0]])
100
+ # # RT_4x4 = np.linalg.inv(RT_4x4)
101
+ # # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
102
+ # # scene.camera_transform = RT_4x4
103
+ #
104
+ # mesh_list = trimesh.util.concatenate(scene.dump())
105
+ # print(mesh_list)
106
+ # trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
107
+
108
+ return scene_filename
109
+
110
+
111
+ args = OmegaConf.create()
112
+ args.base_config_file = "./configs/base.yaml"
113
+ args.config_file = "./configs/conditional_pose_diffusion.yaml"
114
+ args.checkpoint_id = "ConditionalPoseDiffusion"
115
+ args.eval_random_seed = 42
116
+ args.num_samples = 1
117
+
118
+ base_cfg = OmegaConf.load(args.base_config_file)
119
+ cfg = OmegaConf.load(args.config_file)
120
+ cfg = OmegaConf.merge(base_cfg, cfg)
121
+
122
+ infer_wrapper = Infer_Wrapper(args, cfg)
123
+
124
+ demo = gr.Interface(
125
+ fn=infer_wrapper.run,
126
+ inputs=gr.Slider(0, len(infer_wrapper.dataset)),
127
+ # clear color range [0-1.0]
128
+ outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
129
+ )
130
+
131
+ demo.launch()
configs/base.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_dirs:
2
+ data: data
3
+ wandb_dir: wandb_logs
configs/conditional_pose_diffusion.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ random_seed: 1
2
+
3
+ WANDB:
4
+ project: StructDiffusion
5
+ save_dir: ${base_dirs.wandb_dir}
6
+ name: conditional_pose_diffusion
7
+
8
+ DATASET:
9
+ data_root: ${base_dirs.data}
10
+ vocab_dir: ${base_dirs.data}/type_vocabs_coarse.json
11
+
12
+ # important
13
+ use_virtual_structure_frame: True
14
+ ignore_distractor_objects: True
15
+ ignore_rgb: True
16
+
17
+ # the following are determined by the dataset
18
+ max_num_target_objects: 7
19
+ max_num_distractor_objects: 5
20
+ max_num_shape_parameters: 5
21
+ # set to zeros because they are not used for now
22
+ max_num_rearrange_features: 0
23
+ max_num_anchor_features: 0
24
+
25
+ num_pts: 1024
26
+ filter_num_moved_objects_range:
27
+ data_augmentation: False
28
+
29
+ DATALOADER:
30
+ batch_size: 64
31
+ num_workers: 8
32
+ pin_memory: True
33
+
34
+ MODEL:
35
+ # transformer encoder
36
+ encoder_input_dim: 256
37
+ num_attention_heads: 8
38
+ encoder_hidden_dim: 512
39
+ encoder_dropout: 0.0
40
+ encoder_activation: relu
41
+ encoder_num_layers: 8
42
+ # output head
43
+ structure_dropout: 0
44
+ object_dropout: 0
45
+ # pc encoder
46
+ ignore_rgb: ${DATASET.ignore_rgb}
47
+ pc_emb_dim: 256
48
+ posed_pc_emb_dim: 80
49
+ # pose encoder
50
+ pose_emb_dim: 80
51
+ # language
52
+ word_emb_dim: 160
53
+ # diffusion step
54
+ time_emb_dim: 80
55
+ # sequence embeddings
56
+ # max_num_target_objects (+ max_num_distractor_objects if not ignore_distractor_objects)
57
+ max_seq_size: 7
58
+ max_token_type_size: 4
59
+ seq_pos_emb_dim: 8
60
+ seq_type_emb_dim: 8
61
+ # virtual frame
62
+ use_virtual_structure_frame: ${DATASET.use_virtual_structure_frame}
63
+
64
+ NOISE_SCHEDULE:
65
+ timesteps: 200
66
+
67
+ LOSS:
68
+ type: huber
69
+
70
+ OPTIMIZER:
71
+ lr: 0.0001
72
+ weight_decay: 0 #0.0001
73
+ # lr_restart: 3000
74
+ # warmup: 10
75
+
76
+ TRAINER:
77
+ max_epochs: 200
78
+ gradient_clip_val: 1.0
79
+ gpus: 1
80
+ deterministic: False
81
+ # enable_progress_bar: False
configs/pairwise_collision.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ random_seed: 1
2
+
3
+ WANDB:
4
+ project: StructDiffusion
5
+ save_dir: ${base_dirs.wandb_dir}
6
+ name: pairwise_collision
7
+
8
+ DATASET:
9
+ urdf_pc_idx_file: ${base_dirs.pairwise_collision_data}/urdf_pc_idx.pkl
10
+ collision_data_dir: ${base_dirs.pairwise_collision_data}
11
+
12
+ # important
13
+ num_pts: 1024
14
+ num_scene_pts: 2048
15
+ normalize_pc: True
16
+ random_rotation: True
17
+ data_augmentation: False
18
+
19
+ DATALOADER:
20
+ batch_size: 32
21
+ num_workers: 8
22
+ pin_memory: True
23
+
24
+ MODEL:
25
+ max_num_objects: 2
26
+ include_env_pc: False
27
+ pct_random_sampling: True
28
+
29
+ LOSS:
30
+ type: Focal
31
+ focal_gamma: 2
32
+
33
+ OPTIMIZER:
34
+ lr: 0.0001
35
+ weight_decay: 0
36
+
37
+ TRAINER:
38
+ max_epochs: 200
39
+ gradient_clip_val: 1.0
40
+ gpus: 1
41
+ deterministic: False
42
+ # enable_progress_bar: False
data/data00000000.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:947574252625d338b9f37217eacf61f520136e27b458b6d3e65330339e8b299c
3
+ size 1271489
data/data00000002.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3302432de555fed767c5b0d99c35ca01d5e4ac38cf4a0760b8ccb456b432e0e0
3
+ size 3235242
data/data00000003.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b907ba7c3a17f98a438617b462b2a4d3d3f8593c2dc47feb5a6cc3da8c034fc
3
+ size 2059708
data/data00000004.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8ec0136dd4055d304e9b7f5697b79613099b8f8f1e5eec94281f22d8d47cca1
3
+ size 2591656
data/data00000006.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e74ebf185b0af58df0fa2483d5fd58a12b3b62ccac27ff665f35c5c7a13b8d8
3
+ size 1572332
data/data00000008.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db015354a9d53e6fbaf0b040ce226484150b0af226a5c13a0b9f5cb9961db73c
3
+ size 2167265
data/data00000009.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:990ad13f423d9089b30de81d002d23d9d00cf3e007fd7073793cbec03c456ebb
3
+ size 3607752
data/data00000012.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93161f5666c54dbc259c9efa516b67340613b592a1ed42e6c63d4cc8a495002a
3
+ size 2525622
data/data00000013.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94c9cfe6d9f0df176eb0a3baccdf53c7e6e5fc807e5e7ea9e138ad7159f500d9
3
+ size 1715352
data/data00000015.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aab522f9ace1a03b1705fe3bd693b589971d133d45c232ef9ec53842a540bfa
3
+ size 2647026
data/type_vocabs_coarse.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"class": {"Basket": 0, "BeerBottle": 1, "Book": 2, "Bottle": 3, "Bowl": 4, "Calculator": 5, "Candle": 6, "CellPhone": 7, "ComputerMouse": 8, "Controller": 9, "Cup": 10, "Donut": 11, "Fork": 12, "Hammer": 13, "Knife": 14, "Marker": 15, "MilkCarton": 16, "Mug": 17, "Pan": 18, "Pen": 19, "PillBottle": 20, "Plate": 21, "PowerStrip": 22, "Scissors": 23, "SoapBottle": 24, "SodaCan": 25, "Spoon": 26, "Stapler": 27, "Teapot": 28, "VideoGameController": 29, "WineBottle": 30, "CanOpener":31, "Fruit": 32}, "scene": {"dinner": 0}, "size": {"L": 0, "M": 1, "S": 2}, "color": {"blue": 0, "cyan": 1, "green": 2, "magenta": 3, "red": 4, "yellow": 5}, "material": {"glass": 0, "metal": 1, "plastic": 2}, "comparator": {"less": 1, "greater": 2, "equal": 3}, "radius": [0.0, 0.5, 3], "position_x": [-0.1, 1.0, 3], "position_y": [-0.5, 0.5, 3], "rotation": [-3.15, 3.15, 4], "height": [0.0, 0.5, 10], "volumn": [0.0, 0.015, 10], "uniform_angle": {"False": 0, "True": 1}, "face_center": {"False": 0, "True": 1}, "angle_ratio": {"0.5": 0, "1.0": 1}, "shape": {"circle": 0, "line": 1, "tower": 2, "dinner": 3}, "obj_x": [-1.0, 1.0, 200], "obj_y": [-1.0, 1.0, 200], "obj_z": [-1.0, 1.0, 200], "obj_rr": [-3.15, 3.15, 360], "obj_rp": [-3.15, 3.15, 360], "obj_ry": [-3.15, 3.15, 360],"struct_x": [-1.0, 1.0, 200], "struct_y": [-1.0, 1.0, 200], "struct_z": [-1.0, 1.0, 200], "struct_rr": [-3.15, 3.15, 360], "struct_rp": [-3.15, 3.15, 360], "struct_ry": [-3.15, 3.15, 360]}
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.21
2
+ h5py==2.10.0
3
+ opencv-python
4
+ open3d
5
+ trimesh==3.10.2
6
+ pyglet==1.5.0
7
+ pybullet==3.1.7
8
+ nvisii==1.1.70
9
+ openpyxl
10
+ pytorch_lightning==1.6.1
11
+ wandb===0.13.10
12
+ pytorch3d==0.3.0
13
+ omegaconf==2.2.2
scripts/infer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ from omegaconf import OmegaConf
7
+
8
+ from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
9
+ from StructDiffusion.language.tokenizer import Tokenizer
10
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
11
+ from StructDiffusion.diffusion.sampler import Sampler
12
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
13
+ from StructDiffusion.utils.files import get_checkpoint_path_from_dir
14
+ from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
15
+
16
+
17
+ def main(args, cfg):
18
+
19
+ pl.seed_everything(args.eval_random_seed)
20
+
21
+ device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
22
+
23
+ checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
24
+ checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
25
+
26
+ if args.eval_mode == "infer":
27
+
28
+ tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
29
+ # override ignore_rgb for visualization
30
+ cfg.DATASET.ignore_rgb = False
31
+ dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
32
+
33
+ sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device)
34
+
35
+ data_idxs = np.random.permutation(len(dataset))
36
+ for di in data_idxs:
37
+ raw_datum = dataset.get_raw_data(di)
38
+ print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
39
+ datum = dataset.convert_to_tensors(raw_datum, tokenizer)
40
+ batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
41
+
42
+ num_poses = datum["goal_poses"].shape[0]
43
+ xs = sampler.sample(batch, num_poses)
44
+
45
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
46
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
47
+ visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser(description="infer")
52
+ parser.add_argument("--base_config_file", help='base config yaml file',
53
+ default='../configs/base.yaml',
54
+ type=str)
55
+ parser.add_argument("--config_file", help='config yaml file',
56
+ default='../configs/conditional_pose_diffusion.yaml',
57
+ type=str)
58
+ parser.add_argument("--checkpoint_id",
59
+ default="ConditionalPoseDiffusion",
60
+ type=str)
61
+ parser.add_argument("--eval_mode",
62
+ default="infer",
63
+ type=str)
64
+ parser.add_argument("--eval_random_seed",
65
+ default=42,
66
+ type=int)
67
+ parser.add_argument("--num_samples",
68
+ default=10,
69
+ type=int)
70
+ args = parser.parse_args()
71
+
72
+ base_cfg = OmegaConf.load(args.base_config_file)
73
+ cfg = OmegaConf.load(args.config_file)
74
+ cfg = OmegaConf.merge(base_cfg, cfg)
75
+
76
+ main(args, cfg)
77
+
78
+
scripts/infer_with_discriminator.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ from omegaconf import OmegaConf
7
+
8
+ from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
9
+ from StructDiffusion.language.tokenizer import Tokenizer
10
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
11
+ from StructDiffusion.diffusion.sampler import SamplerV2
12
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
13
+ from StructDiffusion.utils.files import get_checkpoint_path_from_dir
14
+ from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
15
+
16
+
17
+ def main(args, cfg):
18
+
19
+ pl.seed_everything(args.eval_random_seed)
20
+
21
+ device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
22
+
23
+ diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
24
+ collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
25
+
26
+ if args.eval_mode == "infer":
27
+
28
+ tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
29
+ # override ignore_rgb for visualization
30
+ cfg.DATASET.ignore_rgb = False
31
+ dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
32
+
33
+ sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
34
+ PairwiseCollisionModel, collision_checkpoint_path, device)
35
+
36
+ data_idxs = np.random.permutation(len(dataset))
37
+ for di in data_idxs:
38
+ raw_datum = dataset.get_raw_data(di)
39
+ print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
40
+ datum = dataset.convert_to_tensors(raw_datum, tokenizer)
41
+ batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
42
+
43
+ num_poses = datum["goal_poses"].shape[0]
44
+ struct_pose, pc_poses_in_struct = sampler.sample(batch, num_poses)
45
+
46
+ new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
47
+ visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser(description="infer")
52
+ parser.add_argument("--base_config_file", help='base config yaml file',
53
+ default='../configs/base.yaml',
54
+ type=str)
55
+ parser.add_argument("--config_file", help='config yaml file',
56
+ default='../configs/conditional_pose_diffusion.yaml',
57
+ type=str)
58
+ parser.add_argument("--diffusion_checkpoint_id",
59
+ default="ConditionalPoseDiffusion",
60
+ type=str)
61
+ parser.add_argument("--collision_checkpoint_id",
62
+ default="curhl56k",
63
+ type=str)
64
+ parser.add_argument("--eval_mode",
65
+ default="infer",
66
+ type=str)
67
+ parser.add_argument("--eval_random_seed",
68
+ default=42,
69
+ type=int)
70
+ parser.add_argument("--num_samples",
71
+ default=10,
72
+ type=int)
73
+ args = parser.parse_args()
74
+
75
+ base_cfg = OmegaConf.load(args.base_config_file)
76
+ cfg = OmegaConf.load(args.config_file)
77
+ cfg = OmegaConf.merge(base_cfg, cfg)
78
+
79
+ main(args, cfg)
80
+
81
+
scripts/train_discriminator.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from omegaconf import OmegaConf
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning.loggers import WandbLogger
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+
9
+ from StructDiffusion.data.pairwise_collision import PairwiseCollisionDataset
10
+ from StructDiffusion.models.pl_models import PairwiseCollisionModel
11
+
12
+
13
+ def main(cfg):
14
+
15
+ pl.seed_everything(cfg.random_seed)
16
+
17
+ wandb_logger = WandbLogger(**cfg.WANDB)
18
+ wandb_logger.experiment.config.update(cfg)
19
+ checkpoint_callback = ModelCheckpoint()
20
+
21
+ full_dataset = PairwiseCollisionDataset(**cfg.DATASET)
22
+ train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [int(len(full_dataset) * 0.7), len(full_dataset) - int(len(full_dataset) * 0.7)])
23
+ train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER)
24
+ valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER)
25
+
26
+ model = PairwiseCollisionModel(cfg.MODEL, cfg.LOSS, cfg.OPTIMIZER, cfg.DATASET)
27
+
28
+ trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER)
29
+
30
+ trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ parser = argparse.ArgumentParser(description="train")
35
+ parser.add_argument("--base_config_file", help='base config yaml file',
36
+ default='../configs/base.yaml',
37
+ type=str)
38
+ parser.add_argument("--config_file", help='config yaml file',
39
+ default='../configs/pairwise_collision.yaml',
40
+ type=str)
41
+ args = parser.parse_args()
42
+ base_cfg = OmegaConf.load(args.base_config_file)
43
+ cfg = OmegaConf.load(args.config_file)
44
+ cfg = OmegaConf.merge(base_cfg, cfg)
45
+
46
+ main(cfg)
scripts/train_generator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import argparse
3
+ from omegaconf import OmegaConf
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.loggers import WandbLogger
6
+ from pytorch_lightning.callbacks import ModelCheckpoint
7
+
8
+ from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
9
+ from StructDiffusion.language.tokenizer import Tokenizer
10
+ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
11
+
12
+
13
+ def main(cfg):
14
+
15
+ pl.seed_everything(cfg.random_seed)
16
+
17
+ wandb_logger = WandbLogger(**cfg.WANDB)
18
+ wandb_logger.experiment.config.update(cfg)
19
+ checkpoint_callback = ModelCheckpoint()
20
+
21
+ tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
22
+ vocab_size = tokenizer.get_vocab_size()
23
+
24
+ train_dataset = SemanticArrangementDataset(split="train", tokenizer=tokenizer, **cfg.DATASET)
25
+ valid_dataset = SemanticArrangementDataset(split="valid", tokenizer=tokenizer, **cfg.DATASET)
26
+ train_dataloader = DataLoader(train_dataset, shuffle=True, **cfg.DATALOADER)
27
+ valid_dataloader = DataLoader(valid_dataset, shuffle=False, **cfg.DATALOADER)
28
+
29
+ model = ConditionalPoseDiffusionModel(vocab_size, cfg.MODEL, cfg.LOSS, cfg.NOISE_SCHEDULE, cfg.OPTIMIZER)
30
+
31
+ trainer = pl.Trainer(logger=wandb_logger, callbacks=[checkpoint_callback], **cfg.TRAINER)
32
+
33
+ trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser(description="train")
38
+ parser.add_argument("--base_config_file", help='base config yaml file',
39
+ default='../configs/base.yaml',
40
+ type=str)
41
+ parser.add_argument("--config_file", help='config yaml file',
42
+ default='../configs/conditional_pose_diffusion.yaml',
43
+ type=str)
44
+ args = parser.parse_args()
45
+ base_cfg = OmegaConf.load(args.base_config_file)
46
+ cfg = OmegaConf.load(args.config_file)
47
+ cfg = OmegaConf.merge(base_cfg, cfg)
48
+
49
+ main(cfg)
src/StructDiffusion/__init__.py ADDED
File without changes
src/StructDiffusion/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (171 Bytes). View file
 
src/StructDiffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (175 Bytes). View file
 
src/StructDiffusion/data/__init__.py ADDED
File without changes
src/StructDiffusion/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (176 Bytes). View file
 
src/StructDiffusion/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (180 Bytes). View file
 
src/StructDiffusion/data/__pycache__/pairwise_collision.cpython-37.pyc ADDED
Binary file (9.72 kB). View file
 
src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-37.pyc ADDED
Binary file (17.1 kB). View file
 
src/StructDiffusion/data/__pycache__/semantic_arrangement.cpython-38.pyc ADDED
Binary file (17.1 kB). View file
 
src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc ADDED
Binary file (16.4 kB). View file
 
src/StructDiffusion/data/pairwise_collision.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import numpy as np
4
+ import os
5
+ import trimesh
6
+ import torch
7
+ import json
8
+ from collections import defaultdict
9
+ import tqdm
10
+ import pickle
11
+ from random import shuffle
12
+
13
+ # Local imports
14
+ from StructDiffusion.utils.rearrangement import show_pcs, get_pts, array_to_tensor
15
+ from StructDiffusion.utils.pointnet import pc_normalize
16
+
17
+ import StructDiffusion.utils.brain2.camera as cam
18
+ import StructDiffusion.utils.brain2.image as img
19
+ import StructDiffusion.utils.transformations as tra
20
+
21
+
22
+ def load_pairwise_collision_data(h5_filename):
23
+
24
+ fh = h5py.File(h5_filename, 'r')
25
+ data_dict = {}
26
+ data_dict["obj1_info"] = eval(fh["obj1_info"][()])
27
+ data_dict["obj2_info"] = eval(fh["obj2_info"][()])
28
+ data_dict["obj1_poses"] = fh["obj1_poses"][:]
29
+ data_dict["obj2_poses"] = fh["obj2_poses"][:]
30
+ data_dict["intersection_labels"] = fh["intersection_labels"][:]
31
+
32
+ return data_dict
33
+
34
+
35
+ class PairwiseCollisionDataset(torch.utils.data.Dataset):
36
+
37
+ def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
38
+ num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
39
+ debug=False):
40
+
41
+ # load dictionary mapping from urdf to list of pc data, each sample is
42
+ # {"step_t": step_t, "obj": obj, "filename": filename}
43
+ with open(urdf_pc_idx_file, "rb") as fh:
44
+ self.urdf_to_pc_data = pickle.load(fh)
45
+ # filter out broken files
46
+ for urdf in self.urdf_to_pc_data:
47
+ valid_pc_data = []
48
+ for pd in self.urdf_to_pc_data[urdf]:
49
+ filename = pd["filename"]
50
+ if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
51
+ continue
52
+ valid_pc_data.append(pd)
53
+ if valid_pc_data:
54
+ self.urdf_to_pc_data[urdf] = valid_pc_data
55
+
56
+ # build data index
57
+ # each sample is a tuple of (collision filename, idx for the labels and poses)
58
+ if collision_data_dir is not None:
59
+ self.data_idxs = self.build_data_idxs(collision_data_dir)
60
+ else:
61
+ print("WARNING: collision_data_dir is None")
62
+
63
+ self.num_pts = num_pts
64
+ self.debug = debug
65
+ self.normalize_pc = normalize_pc
66
+ self.num_scene_pts = num_scene_pts
67
+ self.random_rotation = random_rotation
68
+
69
+ # Noise
70
+ self.data_augmentation = data_augmentation
71
+ # additive noise
72
+ self.gp_rescale_factor_range = [12, 20]
73
+ self.gaussian_scale_range = [0., 0.003]
74
+ # multiplicative noise
75
+ self.gamma_shape = 1000.
76
+ self.gamma_scale = 0.001
77
+
78
+ def build_data_idxs(self, collision_data_dir):
79
+ print("Load collision data...")
80
+ positive_data = []
81
+ negative_data = []
82
+ for filename in tqdm.tqdm(os.listdir(collision_data_dir)):
83
+ if "h5" not in filename:
84
+ continue
85
+ h5_filename = os.path.join(collision_data_dir, filename)
86
+ data_dict = load_pairwise_collision_data(h5_filename)
87
+ obj1_urdf = data_dict["obj1_info"]["urdf"]
88
+ obj2_urdf = data_dict["obj2_info"]["urdf"]
89
+ if obj1_urdf not in self.urdf_to_pc_data:
90
+ print("no pc data for urdf:", obj1_urdf)
91
+ continue
92
+ if obj2_urdf not in self.urdf_to_pc_data:
93
+ print("no pc data for urdf:", obj2_urdf)
94
+ continue
95
+ for idx, l in enumerate(data_dict["intersection_labels"]):
96
+ if l:
97
+ # intersection
98
+ positive_data.append((h5_filename, idx))
99
+ else:
100
+ negative_data.append((h5_filename, idx))
101
+ print("Num pairwise intersections:", len(positive_data))
102
+ print("Num pairwise no intersections:", len(negative_data))
103
+
104
+ if len(negative_data) != len(positive_data):
105
+ min_len = min(len(negative_data), len(positive_data))
106
+ positive_data = [positive_data[i] for i in np.random.permutation(len(positive_data))[:min_len]]
107
+ negative_data = [negative_data[i] for i in np.random.permutation(len(negative_data))[:min_len]]
108
+ print("after balancing")
109
+ print("Num pairwise intersections:", len(positive_data))
110
+ print("Num pairwise no intersections:", len(negative_data))
111
+
112
+ return positive_data + negative_data
113
+
114
+ def create_urdf_pc_idxs(self, urdf_pc_idx_file, data_roots, index_roots):
115
+ print("Load pc data")
116
+ arrangement_steps = []
117
+ for split in ["train"]:
118
+ for data_root, index_root in zip(data_roots, index_roots):
119
+ arrangement_indices_file = os.path.join(data_root, index_root,"{}_arrangement_indices_file_all.txt".format(split))
120
+ if os.path.exists(arrangement_indices_file):
121
+ with open(arrangement_indices_file, "r") as fh:
122
+ arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
123
+ else:
124
+ print("{} does not exist".format(arrangement_indices_file))
125
+
126
+ urdf_to_pc_data = defaultdict(list)
127
+ for filename, step_t in tqdm.tqdm(arrangement_steps):
128
+ h5 = h5py.File(filename, 'r')
129
+ ids = self._get_ids(h5)
130
+ # moved_objs = h5['moved_objs'][()].split(',')
131
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
132
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
133
+ obj_infos = goal_specification["rearrange"]["objects"] + goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"]
134
+ for obj, obj_info in zip(all_objs, obj_infos):
135
+ urdf_to_pc_data[obj_info["urdf"]].append({"step_t": step_t, "obj": obj, "filename": filename})
136
+
137
+ with open(urdf_pc_idx_file, "wb") as fh:
138
+ pickle.dump(urdf_to_pc_data, fh)
139
+
140
+ return urdf_to_pc_data
141
+
142
+ def add_noise_to_depth(self, depth_img):
143
+ """ add depth noise """
144
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
145
+ depth_img = multiplicative_noise * depth_img
146
+ return depth_img
147
+
148
+ def add_noise_to_xyz(self, xyz_img, depth_img):
149
+ """ TODO: remove this code or at least celean it up"""
150
+ xyz_img = xyz_img.copy()
151
+ H, W, C = xyz_img.shape
152
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
153
+ self.gp_rescale_factor_range[1])
154
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
155
+ self.gaussian_scale_range[1])
156
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
157
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
158
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
159
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
160
+ return xyz_img
161
+
162
+ def _get_images(self, h5, idx, ee=True):
163
+ if ee:
164
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
165
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
166
+ else:
167
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
168
+ DMIN, DMAX = "depth_min", "depth_max"
169
+ dmin = h5[DMIN][idx]
170
+ dmax = h5[DMAX][idx]
171
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
172
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
173
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
174
+
175
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
176
+
177
+ # proj_matrix = h5['proj_matrix'][()]
178
+ camera = cam.get_camera_from_h5(h5)
179
+ if self.data_augmentation:
180
+ depth1 = self.add_noise_to_depth(depth1)
181
+
182
+ xyz1 = cam.compute_xyz(depth1, camera)
183
+ if self.data_augmentation:
184
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
185
+
186
+ # Transform the point cloud
187
+ # Here it is...
188
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
189
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
190
+ cam_pose = h5[CAM_POSE][idx]
191
+ if ee:
192
+ # ee_camera_view has 0s for x, y, z
193
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
194
+ cam_pose[:3, 3] = cam_pos
195
+
196
+ # Get transformed point cloud
197
+ h, w, d = xyz1.shape
198
+ xyz1 = xyz1.reshape(h * w, -1)
199
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
200
+ xyz1 = xyz1.reshape(h, w, -1)
201
+
202
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
203
+
204
+ return scene1
205
+
206
+ def _get_ids(self, h5):
207
+ """
208
+ get object ids
209
+
210
+ @param h5:
211
+ @return:
212
+ """
213
+ ids = {}
214
+ for k in h5.keys():
215
+ if k.startswith("id_"):
216
+ ids[k[3:]] = h5[k][()]
217
+ return ids
218
+
219
+ def get_obj_pc(self, h5, step_t, obj):
220
+ scene = self._get_images(h5, step_t, ee=True)
221
+ rgb, depth, seg, valid, xyz = scene
222
+
223
+ # getting object point clouds
224
+ ids = self._get_ids(h5)
225
+ obj_mask = np.logical_and(seg == ids[obj], valid)
226
+ if np.sum(obj_mask) <= 0:
227
+ raise Exception
228
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts, to_tensor=False)
229
+ obj_pc_center = np.mean(obj_xyz, axis=0)
230
+ obj_pose = h5[obj][step_t]
231
+
232
+ obj_pc_pose = np.eye(4)
233
+ obj_pc_pose[:3, 3] = obj_pc_center[:3]
234
+
235
+ return obj_xyz, obj_rgb, obj_pc_pose, obj_pose
236
+
237
+ def __len__(self):
238
+ return len(self.data_idxs)
239
+
240
+ def __getitem__(self, idx):
241
+ collision_filename, collision_idx = self.data_idxs[idx]
242
+ collision_data_dict = load_pairwise_collision_data(collision_filename)
243
+
244
+ obj1_urdf = collision_data_dict["obj1_info"]["urdf"]
245
+ obj2_urdf = collision_data_dict["obj2_info"]["urdf"]
246
+
247
+ # TODO: find a better way to sample pc data?
248
+ obj1_pc_data = np.random.choice(self.urdf_to_pc_data[obj1_urdf])
249
+ obj2_pc_data = np.random.choice(self.urdf_to_pc_data[obj2_urdf])
250
+
251
+ obj1_xyz, obj1_rgb, obj1_pc_pose, obj1_pose = self.get_obj_pc(h5py.File(obj1_pc_data["filename"], "r"), obj1_pc_data["step_t"], obj1_pc_data["obj"])
252
+ obj2_xyz, obj2_rgb, obj2_pc_pose, obj2_pose = self.get_obj_pc(h5py.File(obj2_pc_data["filename"], "r"), obj2_pc_data["step_t"], obj2_pc_data["obj"])
253
+
254
+ obj1_c_pose = collision_data_dict["obj1_poses"][collision_idx]
255
+ obj2_c_pose = collision_data_dict["obj2_poses"][collision_idx]
256
+ label = collision_data_dict["intersection_labels"][collision_idx]
257
+
258
+ obj1_transform = obj1_c_pose @ np.linalg.inv(obj1_pose)
259
+ obj2_transform = obj2_c_pose @ np.linalg.inv(obj2_pose)
260
+ obj1_c_xyz = trimesh.transform_points(obj1_xyz, obj1_transform)
261
+ obj2_c_xyz = trimesh.transform_points(obj2_xyz, obj2_transform)
262
+
263
+ # if self.debug:
264
+ # show_pcs([obj1_c_xyz, obj2_c_xyz], [obj1_rgb, obj2_rgb], add_coordinate_frame=True)
265
+
266
+ ###################################
267
+ obj_xyzs = [obj1_c_xyz, obj2_c_xyz]
268
+ shuffle(obj_xyzs)
269
+
270
+ num_indicator = 2
271
+ new_obj_xyzs = []
272
+ for oi, obj_xyz in enumerate(obj_xyzs):
273
+ obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
274
+ new_obj_xyzs.append(obj_xyz)
275
+ scene_xyz = np.concatenate(new_obj_xyzs, axis=0)
276
+
277
+ # subsampling and normalizing pc
278
+ idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
279
+ scene_xyz = scene_xyz[idx]
280
+ if self.normalize_pc:
281
+ scene_xyz[:, 0:3] = pc_normalize(scene_xyz[:, 0:3])
282
+
283
+ if self.random_rotation:
284
+ scene_xyz[:, 0:3] = trimesh.transform_points(scene_xyz[:, 0:3], tra.euler_matrix(0, 0, np.random.uniform(low=0, high=2 * np.pi)))
285
+
286
+ ###################################
287
+ scene_xyz = array_to_tensor(scene_xyz)
288
+ # convert to torch data
289
+ label = int(label)
290
+
291
+ if self.debug:
292
+ print("intersection:", label)
293
+ show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))], add_coordinate_frame=True)
294
+
295
+ datum = {
296
+ "scene_xyz": scene_xyz,
297
+ "label": torch.FloatTensor([label]),
298
+ }
299
+ return datum
300
+
301
+ # @staticmethod
302
+ # def collate_fn(data):
303
+ # """
304
+ # :param data:
305
+ # :return:
306
+ # """
307
+ #
308
+ # batched_data_dict = {}
309
+ # for key in ["is_circle"]:
310
+ # batched_data_dict[key] = torch.cat([dict[key] for dict in data], dim=0)
311
+ # for key in ["scene_xyz"]:
312
+ # batched_data_dict[key] = torch.stack([dict[key] for dict in data], dim=0)
313
+ #
314
+ # return batched_data_dict
315
+ #
316
+ # # def create_pair_xyzs_from_obj_xyzs(self, new_obj_xyzs, debug=False):
317
+ # #
318
+ # # new_obj_xyzs = [xyz.cpu().numpy() for xyz in new_obj_xyzs]
319
+ # #
320
+ # # # compute pairwise collision
321
+ # # scene_xyzs = []
322
+ # # obj_xyz_pair_idxs = list(itertools.combinations(range(len(new_obj_xyzs)), 2))
323
+ # #
324
+ # # for obj_xyz_pair_idx in obj_xyz_pair_idxs:
325
+ # # obj_xyz_pair = [new_obj_xyzs[obj_xyz_pair_idx[0]], new_obj_xyzs[obj_xyz_pair_idx[1]]]
326
+ # # num_indicator = 2
327
+ # # obj_xyz_pair_ind = []
328
+ # # for oi, obj_xyz in enumerate(obj_xyz_pair):
329
+ # # obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
330
+ # # obj_xyz_pair_ind.append(obj_xyz)
331
+ # # pair_scene_xyz = np.concatenate(obj_xyz_pair_ind, axis=0)
332
+ # #
333
+ # # # subsampling and normalizing pc
334
+ # # rand_idx = np.random.randint(0, pair_scene_xyz.shape[0], self.num_scene_pts)
335
+ # # pair_scene_xyz = pair_scene_xyz[rand_idx]
336
+ # # if self.normalize_pc:
337
+ # # pair_scene_xyz[:, 0:3] = pc_normalize(pair_scene_xyz[:, 0:3])
338
+ # #
339
+ # # scene_xyzs.append(array_to_tensor(pair_scene_xyz))
340
+ # #
341
+ # # if debug:
342
+ # # for scene_xyz in scene_xyzs:
343
+ # # show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))],
344
+ # # add_coordinate_frame=True)
345
+ # #
346
+ # # return scene_xyzs
347
+
348
+
349
+ if __name__ == "__main__":
350
+ dataset = PairwiseCollisionDataset(urdf_pc_idx_file="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data/urdf_pc_idx.pkl",
351
+ collision_data_dir="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data",
352
+ debug=False)
353
+
354
+ for i in tqdm.tqdm(np.random.permutation(len(dataset))):
355
+ # print(i)
356
+ d = dataset[i]
357
+ # print(d["label"])
358
+
359
+ # dl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8)
360
+ # for b in tqdm.tqdm(dl):
361
+ # pass
src/StructDiffusion/data/semantic_arrangement.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import cv2
3
+ import h5py
4
+ import numpy as np
5
+ import os
6
+ import trimesh
7
+ import torch
8
+ from tqdm import tqdm
9
+ import json
10
+ import random
11
+
12
+ from torch.utils.data import DataLoader
13
+
14
+ # Local imports
15
+ from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
16
+ from StructDiffusion.language.tokenizer import Tokenizer
17
+
18
+ import StructDiffusion.utils.brain2.camera as cam
19
+ import StructDiffusion.utils.brain2.image as img
20
+ import StructDiffusion.utils.transformations as tra
21
+
22
+
23
+ class SemanticArrangementDataset(torch.utils.data.Dataset):
24
+
25
+ def __init__(self, data_roots, index_roots, split, tokenizer,
26
+ max_num_target_objects=11, max_num_distractor_objects=5,
27
+ max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
28
+ num_pts=1024,
29
+ use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
30
+ filter_num_moved_objects_range=None, shuffle_object_index=False,
31
+ data_augmentation=True, debug=False, **kwargs):
32
+ """
33
+
34
+ Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
35
+
36
+ :param data_root:
37
+ :param split: train, valid, or test
38
+ :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
39
+ :param debug:
40
+ :param max_num_shape_parameters:
41
+ :param max_num_objects:
42
+ :param max_num_rearrange_features:
43
+ :param max_num_anchor_features:
44
+ :param num_pts:
45
+ :param use_stored_arrangement_indices:
46
+ :param kwargs:
47
+ """
48
+
49
+ self.use_virtual_structure_frame = use_virtual_structure_frame
50
+ self.ignore_distractor_objects = ignore_distractor_objects
51
+ self.ignore_rgb = ignore_rgb and not debug
52
+
53
+ self.num_pts = num_pts
54
+ self.debug = debug
55
+
56
+ self.max_num_objects = max_num_target_objects
57
+ self.max_num_other_objects = max_num_distractor_objects
58
+ self.max_num_shape_parameters = max_num_shape_parameters
59
+ self.max_num_rearrange_features = max_num_rearrange_features
60
+ self.max_num_anchor_features = max_num_anchor_features
61
+ self.shuffle_object_index = shuffle_object_index
62
+
63
+ # used to tokenize the language part
64
+ self.tokenizer = tokenizer
65
+
66
+ # retrieve data
67
+ self.data_roots = data_roots
68
+ self.arrangement_data = []
69
+ arrangement_steps = []
70
+ for ddx in range(len(data_roots)):
71
+ data_root = data_roots[ddx]
72
+ index_root = index_roots[ddx]
73
+ arrangement_indices_file = os.path.join(data_root, index_root, "{}_arrangement_indices_file_all.txt".format(split))
74
+ if os.path.exists(arrangement_indices_file):
75
+ with open(arrangement_indices_file, "r") as fh:
76
+ arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
77
+ else:
78
+ print("{} does not exist".format(arrangement_indices_file))
79
+ # only keep the goal, ignore the intermediate steps
80
+ for filename, step_t in arrangement_steps:
81
+ if step_t == 0:
82
+ if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename:
83
+ continue
84
+ self.arrangement_data.append((filename, step_t))
85
+ # if specified, filter data
86
+ if filter_num_moved_objects_range is not None:
87
+ self.arrangement_data = self.filter_based_on_number_of_moved_objects(filter_num_moved_objects_range)
88
+ print("{} valid sequences".format(len(self.arrangement_data)))
89
+
90
+ # Data Aug
91
+ self.data_augmentation = data_augmentation
92
+ # additive noise
93
+ self.gp_rescale_factor_range = [12, 20]
94
+ self.gaussian_scale_range = [0., 0.003]
95
+ # multiplicative noise
96
+ self.gamma_shape = 1000.
97
+ self.gamma_scale = 0.001
98
+
99
+ def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
100
+ assert len(list(filter_num_moved_objects_range)) == 2
101
+ min_num, max_num = filter_num_moved_objects_range
102
+ print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
103
+ ok_data = []
104
+ for filename, step_t in self.arrangement_data:
105
+ h5 = h5py.File(filename, 'r')
106
+ moved_objs = h5['moved_objs'][()].split(',')
107
+ if min_num <= len(moved_objs) <= max_num:
108
+ ok_data.append((filename, step_t))
109
+ print("{} valid sequences left".format(len(ok_data)))
110
+ return ok_data
111
+
112
+ def get_data_idx(self, idx):
113
+ # Create the datum to return
114
+ file_idx = np.argmax(idx < self.file_to_count)
115
+ data = h5py.File(self.data_files[file_idx], 'r')
116
+ if file_idx > 0:
117
+ # for lang2sym, idx is always 0
118
+ idx = idx - self.file_to_count[file_idx - 1]
119
+ return data, idx, file_idx
120
+
121
+ def add_noise_to_depth(self, depth_img):
122
+ """ add depth noise """
123
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
124
+ depth_img = multiplicative_noise * depth_img
125
+ return depth_img
126
+
127
+ def add_noise_to_xyz(self, xyz_img, depth_img):
128
+ """ TODO: remove this code or at least celean it up"""
129
+ xyz_img = xyz_img.copy()
130
+ H, W, C = xyz_img.shape
131
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
132
+ self.gp_rescale_factor_range[1])
133
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
134
+ self.gaussian_scale_range[1])
135
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
136
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
137
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
138
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
139
+ return xyz_img
140
+
141
+ def random_index(self):
142
+ return self[np.random.randint(len(self))]
143
+
144
+ def _get_rgb(self, h5, idx, ee=True):
145
+ RGB = "ee_rgb" if ee else "rgb"
146
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
147
+ return rgb1
148
+
149
+ def _get_depth(self, h5, idx, ee=True):
150
+ DEPTH = "ee_depth" if ee else "depth"
151
+
152
+ def _get_images(self, h5, idx, ee=True):
153
+ if ee:
154
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
155
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
156
+ else:
157
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
158
+ DMIN, DMAX = "depth_min", "depth_max"
159
+ dmin = h5[DMIN][idx]
160
+ dmax = h5[DMAX][idx]
161
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
162
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
163
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
164
+
165
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
166
+
167
+ # proj_matrix = h5['proj_matrix'][()]
168
+ camera = cam.get_camera_from_h5(h5)
169
+ if self.data_augmentation:
170
+ depth1 = self.add_noise_to_depth(depth1)
171
+
172
+ xyz1 = cam.compute_xyz(depth1, camera)
173
+ if self.data_augmentation:
174
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
175
+
176
+ # Transform the point cloud
177
+ # Here it is...
178
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
179
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
180
+ cam_pose = h5[CAM_POSE][idx]
181
+ if ee:
182
+ # ee_camera_view has 0s for x, y, z
183
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
184
+ cam_pose[:3, 3] = cam_pos
185
+
186
+ # Get transformed point cloud
187
+ h, w, d = xyz1.shape
188
+ xyz1 = xyz1.reshape(h * w, -1)
189
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
190
+ xyz1 = xyz1.reshape(h, w, -1)
191
+
192
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
193
+
194
+ return scene1
195
+
196
+ def __len__(self):
197
+ return len(self.arrangement_data)
198
+
199
+ def _get_ids(self, h5):
200
+ """
201
+ get object ids
202
+
203
+ @param h5:
204
+ @return:
205
+ """
206
+ ids = {}
207
+ for k in h5.keys():
208
+ if k.startswith("id_"):
209
+ ids[k[3:]] = h5[k][()]
210
+ return ids
211
+
212
+ def get_positive_ratio(self):
213
+ num_pos = 0
214
+ for d in self.arrangement_data:
215
+ filename, step_t = d
216
+ if step_t == 0:
217
+ num_pos += 1
218
+ return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
219
+
220
+ def get_object_position_vocab_sizes(self):
221
+ return self.tokenizer.get_object_position_vocab_sizes()
222
+
223
+ def get_vocab_size(self):
224
+ return self.tokenizer.get_vocab_size()
225
+
226
+ def get_data_index(self, idx):
227
+ filename = self.arrangement_data[idx]
228
+ return filename
229
+
230
+ def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
231
+ """
232
+
233
+ :param idx:
234
+ :param inference_mode:
235
+ :param shuffle_object_index: used to test different orders of objects
236
+ :return:
237
+ """
238
+
239
+ filename, _ = self.arrangement_data[idx]
240
+
241
+ h5 = h5py.File(filename, 'r')
242
+ ids = self._get_ids(h5)
243
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
244
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
245
+ num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
246
+ num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
247
+ assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
248
+ assert num_rearrange_objs <= self.max_num_objects
249
+ assert num_other_objs <= self.max_num_other_objects
250
+
251
+ # important: only using the last step
252
+ step_t = num_rearrange_objs
253
+
254
+ target_objs = all_objs[:num_rearrange_objs]
255
+ other_objs = all_objs[num_rearrange_objs:]
256
+
257
+ structure_parameters = goal_specification["shape"]
258
+
259
+ # Important: ensure the order is correct
260
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
261
+ target_objs = target_objs[::-1]
262
+ elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
263
+ target_objs = target_objs
264
+ else:
265
+ raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
266
+ all_objs = target_objs + other_objs
267
+
268
+ ###################################
269
+ # getting scene images and point clouds
270
+ scene = self._get_images(h5, step_t, ee=True)
271
+ rgb, depth, seg, valid, xyz = scene
272
+ if inference_mode:
273
+ initial_scene = scene
274
+
275
+ # getting object point clouds
276
+ obj_pcs = []
277
+ obj_pad_mask = []
278
+ current_pc_poses = []
279
+ other_obj_pcs = []
280
+ other_obj_pad_mask = []
281
+ for obj in all_objs:
282
+ obj_mask = np.logical_and(seg == ids[obj], valid)
283
+ if np.sum(obj_mask) <= 0:
284
+ raise Exception
285
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
286
+ if not ok:
287
+ raise Exception
288
+
289
+ if obj in target_objs:
290
+ if self.ignore_rgb:
291
+ obj_pcs.append(obj_xyz)
292
+ else:
293
+ obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
294
+ obj_pad_mask.append(0)
295
+ pc_pose = np.eye(4)
296
+ pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
297
+ current_pc_poses.append(pc_pose)
298
+ elif obj in other_objs:
299
+ if self.ignore_rgb:
300
+ other_obj_pcs.append(obj_xyz)
301
+ else:
302
+ other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
303
+ other_obj_pad_mask.append(0)
304
+ else:
305
+ raise Exception
306
+
307
+ ###################################
308
+ # computes goal positions for objects
309
+ # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
310
+ if self.use_virtual_structure_frame:
311
+ goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
312
+ structure_parameters["rotation"][2])
313
+ goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
314
+ structure_parameters["position"][2]]
315
+ goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
316
+
317
+ goal_obj_poses = []
318
+ current_obj_poses = []
319
+ goal_pc_poses = []
320
+ for obj, current_pc_pose in zip(target_objs, current_pc_poses):
321
+ goal_pose = h5[obj][0]
322
+ current_pose = h5[obj][step_t]
323
+ if inference_mode:
324
+ goal_obj_poses.append(goal_pose)
325
+ current_obj_poses.append(current_pose)
326
+
327
+ goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
328
+ if self.use_virtual_structure_frame:
329
+ goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
330
+ goal_pc_poses.append(goal_pc_pose)
331
+
332
+ # transform current object point cloud to the goal point cloud in the world frame
333
+ if self.debug:
334
+ new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
335
+ for i, obj_pc in enumerate(new_obj_pcs):
336
+
337
+ current_pc_pose = current_pc_poses[i]
338
+ goal_pc_pose = goal_pc_poses[i]
339
+ if self.use_virtual_structure_frame:
340
+ goal_pc_pose = goal_structure_pose @ goal_pc_pose
341
+ print("current pc pose", current_pc_pose)
342
+ print("goal pc pose", goal_pc_pose)
343
+
344
+ goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
345
+ print("transform", goal_pc_transform)
346
+ new_obj_pc = copy.deepcopy(obj_pc)
347
+ new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
348
+ print(new_obj_pc.shape)
349
+
350
+ # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
351
+ new_obj_pcs[i] = new_obj_pc
352
+ new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
353
+ new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
354
+ show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
355
+ [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
356
+ add_coordinate_frame=True)
357
+ show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
358
+
359
+ # pad data
360
+ for i in range(self.max_num_objects - len(target_objs)):
361
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
362
+ obj_pad_mask.append(1)
363
+ for i in range(self.max_num_other_objects - len(other_objs)):
364
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
365
+ other_obj_pad_mask.append(1)
366
+
367
+ ###################################
368
+ # preparing sentence
369
+ sentence = []
370
+ sentence_pad_mask = []
371
+
372
+ # structure parameters
373
+ # 5 parameters
374
+ structure_parameters = goal_specification["shape"]
375
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
376
+ sentence.append((structure_parameters["type"], "shape"))
377
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
378
+ sentence.append((structure_parameters["position"][0], "position_x"))
379
+ sentence.append((structure_parameters["position"][1], "position_y"))
380
+ if structure_parameters["type"] == "circle":
381
+ sentence.append((structure_parameters["radius"], "radius"))
382
+ elif structure_parameters["type"] == "line":
383
+ sentence.append((structure_parameters["length"] / 2.0, "radius"))
384
+ for _ in range(5):
385
+ sentence_pad_mask.append(0)
386
+ else:
387
+ sentence.append((structure_parameters["type"], "shape"))
388
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
389
+ sentence.append((structure_parameters["position"][0], "position_x"))
390
+ sentence.append((structure_parameters["position"][1], "position_y"))
391
+ for _ in range(4):
392
+ sentence_pad_mask.append(0)
393
+ sentence.append(("PAD", None))
394
+ sentence_pad_mask.append(1)
395
+
396
+ ###################################
397
+ # paddings
398
+ for i in range(self.max_num_objects - len(target_objs)):
399
+ goal_pc_poses.append(np.eye(4))
400
+
401
+ ###################################
402
+ if self.debug:
403
+ print("---")
404
+ print("all objects:", all_objs)
405
+ print("target objects:", target_objs)
406
+ print("other objects:", other_objs)
407
+ print("goal specification:", goal_specification)
408
+ print("sentence:", sentence)
409
+ show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
410
+
411
+ assert len(obj_pcs) == len(goal_pc_poses)
412
+ ###################################
413
+
414
+ # shuffle the position of objects
415
+ if shuffle_object_index:
416
+ shuffle_target_object_indices = list(range(len(target_objs)))
417
+ random.shuffle(shuffle_target_object_indices)
418
+ shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects))
419
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
420
+ goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
421
+ if inference_mode:
422
+ goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices]
423
+ current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices]
424
+ target_objs = [target_objs[i] for i in shuffle_target_object_indices]
425
+ current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices]
426
+
427
+ ###################################
428
+ if self.use_virtual_structure_frame:
429
+ if self.ignore_distractor_objects:
430
+ # language, structure virtual frame, target objects
431
+ pcs = obj_pcs
432
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
433
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
434
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
435
+ else:
436
+ # language, distractor objects, structure virtual frame, target objects
437
+ pcs = other_obj_pcs + obj_pcs
438
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
439
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
440
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
441
+ goal_poses = [goal_structure_pose] + goal_pc_poses
442
+ else:
443
+ if self.ignore_distractor_objects:
444
+ # language, target objects
445
+ pcs = obj_pcs
446
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
447
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
448
+ pad_mask = sentence_pad_mask + obj_pad_mask
449
+ else:
450
+ # language, distractor objects, target objects
451
+ pcs = other_obj_pcs + obj_pcs
452
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
453
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
454
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
455
+ goal_poses = goal_pc_poses
456
+
457
+ datum = {
458
+ "pcs": pcs,
459
+ "sentence": sentence,
460
+ "goal_poses": goal_poses,
461
+ "type_index": type_index,
462
+ "position_index": position_index,
463
+ "pad_mask": pad_mask,
464
+ "t": step_t,
465
+ "filename": filename
466
+ }
467
+
468
+ if inference_mode:
469
+ datum["rgb"] = rgb
470
+ datum["goal_obj_poses"] = goal_obj_poses
471
+ datum["current_obj_poses"] = current_obj_poses
472
+ datum["target_objs"] = target_objs
473
+ datum["initial_scene"] = initial_scene
474
+ datum["ids"] = ids
475
+ datum["goal_specification"] = goal_specification
476
+ datum["current_pc_poses"] = current_pc_poses
477
+
478
+ return datum
479
+
480
+ @staticmethod
481
+ def convert_to_tensors(datum, tokenizer):
482
+ tensors = {
483
+ "pcs": torch.stack(datum["pcs"], dim=0),
484
+ "sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])),
485
+ "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
486
+ "type_index": torch.LongTensor(np.array(datum["type_index"])),
487
+ "position_index": torch.LongTensor(np.array(datum["position_index"])),
488
+ "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
489
+ "t": datum["t"],
490
+ "filename": datum["filename"]
491
+ }
492
+ return tensors
493
+
494
+ def __getitem__(self, idx):
495
+
496
+ datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
497
+ self.tokenizer)
498
+
499
+ return datum
500
+
501
+ def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
502
+ tensor_x = {}
503
+
504
+ tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
505
+ tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
506
+ if not inference_mode:
507
+ tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
508
+
509
+ tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
510
+ tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
511
+ tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
512
+
513
+ return tensor_x
514
+
515
+
516
+ def compute_min_max(dataloader):
517
+
518
+ # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
519
+ # -0.9079, -0.8668, -0.9105, -0.4186])
520
+ # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
521
+ # 0.4787, 0.6421, 1.0000])
522
+ # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
523
+ # -0.0000, 0.0000, 0.0000, 1.0000])
524
+ # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
525
+ # 0.0000, 0.0000, 1.0000])
526
+
527
+ min_value = torch.ones(16) * 10000
528
+ max_value = torch.ones(16) * -10000
529
+ for d in tqdm(dataloader):
530
+ goal_poses = d["goal_poses"]
531
+ goal_poses = goal_poses.reshape(-1, 16)
532
+ current_max, _ = torch.max(goal_poses, dim=0)
533
+ current_min, _ = torch.min(goal_poses, dim=0)
534
+ max_value[max_value < current_max] = current_max[max_value < current_max]
535
+ max_value[max_value > current_min] = current_min[max_value > current_min]
536
+ print(f"{min_value} - {max_value}")
537
+
538
+
539
+ if __name__ == "__main__":
540
+
541
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
542
+
543
+ data_roots = []
544
+ index_roots = []
545
+ for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
546
+ data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
547
+ index_roots.append(index)
548
+
549
+ dataset = SemanticArrangementDataset(data_roots=data_roots,
550
+ index_roots=index_roots,
551
+ split="valid", tokenizer=tokenizer,
552
+ max_num_target_objects=7,
553
+ max_num_distractor_objects=5,
554
+ max_num_shape_parameters=5,
555
+ max_num_rearrange_features=0,
556
+ max_num_anchor_features=0,
557
+ num_pts=1024,
558
+ use_virtual_structure_frame=True,
559
+ ignore_distractor_objects=True,
560
+ ignore_rgb=True,
561
+ filter_num_moved_objects_range=None, # [5, 5]
562
+ data_augmentation=False,
563
+ shuffle_object_index=False,
564
+ debug=False)
565
+
566
+ # print(len(dataset))
567
+ # for d in dataset:
568
+ # print("\n\n" + "="*100)
569
+
570
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
571
+ for i, d in enumerate(tqdm(dataloader)):
572
+ pass
573
+ # for k in d:
574
+ # if isinstance(d[k], torch.Tensor):
575
+ # print("--size", k, d[k].shape)
576
+ # for k in d:
577
+ # print(k, d[k])
578
+ #
579
+ # input("next?")
src/StructDiffusion/data/semantic_arrangement_demo.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import cv2
3
+ import h5py
4
+ import numpy as np
5
+ import os
6
+ import trimesh
7
+ import torch
8
+ from tqdm import tqdm
9
+ import json
10
+ import random
11
+
12
+ from torch.utils.data import DataLoader
13
+
14
+ # Local imports
15
+ from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
16
+ from StructDiffusion.language.tokenizer import Tokenizer
17
+
18
+ import StructDiffusion.utils.brain2.camera as cam
19
+ import StructDiffusion.utils.brain2.image as img
20
+ import StructDiffusion.utils.transformations as tra
21
+
22
+
23
+ class SemanticArrangementDataset(torch.utils.data.Dataset):
24
+
25
+ def __init__(self, data_root, tokenizer,
26
+ max_num_target_objects=11, max_num_distractor_objects=5,
27
+ max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
28
+ num_pts=1024,
29
+ use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
30
+ filter_num_moved_objects_range=None, shuffle_object_index=False,
31
+ data_augmentation=True, debug=False, **kwargs):
32
+ """
33
+
34
+ Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
35
+
36
+ :param data_root:
37
+ :param split: train, valid, or test
38
+ :param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
39
+ :param debug:
40
+ :param max_num_shape_parameters:
41
+ :param max_num_objects:
42
+ :param max_num_rearrange_features:
43
+ :param max_num_anchor_features:
44
+ :param num_pts:
45
+ :param use_stored_arrangement_indices:
46
+ :param kwargs:
47
+ """
48
+
49
+ self.use_virtual_structure_frame = use_virtual_structure_frame
50
+ self.ignore_distractor_objects = ignore_distractor_objects
51
+ self.ignore_rgb = ignore_rgb and not debug
52
+
53
+ self.num_pts = num_pts
54
+ self.debug = debug
55
+
56
+ self.max_num_objects = max_num_target_objects
57
+ self.max_num_other_objects = max_num_distractor_objects
58
+ self.max_num_shape_parameters = max_num_shape_parameters
59
+ self.max_num_rearrange_features = max_num_rearrange_features
60
+ self.max_num_anchor_features = max_num_anchor_features
61
+ self.shuffle_object_index = shuffle_object_index
62
+
63
+ # used to tokenize the language part
64
+ self.tokenizer = tokenizer
65
+
66
+ # retrieve data
67
+ self.data_root = data_root
68
+ self.arrangement_data = []
69
+ for filename in os.listdir(data_root):
70
+ if ".h5" in filename:
71
+ self.arrangement_data.append((os.path.join(data_root, filename), 0))
72
+ print("{} valid sequences".format(len(self.arrangement_data)))
73
+
74
+ # Data Aug
75
+ self.data_augmentation = data_augmentation
76
+ # additive noise
77
+ self.gp_rescale_factor_range = [12, 20]
78
+ self.gaussian_scale_range = [0., 0.003]
79
+ # multiplicative noise
80
+ self.gamma_shape = 1000.
81
+ self.gamma_scale = 0.001
82
+
83
+ def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
84
+ assert len(list(filter_num_moved_objects_range)) == 2
85
+ min_num, max_num = filter_num_moved_objects_range
86
+ print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
87
+ ok_data = []
88
+ for filename, step_t in self.arrangement_data:
89
+ h5 = h5py.File(filename, 'r')
90
+ moved_objs = h5['moved_objs'][()].split(',')
91
+ if min_num <= len(moved_objs) <= max_num:
92
+ ok_data.append((filename, step_t))
93
+ print("{} valid sequences left".format(len(ok_data)))
94
+ return ok_data
95
+
96
+ def get_data_idx(self, idx):
97
+ # Create the datum to return
98
+ file_idx = np.argmax(idx < self.file_to_count)
99
+ data = h5py.File(self.data_files[file_idx], 'r')
100
+ if file_idx > 0:
101
+ # for lang2sym, idx is always 0
102
+ idx = idx - self.file_to_count[file_idx - 1]
103
+ return data, idx, file_idx
104
+
105
+ def add_noise_to_depth(self, depth_img):
106
+ """ add depth noise """
107
+ multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
108
+ depth_img = multiplicative_noise * depth_img
109
+ return depth_img
110
+
111
+ def add_noise_to_xyz(self, xyz_img, depth_img):
112
+ """ TODO: remove this code or at least celean it up"""
113
+ xyz_img = xyz_img.copy()
114
+ H, W, C = xyz_img.shape
115
+ gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
116
+ self.gp_rescale_factor_range[1])
117
+ gp_scale = np.random.uniform(self.gaussian_scale_range[0],
118
+ self.gaussian_scale_range[1])
119
+ small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
120
+ additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
121
+ additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
122
+ xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
123
+ return xyz_img
124
+
125
+ def random_index(self):
126
+ return self[np.random.randint(len(self))]
127
+
128
+ def _get_rgb(self, h5, idx, ee=True):
129
+ RGB = "ee_rgb" if ee else "rgb"
130
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
131
+ return rgb1
132
+
133
+ def _get_depth(self, h5, idx, ee=True):
134
+ DEPTH = "ee_depth" if ee else "depth"
135
+
136
+ def _get_images(self, h5, idx, ee=True):
137
+ if ee:
138
+ RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
139
+ DMIN, DMAX = "ee_depth_min", "ee_depth_max"
140
+ else:
141
+ RGB, DEPTH, SEG = "rgb", "depth", "seg"
142
+ DMIN, DMAX = "depth_min", "depth_max"
143
+ dmin = h5[DMIN][idx]
144
+ dmax = h5[DMAX][idx]
145
+ rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
146
+ depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
147
+ seg1 = img.PNGToNumpy(h5[SEG][idx])
148
+
149
+ valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
150
+
151
+ # proj_matrix = h5['proj_matrix'][()]
152
+ camera = cam.get_camera_from_h5(h5)
153
+ if self.data_augmentation:
154
+ depth1 = self.add_noise_to_depth(depth1)
155
+
156
+ xyz1 = cam.compute_xyz(depth1, camera)
157
+ if self.data_augmentation:
158
+ xyz1 = self.add_noise_to_xyz(xyz1, depth1)
159
+
160
+ # Transform the point cloud
161
+ # Here it is...
162
+ # CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
163
+ CAM_POSE = "ee_camera_view" if ee else "camera_view"
164
+ cam_pose = h5[CAM_POSE][idx]
165
+ if ee:
166
+ # ee_camera_view has 0s for x, y, z
167
+ cam_pos = h5["ee_cam_pose"][:][:3, 3]
168
+ cam_pose[:3, 3] = cam_pos
169
+
170
+ # Get transformed point cloud
171
+ h, w, d = xyz1.shape
172
+ xyz1 = xyz1.reshape(h * w, -1)
173
+ xyz1 = trimesh.transform_points(xyz1, cam_pose)
174
+ xyz1 = xyz1.reshape(h, w, -1)
175
+
176
+ scene1 = rgb1, depth1, seg1, valid1, xyz1
177
+
178
+ return scene1
179
+
180
+ def __len__(self):
181
+ return len(self.arrangement_data)
182
+
183
+ def _get_ids(self, h5):
184
+ """
185
+ get object ids
186
+
187
+ @param h5:
188
+ @return:
189
+ """
190
+ ids = {}
191
+ for k in h5.keys():
192
+ if k.startswith("id_"):
193
+ ids[k[3:]] = h5[k][()]
194
+ return ids
195
+
196
+ def get_positive_ratio(self):
197
+ num_pos = 0
198
+ for d in self.arrangement_data:
199
+ filename, step_t = d
200
+ if step_t == 0:
201
+ num_pos += 1
202
+ return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
203
+
204
+ def get_object_position_vocab_sizes(self):
205
+ return self.tokenizer.get_object_position_vocab_sizes()
206
+
207
+ def get_vocab_size(self):
208
+ return self.tokenizer.get_vocab_size()
209
+
210
+ def get_data_index(self, idx):
211
+ filename = self.arrangement_data[idx]
212
+ return filename
213
+
214
+ def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
215
+ """
216
+
217
+ :param idx:
218
+ :param inference_mode:
219
+ :param shuffle_object_index: used to test different orders of objects
220
+ :return:
221
+ """
222
+
223
+ filename, _ = self.arrangement_data[idx]
224
+
225
+ h5 = h5py.File(filename, 'r')
226
+ ids = self._get_ids(h5)
227
+ all_objs = sorted([o for o in ids.keys() if "object_" in o])
228
+ goal_specification = json.loads(str(np.array(h5["goal_specification"])))
229
+ num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
230
+ num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
231
+ assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
232
+ assert num_rearrange_objs <= self.max_num_objects
233
+ assert num_other_objs <= self.max_num_other_objects
234
+
235
+ # important: only using the last step
236
+ step_t = num_rearrange_objs
237
+
238
+ target_objs = all_objs[:num_rearrange_objs]
239
+ other_objs = all_objs[num_rearrange_objs:]
240
+
241
+ structure_parameters = goal_specification["shape"]
242
+
243
+ # Important: ensure the order is correct
244
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
245
+ target_objs = target_objs[::-1]
246
+ elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
247
+ target_objs = target_objs
248
+ else:
249
+ raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
250
+ all_objs = target_objs + other_objs
251
+
252
+ ###################################
253
+ # getting scene images and point clouds
254
+ scene = self._get_images(h5, step_t, ee=True)
255
+ rgb, depth, seg, valid, xyz = scene
256
+ if inference_mode:
257
+ initial_scene = scene
258
+
259
+ # getting object point clouds
260
+ obj_pcs = []
261
+ obj_pad_mask = []
262
+ current_pc_poses = []
263
+ other_obj_pcs = []
264
+ other_obj_pad_mask = []
265
+ for obj in all_objs:
266
+ obj_mask = np.logical_and(seg == ids[obj], valid)
267
+ if np.sum(obj_mask) <= 0:
268
+ raise Exception
269
+ ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
270
+ if not ok:
271
+ raise Exception
272
+
273
+ if obj in target_objs:
274
+ if self.ignore_rgb:
275
+ obj_pcs.append(obj_xyz)
276
+ else:
277
+ obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
278
+ obj_pad_mask.append(0)
279
+ pc_pose = np.eye(4)
280
+ pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
281
+ current_pc_poses.append(pc_pose)
282
+ elif obj in other_objs:
283
+ if self.ignore_rgb:
284
+ other_obj_pcs.append(obj_xyz)
285
+ else:
286
+ other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
287
+ other_obj_pad_mask.append(0)
288
+ else:
289
+ raise Exception
290
+
291
+ ###################################
292
+ # computes goal positions for objects
293
+ # Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
294
+ if self.use_virtual_structure_frame:
295
+ goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
296
+ structure_parameters["rotation"][2])
297
+ goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
298
+ structure_parameters["position"][2]]
299
+ goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
300
+
301
+ goal_obj_poses = []
302
+ current_obj_poses = []
303
+ goal_pc_poses = []
304
+ for obj, current_pc_pose in zip(target_objs, current_pc_poses):
305
+ goal_pose = h5[obj][0]
306
+ current_pose = h5[obj][step_t]
307
+ if inference_mode:
308
+ goal_obj_poses.append(goal_pose)
309
+ current_obj_poses.append(current_pose)
310
+
311
+ goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
312
+ if self.use_virtual_structure_frame:
313
+ goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
314
+ goal_pc_poses.append(goal_pc_pose)
315
+
316
+ # transform current object point cloud to the goal point cloud in the world frame
317
+ if self.debug:
318
+ new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
319
+ for i, obj_pc in enumerate(new_obj_pcs):
320
+
321
+ current_pc_pose = current_pc_poses[i]
322
+ goal_pc_pose = goal_pc_poses[i]
323
+ if self.use_virtual_structure_frame:
324
+ goal_pc_pose = goal_structure_pose @ goal_pc_pose
325
+ print("current pc pose", current_pc_pose)
326
+ print("goal pc pose", goal_pc_pose)
327
+
328
+ goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
329
+ print("transform", goal_pc_transform)
330
+ new_obj_pc = copy.deepcopy(obj_pc)
331
+ new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
332
+ print(new_obj_pc.shape)
333
+
334
+ # visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
335
+ new_obj_pcs[i] = new_obj_pc
336
+ new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
337
+ new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
338
+ show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
339
+ [pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
340
+ add_coordinate_frame=True)
341
+ show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
342
+
343
+ # pad data
344
+ for i in range(self.max_num_objects - len(target_objs)):
345
+ obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
346
+ obj_pad_mask.append(1)
347
+ for i in range(self.max_num_other_objects - len(other_objs)):
348
+ other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
349
+ other_obj_pad_mask.append(1)
350
+
351
+ ###################################
352
+ # preparing sentence
353
+ sentence = []
354
+ sentence_pad_mask = []
355
+
356
+ # structure parameters
357
+ # 5 parameters
358
+ structure_parameters = goal_specification["shape"]
359
+ if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
360
+ sentence.append((structure_parameters["type"], "shape"))
361
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
362
+ sentence.append((structure_parameters["position"][0], "position_x"))
363
+ sentence.append((structure_parameters["position"][1], "position_y"))
364
+ if structure_parameters["type"] == "circle":
365
+ sentence.append((structure_parameters["radius"], "radius"))
366
+ elif structure_parameters["type"] == "line":
367
+ sentence.append((structure_parameters["length"] / 2.0, "radius"))
368
+ for _ in range(5):
369
+ sentence_pad_mask.append(0)
370
+ else:
371
+ sentence.append((structure_parameters["type"], "shape"))
372
+ sentence.append((structure_parameters["rotation"][2], "rotation"))
373
+ sentence.append((structure_parameters["position"][0], "position_x"))
374
+ sentence.append((structure_parameters["position"][1], "position_y"))
375
+ for _ in range(4):
376
+ sentence_pad_mask.append(0)
377
+ sentence.append(("PAD", None))
378
+ sentence_pad_mask.append(1)
379
+
380
+ ###################################
381
+ # paddings
382
+ for i in range(self.max_num_objects - len(target_objs)):
383
+ goal_pc_poses.append(np.eye(4))
384
+
385
+ ###################################
386
+ if self.debug:
387
+ print("---")
388
+ print("all objects:", all_objs)
389
+ print("target objects:", target_objs)
390
+ print("other objects:", other_objs)
391
+ print("goal specification:", goal_specification)
392
+ print("sentence:", sentence)
393
+ show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
394
+
395
+ assert len(obj_pcs) == len(goal_pc_poses)
396
+ ###################################
397
+
398
+ # shuffle the position of objects
399
+ if shuffle_object_index:
400
+ shuffle_target_object_indices = list(range(len(target_objs)))
401
+ random.shuffle(shuffle_target_object_indices)
402
+ shuffle_object_indices = shuffle_target_object_indices + list(range(len(target_objs), self.max_num_objects))
403
+ obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
404
+ goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
405
+ if inference_mode:
406
+ goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices]
407
+ current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices]
408
+ target_objs = [target_objs[i] for i in shuffle_target_object_indices]
409
+ current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices]
410
+
411
+ ###################################
412
+ if self.use_virtual_structure_frame:
413
+ if self.ignore_distractor_objects:
414
+ # language, structure virtual frame, target objects
415
+ pcs = obj_pcs
416
+ type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
417
+ position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
418
+ pad_mask = sentence_pad_mask + [0] + obj_pad_mask
419
+ else:
420
+ # language, distractor objects, structure virtual frame, target objects
421
+ pcs = other_obj_pcs + obj_pcs
422
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
423
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
424
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
425
+ goal_poses = [goal_structure_pose] + goal_pc_poses
426
+ else:
427
+ if self.ignore_distractor_objects:
428
+ # language, target objects
429
+ pcs = obj_pcs
430
+ type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
431
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
432
+ pad_mask = sentence_pad_mask + obj_pad_mask
433
+ else:
434
+ # language, distractor objects, target objects
435
+ pcs = other_obj_pcs + obj_pcs
436
+ type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
437
+ position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
438
+ pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
439
+ goal_poses = goal_pc_poses
440
+
441
+ datum = {
442
+ "pcs": pcs,
443
+ "sentence": sentence,
444
+ "goal_poses": goal_poses,
445
+ "type_index": type_index,
446
+ "position_index": position_index,
447
+ "pad_mask": pad_mask,
448
+ "t": step_t,
449
+ "filename": filename
450
+ }
451
+
452
+ if inference_mode:
453
+ datum["rgb"] = rgb
454
+ datum["goal_obj_poses"] = goal_obj_poses
455
+ datum["current_obj_poses"] = current_obj_poses
456
+ datum["target_objs"] = target_objs
457
+ datum["initial_scene"] = initial_scene
458
+ datum["ids"] = ids
459
+ datum["goal_specification"] = goal_specification
460
+ datum["current_pc_poses"] = current_pc_poses
461
+
462
+ return datum
463
+
464
+ @staticmethod
465
+ def convert_to_tensors(datum, tokenizer):
466
+ tensors = {
467
+ "pcs": torch.stack(datum["pcs"], dim=0),
468
+ "sentence": torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]])),
469
+ "goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
470
+ "type_index": torch.LongTensor(np.array(datum["type_index"])),
471
+ "position_index": torch.LongTensor(np.array(datum["position_index"])),
472
+ "pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
473
+ "t": datum["t"],
474
+ "filename": datum["filename"]
475
+ }
476
+ return tensors
477
+
478
+ def __getitem__(self, idx):
479
+
480
+ datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
481
+ self.tokenizer)
482
+
483
+ return datum
484
+
485
+ def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
486
+ tensor_x = {}
487
+
488
+ tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
489
+ tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
490
+ if not inference_mode:
491
+ tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
492
+
493
+ tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
494
+ tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
495
+ tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
496
+
497
+ return tensor_x
498
+
499
+
500
+ def compute_min_max(dataloader):
501
+
502
+ # tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
503
+ # -0.9079, -0.8668, -0.9105, -0.4186])
504
+ # tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
505
+ # 0.4787, 0.6421, 1.0000])
506
+ # tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
507
+ # -0.0000, 0.0000, 0.0000, 1.0000])
508
+ # tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
509
+ # 0.0000, 0.0000, 1.0000])
510
+
511
+ min_value = torch.ones(16) * 10000
512
+ max_value = torch.ones(16) * -10000
513
+ for d in tqdm(dataloader):
514
+ goal_poses = d["goal_poses"]
515
+ goal_poses = goal_poses.reshape(-1, 16)
516
+ current_max, _ = torch.max(goal_poses, dim=0)
517
+ current_min, _ = torch.min(goal_poses, dim=0)
518
+ max_value[max_value < current_max] = current_max[max_value < current_max]
519
+ max_value[max_value > current_min] = current_min[max_value > current_min]
520
+ print(f"{min_value} - {max_value}")
521
+
522
+
523
+ if __name__ == "__main__":
524
+
525
+ tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
526
+
527
+ data_roots = []
528
+ index_roots = []
529
+ for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
530
+ data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
531
+ index_roots.append(index)
532
+
533
+ dataset = SemanticArrangementDataset(data_roots=data_roots,
534
+ index_roots=index_roots,
535
+ split="valid", tokenizer=tokenizer,
536
+ max_num_target_objects=7,
537
+ max_num_distractor_objects=5,
538
+ max_num_shape_parameters=5,
539
+ max_num_rearrange_features=0,
540
+ max_num_anchor_features=0,
541
+ num_pts=1024,
542
+ use_virtual_structure_frame=True,
543
+ ignore_distractor_objects=True,
544
+ ignore_rgb=True,
545
+ filter_num_moved_objects_range=None, # [5, 5]
546
+ data_augmentation=False,
547
+ shuffle_object_index=False,
548
+ debug=False)
549
+
550
+ # print(len(dataset))
551
+ # for d in dataset:
552
+ # print("\n\n" + "="*100)
553
+
554
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
555
+ for i, d in enumerate(tqdm(dataloader)):
556
+ pass
557
+ # for k in d:
558
+ # if isinstance(d[k], torch.Tensor):
559
+ # print("--size", k, d[k].shape)
560
+ # for k in d:
561
+ # print(k, d[k])
562
+ #
563
+ # input("next?")
src/StructDiffusion/diffusion/__init__.py ADDED
File without changes
src/StructDiffusion/diffusion/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (181 Bytes). View file
 
src/StructDiffusion/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (185 Bytes). View file
 
src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-37.pyc ADDED
Binary file (2.57 kB). View file
 
src/StructDiffusion/diffusion/__pycache__/noise_schedule.cpython-38.pyc ADDED
Binary file (2.57 kB). View file
 
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-37.pyc ADDED
Binary file (2.25 kB). View file
 
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc ADDED
Binary file (2.27 kB). View file
 
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-37.pyc ADDED
Binary file (5.74 kB). View file
 
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc ADDED
Binary file (5.71 kB). View file
 
src/StructDiffusion/diffusion/noise_schedule.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def cosine_beta_schedule(timesteps, s=0.008):
7
+ """
8
+ cosine schedule as proposed in https://arxiv.org/abs/2102.09672
9
+ """
10
+ steps = timesteps + 1
11
+ x = torch.linspace(0, timesteps, steps)
12
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
13
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
14
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
15
+ return torch.clip(betas, 0.0001, 0.9999)
16
+
17
+
18
+ def linear_beta_schedule(timesteps):
19
+ beta_start = 0.0001
20
+ beta_end = 0.02
21
+ return torch.linspace(beta_start, beta_end, timesteps)
22
+
23
+
24
+ def quadratic_beta_schedule(timesteps):
25
+ beta_start = 0.0001
26
+ beta_end = 0.02
27
+ return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
28
+
29
+
30
+ def sigmoid_beta_schedule(timesteps):
31
+ beta_start = 0.0001
32
+ beta_end = 0.02
33
+ betas = torch.linspace(-6, 6, timesteps)
34
+ return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
35
+
36
+
37
+ class NoiseSchedule:
38
+
39
+ def __init__(self, timesteps=200):
40
+
41
+ self.timesteps = timesteps
42
+
43
+ # define beta schedule
44
+ self.betas = linear_beta_schedule(timesteps=timesteps)
45
+ # self.betas = cosine_beta_schedule(timesteps=timesteps)
46
+
47
+ # define alphas
48
+ self.alphas = 1. - self.betas
49
+ # alphas_cumprod: alpha bar
50
+ self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
51
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
52
+ self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
53
+
54
+ # calculations for diffusion q(x_t | x_{t-1}) and others
55
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
56
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
57
+
58
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
59
+ self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
60
+
61
+
62
+ def extract(a, t, x_shape):
63
+ batch_size = t.shape[0]
64
+ out = a.gather(-1, t.cpu())
65
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
66
+
67
+
68
+ # forward diffusion (using the nice property)
69
+ def q_sample(x_start, t, noise_schedule, noise=None):
70
+ if noise is None:
71
+ noise = torch.randn_like(x_start)
72
+
73
+ sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape)
74
+ # print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t)
75
+ sqrt_one_minus_alphas_cumprod_t = extract(
76
+ noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape
77
+ )
78
+ # print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t)
79
+ # print("noise", noise)
80
+
81
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
src/StructDiffusion/diffusion/pose_conversion.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch3d.transforms as tra3d
4
+
5
+ from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d
6
+
7
+
8
+ def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs):
9
+
10
+ # important: we need to get the first two columns, not first two rows
11
+ # array([[ 3, 4, 5],
12
+ # [ 6, 7, 8],
13
+ # [ 9, 10, 11]])
14
+ xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10]
15
+
16
+ # print(batch_data["obj_xyztheta_inputs"].shape)
17
+ # print(batch_data["struct_xyztheta_inputs"].shape)
18
+
19
+ # only get the first and second columns of rotation
20
+ obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9
21
+ struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9
22
+
23
+ x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9
24
+
25
+ # print(x.shape)
26
+
27
+ return x
28
+
29
+
30
+ def get_diffusion_variables_from_H(poses):
31
+ """
32
+ [[0,1,2,3],
33
+ [4,5,6,7],
34
+ [8,9,10,11],
35
+ [12,13,14,15]
36
+ :param obj_xyztheta_inputs: B, N, 4, 4
37
+ :return:
38
+ """
39
+
40
+ xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9]
41
+
42
+ B, N, _, _ = poses.shape
43
+ x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9
44
+ return x
45
+
46
+
47
+ def get_struct_objs_poses(x):
48
+
49
+ on_gpu = x.is_cuda
50
+ if not on_gpu:
51
+ x = x.cuda()
52
+
53
+ # assert x.is_cuda, "compute_rotation_matrix_from_ortho6d requires input to be on gpu"
54
+ device = x.device
55
+
56
+ # important: the noisy x can go out of bounds
57
+ x = torch.clamp(x, min=-1, max=1)
58
+
59
+ # x: B, 1 + N, 9
60
+ B = x.shape[0]
61
+ N = x.shape[1] - 1
62
+
63
+ # compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3]
64
+ x_6d = x[:, :, 3:].reshape(-1, 6)
65
+ x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3
66
+
67
+ x_trans = x[:, :, :3] # B, 1 + N, 3
68
+
69
+ x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device)
70
+ x_full[:, :, :3, :3] = x_rot
71
+ x_full[:, :, :3, 3] = x_trans
72
+
73
+ struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4
74
+ pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4
75
+
76
+ if not on_gpu:
77
+ struct_pose = struct_pose.cpu()
78
+ pc_poses_in_struct = pc_poses_in_struct.cpu()
79
+
80
+ return struct_pose, pc_poses_in_struct
81
+
82
+
83
+ def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
84
+
85
+ device = obj_xyzs.device
86
+
87
+ # obj_xyzs: B, N, P, 3
88
+ # struct_pose: B, 1, 4, 4
89
+ # pc_poses_in_struct: B, N, 4, 4
90
+ B, N, _, _ = pc_poses_in_struct.shape
91
+ _, _, P, _ = obj_xyzs.shape
92
+
93
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
94
+ # print(torch.mean(obj_xyzs, dim=2).shape)
95
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4
96
+
97
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
98
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
99
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
100
+
101
+ goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4
102
+ goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4
103
+ return current_pc_poses, goal_pc_poses
src/StructDiffusion/diffusion/sampler.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import pytorch3d.transforms as tra3d
4
+
5
+ from StructDiffusion.diffusion.noise_schedule import extract
6
+ from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
7
+ from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs, move_pc_and_create_scene_new
8
+
9
+ class Sampler:
10
+
11
+ def __init__(self, model_class, checkpoint_path, device, debug=False):
12
+
13
+ self.debug = debug
14
+ self.device = device
15
+
16
+ self.model = model_class.load_from_checkpoint(checkpoint_path)
17
+ self.backbone = self.model.model
18
+ self.backbone.to(device)
19
+ self.backbone.eval()
20
+
21
+ def sample(self, batch, num_poses):
22
+
23
+ noise_schedule = self.model.noise_schedule
24
+
25
+ B = batch["pcs"].shape[0]
26
+
27
+ x_noisy = torch.randn((B, num_poses, 9), device=self.device)
28
+
29
+ xs = []
30
+ for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
31
+ desc='sampling loop time step', total=noise_schedule.timesteps):
32
+
33
+ t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
34
+
35
+ # noise schedule
36
+ betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
37
+ sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
38
+ sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
39
+
40
+ # predict noise
41
+ pcs = batch["pcs"]
42
+ sentence = batch["sentence"]
43
+ type_index = batch["type_index"]
44
+ position_index = batch["position_index"]
45
+ pad_mask = batch["pad_mask"]
46
+ # calling the backbone instead of the pytorch-lightning model
47
+ with torch.no_grad():
48
+ predicted_noise = self.backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
49
+
50
+ # compute noisy x at t
51
+ model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
52
+ if t_index == 0:
53
+ x_noisy = model_mean
54
+ else:
55
+ posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
56
+ noise = torch.randn_like(x_noisy)
57
+ x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
58
+
59
+ xs.append(x_noisy)
60
+
61
+ xs = list(reversed(xs))
62
+ return xs
63
+
64
+ class SamplerV2:
65
+
66
+ def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
67
+ collision_model_class, collision_checkpoint_path,
68
+ device, debug=False):
69
+
70
+ self.debug = debug
71
+ self.device = device
72
+
73
+ self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
74
+ self.diffusion_backbone = self.diffusion_model.model
75
+ self.diffusion_backbone.to(device)
76
+ self.diffusion_backbone.eval()
77
+
78
+ self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
79
+ self.collision_backbone = self.collision_model.model
80
+ self.collision_backbone.to(device)
81
+ self.collision_backbone.eval()
82
+
83
+ def sample(self, batch, num_poses):
84
+
85
+ noise_schedule = self.diffusion_model.noise_schedule
86
+
87
+ B = batch["pcs"].shape[0]
88
+
89
+ x_noisy = torch.randn((B, num_poses, 9), device=self.device)
90
+
91
+ xs = []
92
+ for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
93
+ desc='sampling loop time step', total=noise_schedule.timesteps):
94
+
95
+ t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
96
+
97
+ # noise schedule
98
+ betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
99
+ sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
100
+ sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
101
+
102
+ # predict noise
103
+ pcs = batch["pcs"]
104
+ sentence = batch["sentence"]
105
+ type_index = batch["type_index"]
106
+ position_index = batch["position_index"]
107
+ pad_mask = batch["pad_mask"]
108
+ # calling the backbone instead of the pytorch-lightning model
109
+ with torch.no_grad():
110
+ predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
111
+
112
+ # compute noisy x at t
113
+ model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
114
+ if t_index == 0:
115
+ x_noisy = model_mean
116
+ else:
117
+ posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
118
+ noise = torch.randn_like(x_noisy)
119
+ x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
120
+
121
+ xs.append(x_noisy)
122
+
123
+ xs = list(reversed(xs))
124
+
125
+ visualize = True
126
+
127
+ struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
128
+ # struct_pose: B, 1, 4, 4
129
+ # pc_poses_in_struct: B, N, 4, 4
130
+
131
+ S = B
132
+ num_elite = 10
133
+ ####################################################
134
+ # only keep one copy
135
+
136
+ # N, P, 3
137
+ obj_xyzs = batch["pcs"][0][:, :, :3]
138
+ print("obj_xyzs shape", obj_xyzs.shape)
139
+
140
+ # 1, N
141
+ # object_pad_mask: padding location has 1
142
+ num_target_objs = num_poses
143
+ if self.diffusion_backbone.use_virtual_structure_frame:
144
+ num_target_objs -= 1
145
+ object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
146
+ target_object_inds = 1 - object_pad_mask
147
+ print("target_object_inds shape", target_object_inds.shape)
148
+ print("target_object_inds", target_object_inds)
149
+
150
+ N, P, _ = obj_xyzs.shape
151
+ print("S, N, P: {}, {}, {}".format(S, N, P))
152
+
153
+ ####################################################
154
+ # S, N, ...
155
+
156
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
157
+ struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
158
+
159
+ new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
160
+ current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
161
+ current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
162
+ current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
163
+
164
+ # optimize xyzrpy
165
+ obj_params = torch.zeros((S, N, 6)).to(self.device)
166
+ obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
167
+ obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
168
+ #
169
+ # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
170
+ #
171
+ # if visualize:
172
+ # print("visualizing rearrangements predicted by the generator")
173
+ # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
174
+
175
+ ####################################################
176
+ # rank
177
+
178
+ # evaluate in batches
179
+ scores = torch.zeros(S).to(self.device)
180
+ no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
181
+ num_batches = int(S / B)
182
+ if S % B != 0:
183
+ num_batches += 1
184
+ for b in range(num_batches):
185
+ if b + 1 == num_batches:
186
+ cur_batch_idxs_start = b * B
187
+ cur_batch_idxs_end = S
188
+ else:
189
+ cur_batch_idxs_start = b * B
190
+ cur_batch_idxs_end = (b + 1) * B
191
+ cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
192
+
193
+ # print("current batch idxs start", cur_batch_idxs_start)
194
+ # print("current batch idxs end", cur_batch_idxs_end)
195
+ # print("size of the current batch", cur_batch_size)
196
+
197
+ batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
198
+ batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
199
+ batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
200
+
201
+ new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
202
+ move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
203
+ target_object_inds, self.device,
204
+ return_scene_pts=False,
205
+ return_scene_pts_and_pc_idxs=False,
206
+ num_scene_pts=False,
207
+ normalize_pc=False,
208
+ return_pair_pc=True,
209
+ num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
210
+ normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
211
+
212
+ #######################################
213
+ # predict whether there are pairwise collisions
214
+ # if collision_score_weight > 0:
215
+ with torch.no_grad():
216
+ _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
217
+ # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
218
+ collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
219
+ collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
220
+
221
+ # debug
222
+ # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
223
+ # print("batch id", bi)
224
+ # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
225
+ # print("pair", pi)
226
+ # # obj_pair_xyzs: 2 * P, 5
227
+ # print("collision score", collision_scores[bi, pi])
228
+ # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
229
+
230
+ # 1 - mean() since the collision model predicts 1 if there is a collision
231
+ no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
232
+ if visualize:
233
+ print("no intersection scores", no_intersection_scores)
234
+ # #######################################
235
+ # if discriminator_score_weight > 0:
236
+ # # # debug:
237
+ # # print(subsampled_scene_xyz.shape)
238
+ # # print(subsampled_scene_xyz[0])
239
+ # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
240
+ # #
241
+ # with torch.no_grad():
242
+ #
243
+ # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
244
+ # # local_sentence = sentence[:, [0, 4]]
245
+ # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
246
+ # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
247
+ #
248
+ # sentence_disc = torch.LongTensor(
249
+ # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
250
+ # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
251
+ # position_index_dic = torch.LongTensor(raw_position_index_discriminator)
252
+ #
253
+ # preds = discriminator_model.forward(subsampled_scene_xyz,
254
+ # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
255
+ # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
256
+ # 1).to(device),
257
+ # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
258
+ # device))
259
+ # # preds = discriminator_model.forward(subsampled_scene_xyz)
260
+ # preds = discriminator_model.convert_logits(preds)
261
+ # preds = preds["is_circle"] # cur_batch_size,
262
+ # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
263
+ # if visualize:
264
+ # print("discriminator scores", scores)
265
+
266
+ # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
267
+ scores = no_intersection_scores
268
+ sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
269
+ elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
270
+ elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
271
+ elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
272
+ elite_scores = scores[sort_idx]
273
+ print("elite scores:", elite_scores)
274
+
275
+ ####################################################
276
+ # # visualize best samples
277
+ # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
278
+ # batch_current_pc_pose = current_pc_pose[0: num_elite * N]
279
+ # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
280
+ # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
281
+ # target_object_inds, self.device,
282
+ # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
283
+ # if visualize:
284
+ # print("visualizing elite rearrangements ranked by collision model/discriminator")
285
+ # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
286
+
287
+ # num_elite, N, 6
288
+ elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
289
+ pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
290
+ pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
291
+ pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
292
+ pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
293
+
294
+ struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
295
+
296
+ return struct_pose, pc_poses_in_struct
src/StructDiffusion/language/__init__.py ADDED
File without changes
src/StructDiffusion/language/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (180 Bytes). View file
 
src/StructDiffusion/language/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (184 Bytes). View file
 
src/StructDiffusion/language/__pycache__/tokenizer.cpython-37.pyc ADDED
Binary file (11.4 kB). View file