Spaces:
Paused
Paused
Weiyu Liu
commited on
Commit
·
f392320
1
Parent(s):
38a6100
add natural language model and app
Browse files- __pycache__/app.cpython-38.pyc +0 -0
- app.py +197 -45
- app_v0.py +282 -0
- app_v1.py +217 -0
- configs/conditional_pose_diffusion_language.yaml +92 -0
- data/template_sentence_data.pkl +3 -0
- requirements.txt +2 -1
- src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc +0 -0
- src/StructDiffusion/data/__pycache__/semantic_arrangement_language.cpython-38.pyc +0 -0
- src/StructDiffusion/data/__pycache__/semantic_arrangement_language_demo.cpython-38.pyc +0 -0
- src/StructDiffusion/data/pairwise_collision.py +19 -63
- src/StructDiffusion/data/semantic_arrangement.py +1 -44
- src/StructDiffusion/data/semantic_arrangement_language.py +633 -0
- src/StructDiffusion/data/semantic_arrangement_language_demo.py +693 -0
- src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc +0 -0
- src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc +0 -0
- src/StructDiffusion/diffusion/sampler.py +243 -235
- src/StructDiffusion/language/__pycache__/sentence_encoder.cpython-38.pyc +0 -0
- src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc +0 -0
- src/StructDiffusion/language/convert_to_natural_language.ipynb +773 -0
- src/StructDiffusion/language/sentence_encoder.py +23 -0
- src/StructDiffusion/language/test_parrot_paraphrase.py +38 -0
- src/StructDiffusion/language/tokenizer.py +1 -22
- src/StructDiffusion/models/__pycache__/models.cpython-38.pyc +0 -0
- src/StructDiffusion/models/models.py +17 -3
- src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc +0 -0
- src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc +0 -0
- src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc +0 -0
- src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc +0 -0
- src/StructDiffusion/utils/__pycache__/tra3d.cpython-38.pyc +0 -0
- src/StructDiffusion/utils/batch_inference.py +11 -247
- src/StructDiffusion/utils/files.py +9 -1
- src/StructDiffusion/utils/np_speed_test.py +41 -0
- src/StructDiffusion/utils/rearrangement.py +29 -4
- src/StructDiffusion/utils/tra3d.py +148 -0
- tmp_data/input_scene.glb +0 -0
- tmp_data/input_scene_102.glb +0 -0
- tmp_data/input_scene_None.glb +0 -0
- tmp_data/output_scene.glb +0 -0
- tmp_data/output_scene_102.glb +0 -0
- wandb_logs/StructDiffusion/CollisionDiscriminator/checkpoints/epoch=199-step=653400.ckpt +3 -0
- wandb_logs/StructDiffusion/ConditionalPoseDiffusionLanguage/checkpoints/epoch=199-step=100000.ckpt +3 -0
__pycache__/app.cpython-38.pyc
CHANGED
Binary files a/__pycache__/app.cpython-38.pyc and b/__pycache__/app.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -10,13 +10,15 @@ from omegaconf import OmegaConf
|
|
10 |
import sys
|
11 |
sys.path.append('./src')
|
12 |
|
13 |
-
from StructDiffusion.data.
|
14 |
from StructDiffusion.language.tokenizer import Tokenizer
|
15 |
-
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
|
16 |
-
from StructDiffusion.diffusion.sampler import Sampler
|
17 |
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
18 |
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
19 |
-
from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
|
|
|
|
|
20 |
import StructDiffusion.utils.transformations as tra
|
21 |
|
22 |
|
@@ -65,23 +67,31 @@ class Infer_Wrapper:
|
|
65 |
|
66 |
def __init__(self, args, cfg):
|
67 |
|
|
|
|
|
68 |
# load
|
69 |
pl.seed_everything(args.eval_random_seed)
|
70 |
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
71 |
|
72 |
-
|
73 |
-
|
74 |
|
75 |
self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
76 |
# override ignore_rgb for visualization
|
77 |
cfg.DATASET.ignore_rgb = False
|
78 |
self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
|
79 |
|
80 |
-
self.sampler =
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def visualize_scene(self, di, session_id):
|
83 |
-
|
84 |
-
|
|
|
85 |
|
86 |
obj_xyz = raw_datum["pcs"]
|
87 |
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
|
@@ -93,20 +103,77 @@ class Infer_Wrapper:
|
|
93 |
|
94 |
return language_command, scene_filename
|
95 |
|
96 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
|
100 |
-
raw_datum = self.dataset.
|
101 |
-
|
102 |
-
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
|
103 |
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
|
104 |
|
105 |
-
num_poses =
|
106 |
-
|
107 |
|
108 |
-
|
109 |
-
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
|
110 |
|
111 |
# vis
|
112 |
vis_obj_xyzs = new_obj_xyzs[:3]
|
@@ -115,18 +182,11 @@ class Infer_Wrapper:
|
|
115 |
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
|
116 |
vis_obj_xyzs = vis_obj_xyzs.numpy()
|
117 |
|
118 |
-
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
|
119 |
-
# if verbose:
|
120 |
-
# print("example {}".format(bi))
|
121 |
-
# print(vis_obj_xyz.shape)
|
122 |
-
#
|
123 |
-
# if trimesh:
|
124 |
-
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
|
125 |
vis_obj_xyz = vis_obj_xyzs[0]
|
126 |
-
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
|
127 |
-
|
|
|
128 |
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
129 |
-
|
130 |
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
|
131 |
scene.export(scene_filename)
|
132 |
|
@@ -167,10 +227,13 @@ class Infer_Wrapper:
|
|
167 |
|
168 |
args = OmegaConf.create()
|
169 |
args.base_config_file = "./configs/base.yaml"
|
170 |
-
args.config_file = "./configs/
|
171 |
-
args.
|
|
|
172 |
args.eval_random_seed = 42
|
173 |
-
args.num_samples =
|
|
|
|
|
174 |
|
175 |
base_cfg = OmegaConf.load(args.base_config_file)
|
176 |
cfg = OmegaConf.load(args.config_file)
|
@@ -178,34 +241,123 @@ cfg = OmegaConf.merge(base_cfg, cfg)
|
|
178 |
|
179 |
infer_wrapper = Infer_Wrapper(args, cfg)
|
180 |
|
181 |
-
# version
|
182 |
-
# demo = gr.
|
183 |
-
#
|
184 |
-
#
|
185 |
-
# #
|
186 |
-
#
|
187 |
-
# )
|
188 |
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
# demo.launch()
|
190 |
|
191 |
# version 1
|
192 |
-
demo = gr.Blocks(theme=gr.themes.Soft())
|
|
|
193 |
with demo:
|
194 |
gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
|
195 |
# font-size:18px
|
196 |
gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
|
197 |
|
198 |
session_id = gr.State(value=np.random.randint(0, 1000))
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
language_command = gr.Textbox(label="Input Language Command")
|
202 |
-
output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
|
203 |
|
204 |
-
b1 = gr.Button("Show Input Language and Scene")
|
205 |
b2 = gr.Button("Generate 3D Structure")
|
206 |
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
demo.queue(concurrency_count=10)
|
211 |
demo.launch()
|
|
|
10 |
import sys
|
11 |
sys.path.append('./src')
|
12 |
|
13 |
+
from StructDiffusion.data.semantic_arrangement_language_demo import SemanticArrangementDataset
|
14 |
from StructDiffusion.language.tokenizer import Tokenizer
|
15 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
|
16 |
+
from StructDiffusion.diffusion.sampler import Sampler, SamplerV2
|
17 |
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
18 |
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
19 |
+
from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh, get_trimesh_scene_with_table
|
20 |
+
import StructDiffusion.utils.transformations as tra
|
21 |
+
from StructDiffusion.language.sentence_encoder import SentenceBertEncoder
|
22 |
import StructDiffusion.utils.transformations as tra
|
23 |
|
24 |
|
|
|
67 |
|
68 |
def __init__(self, args, cfg):
|
69 |
|
70 |
+
self.num_pts = cfg.DATASET.num_pts
|
71 |
+
|
72 |
# load
|
73 |
pl.seed_everything(args.eval_random_seed)
|
74 |
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
75 |
|
76 |
+
diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
|
77 |
+
collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
|
78 |
|
79 |
self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
80 |
# override ignore_rgb for visualization
|
81 |
cfg.DATASET.ignore_rgb = False
|
82 |
self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
|
83 |
|
84 |
+
self.sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
|
85 |
+
PairwiseCollisionModel, collision_checkpoint_path, self.device)
|
86 |
+
|
87 |
+
self.sentence_encoder = SentenceBertEncoder()
|
88 |
+
|
89 |
+
self.session_id_to_obj_xyzs = {}
|
90 |
|
91 |
def visualize_scene(self, di, session_id):
|
92 |
+
|
93 |
+
raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
|
94 |
+
language_command = raw_datum["template_sentence"]
|
95 |
|
96 |
obj_xyz = raw_datum["pcs"]
|
97 |
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
|
|
|
103 |
|
104 |
return language_command, scene_filename
|
105 |
|
106 |
+
def build_scene(self, mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1,
|
107 |
+
mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2,
|
108 |
+
mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3,
|
109 |
+
mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4,
|
110 |
+
mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5, session_id):
|
111 |
+
|
112 |
+
object_list = [(mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1),
|
113 |
+
(mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2),
|
114 |
+
(mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3),
|
115 |
+
(mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4),
|
116 |
+
(mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5)]
|
117 |
+
|
118 |
+
scene = get_trimesh_scene_with_table()
|
119 |
+
|
120 |
+
obj_xyzs = []
|
121 |
+
for mesh_filename, x, y, z, ai, aj, ak, scale in object_list:
|
122 |
+
if mesh_filename is None:
|
123 |
+
continue
|
124 |
+
obj_mesh = trimesh.load(mesh_filename)
|
125 |
+
obj_mesh.apply_scale(scale)
|
126 |
+
z_min = obj_mesh.bounds[0, 2]
|
127 |
+
tform = tra.euler_matrix(ai, aj, ak)
|
128 |
+
tform[:3, 3] = [x, y, z - z_min]
|
129 |
+
obj_mesh.apply_transform(tform)
|
130 |
+
obj_xyz = obj_mesh.sample(self.num_pts)
|
131 |
+
obj = trimesh.PointCloud(obj_xyz)
|
132 |
+
scene.add_geometry(obj)
|
133 |
+
|
134 |
+
obj_xyzs.append(obj_xyz)
|
135 |
+
|
136 |
+
self.session_id_to_obj_xyzs[session_id] = obj_xyzs
|
137 |
+
|
138 |
+
# scene.show()
|
139 |
+
|
140 |
+
# obj_file = "/home/weiyu/data_drive/StructDiffusion/housekeep_custom_handpicked_small/visual/book_Eat_to_Live_The_Amazing_NutrientRich_Program_for_Fast_and_Sustained_Weight_Loss_Revised_Edition_Book_L/model.obj"
|
141 |
+
# obj = trimesh.load(obj_file)
|
142 |
+
#
|
143 |
+
# scene = get_trimesh_scene_with_table()
|
144 |
+
# scene.add_geometry(obj)
|
145 |
+
#
|
146 |
+
# scene.show()
|
147 |
+
|
148 |
+
# raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
|
149 |
+
# language_command = raw_datum["template_sentence"]
|
150 |
+
#
|
151 |
+
# obj_xyz = raw_datum["pcs"]
|
152 |
+
# scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz],
|
153 |
+
# return_scene=True)
|
154 |
+
|
155 |
+
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi / 2))
|
156 |
+
scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
|
157 |
+
scene.export(scene_filename)
|
158 |
+
|
159 |
+
return scene_filename
|
160 |
+
|
161 |
+
# return language_command, scene_filename
|
162 |
+
|
163 |
+
def infer(self, language_command, session_id, progress=gr.Progress()):
|
164 |
+
|
165 |
+
obj_xyzs = self.session_id_to_obj_xyzs[session_id]
|
166 |
|
167 |
+
sentence_embedding = self.sentence_encoder.encode([language_command]).flatten()
|
168 |
|
169 |
+
raw_datum = self.dataset.build_data_from_xyzs(obj_xyzs, sentence_embedding)
|
170 |
+
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer, use_sentence_embedding=True)
|
|
|
171 |
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
|
172 |
|
173 |
+
num_poses = raw_datum["num_goal_poses"]
|
174 |
+
struct_pose, pc_poses_in_struct = self.sampler.sample(batch, num_poses, args.num_elites, args.discriminator_batch_size)
|
175 |
|
176 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"][:args.num_elites], struct_pose, pc_poses_in_struct)
|
|
|
177 |
|
178 |
# vis
|
179 |
vis_obj_xyzs = new_obj_xyzs[:3]
|
|
|
182 |
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
|
183 |
vis_obj_xyzs = vis_obj_xyzs.numpy()
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
vis_obj_xyz = vis_obj_xyzs[0]
|
186 |
+
# scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
|
187 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], obj_rgbs=None, return_scene=True)
|
188 |
+
scene.show()
|
189 |
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
|
|
190 |
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
|
191 |
scene.export(scene_filename)
|
192 |
|
|
|
227 |
|
228 |
args = OmegaConf.create()
|
229 |
args.base_config_file = "./configs/base.yaml"
|
230 |
+
args.config_file = "./configs/conditional_pose_diffusion_language.yaml"
|
231 |
+
args.diffusion_checkpoint_id = "ConditionalPoseDiffusionLanguage"
|
232 |
+
args.collision_checkpoint_id = "CollisionDiscriminator"
|
233 |
args.eval_random_seed = 42
|
234 |
+
args.num_samples = 50
|
235 |
+
args.num_elites = 3
|
236 |
+
args.discriminator_batch_size = 10
|
237 |
|
238 |
base_cfg = OmegaConf.load(args.base_config_file)
|
239 |
cfg = OmegaConf.load(args.config_file)
|
|
|
241 |
|
242 |
infer_wrapper = Infer_Wrapper(args, cfg)
|
243 |
|
244 |
+
# # version 1
|
245 |
+
# demo = gr.Blocks(theme=gr.themes.Soft())
|
246 |
+
# with demo:
|
247 |
+
# gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
|
248 |
+
# # font-size:18px
|
249 |
+
# gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
|
|
|
250 |
#
|
251 |
+
# session_id = gr.State(value=np.random.randint(0, 1000))
|
252 |
+
# data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
|
253 |
+
# input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
|
254 |
+
# language_command = gr.Textbox(label="Input Language Command")
|
255 |
+
# output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
|
256 |
+
#
|
257 |
+
# b1 = gr.Button("Show Input Language and Scene")
|
258 |
+
# b2 = gr.Button("Generate 3D Structure")
|
259 |
+
#
|
260 |
+
# b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
|
261 |
+
# b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
|
262 |
+
#
|
263 |
+
# demo.queue(concurrency_count=10)
|
264 |
# demo.launch()
|
265 |
|
266 |
# version 1
|
267 |
+
# demo = gr.Blocks(theme=gr.themes.Soft())
|
268 |
+
demo = gr.Blocks()
|
269 |
with demo:
|
270 |
gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
|
271 |
# font-size:18px
|
272 |
gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
|
273 |
|
274 |
session_id = gr.State(value=np.random.randint(0, 1000))
|
275 |
+
with gr.Tab("Object 1"):
|
276 |
+
with gr.Column(scale=1, min_width=600):
|
277 |
+
mesh_filename_1 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
|
278 |
+
with gr.Row():
|
279 |
+
x_1 = gr.Slider(0, 1, label="x")
|
280 |
+
y_1 = gr.Slider(-0.5, 0.5, label="y")
|
281 |
+
z_1 = gr.Slider(0, 0.5, label="z")
|
282 |
+
with gr.Row():
|
283 |
+
ai_1 = gr.Slider(0, np.pi * 2, label="roll")
|
284 |
+
aj_1 = gr.Slider(0, np.pi * 2, label="pitch")
|
285 |
+
ak_1 = gr.Slider(0, np.pi * 2, label="yaw")
|
286 |
+
scale_1 = gr.Slider(0, 1)
|
287 |
+
with gr.Tab("Object 2"):
|
288 |
+
with gr.Column(scale=1, min_width=600):
|
289 |
+
mesh_filename_2 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
|
290 |
+
with gr.Row():
|
291 |
+
x_2 = gr.Slider(0, 1, label="x")
|
292 |
+
y_2 = gr.Slider(-0.5, 0.5, label="y")
|
293 |
+
z_2 = gr.Slider(0, 0.5, label="z")
|
294 |
+
with gr.Row():
|
295 |
+
ai_2 = gr.Slider(0, np.pi * 2, label="roll")
|
296 |
+
aj_2 = gr.Slider(0, np.pi * 2, label="pitch")
|
297 |
+
ak_2 = gr.Slider(0, np.pi * 2, label="yaw")
|
298 |
+
scale_2 = gr.Slider(0, 1)
|
299 |
+
with gr.Tab("Object 3"):
|
300 |
+
with gr.Column(scale=1, min_width=600):
|
301 |
+
mesh_filename_3 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
|
302 |
+
with gr.Row():
|
303 |
+
x_3 = gr.Slider(0, 1, label="x")
|
304 |
+
y_3 = gr.Slider(-0.5, 0.5, label="y")
|
305 |
+
z_3 = gr.Slider(0, 0.5, label="z")
|
306 |
+
with gr.Row():
|
307 |
+
ai_3 = gr.Slider(0, np.pi * 2, label="roll")
|
308 |
+
aj_3 = gr.Slider(0, np.pi * 2, label="pitch")
|
309 |
+
ak_3 = gr.Slider(0, np.pi * 2, label="yaw")
|
310 |
+
scale_3 = gr.Slider(0, 1)
|
311 |
+
with gr.Tab("Object 4"):
|
312 |
+
with gr.Column(scale=1, min_width=600):
|
313 |
+
mesh_filename_4 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
|
314 |
+
with gr.Row():
|
315 |
+
x_4 = gr.Slider(0, 1, label="x")
|
316 |
+
y_4 = gr.Slider(-0.5, 0.5, label="y")
|
317 |
+
z_4 = gr.Slider(0, 0.5, label="z")
|
318 |
+
with gr.Row():
|
319 |
+
ai_4 = gr.Slider(0, np.pi * 2, label="roll")
|
320 |
+
aj_4 = gr.Slider(0, np.pi * 2, label="pitch")
|
321 |
+
ak_4 = gr.Slider(0, np.pi * 2, label="yaw")
|
322 |
+
scale_4 = gr.Slider(0, 1)
|
323 |
+
with gr.Tab("Object 5"):
|
324 |
+
with gr.Column(scale=1, min_width=600):
|
325 |
+
mesh_filename_5 = gr.Model3D(clear_color=[0, 0, 0, 0], label="Load 3D Object")
|
326 |
+
with gr.Row():
|
327 |
+
x_5 = gr.Slider(0, 1, label="x")
|
328 |
+
y_5 = gr.Slider(-0.5, 0.5, label="y")
|
329 |
+
z_5 = gr.Slider(0, 0.5, label="z")
|
330 |
+
with gr.Row():
|
331 |
+
ai_5 = gr.Slider(0, np.pi * 2, label="roll")
|
332 |
+
aj_5 = gr.Slider(0, np.pi * 2, label="pitch")
|
333 |
+
ak_5 = gr.Slider(0, np.pi * 2, label="yaw")
|
334 |
+
scale_5 = gr.Slider(0, 1)
|
335 |
+
|
336 |
+
b1 = gr.Button("Build Initial Scene")
|
337 |
+
|
338 |
+
initial_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Initial 3D Scene")
|
339 |
language_command = gr.Textbox(label="Input Language Command")
|
|
|
340 |
|
|
|
341 |
b2 = gr.Button("Generate 3D Structure")
|
342 |
|
343 |
+
output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
|
344 |
+
|
345 |
+
# data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
|
346 |
+
# input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
|
347 |
+
# language_command = gr.Textbox(label="Input Language Command")
|
348 |
+
# output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
|
349 |
+
#
|
350 |
+
# b1 = gr.Button("Show Input Language and Scene")
|
351 |
+
# b2 = gr.Button("Generate 3D Structure")
|
352 |
+
|
353 |
+
b1.click(infer_wrapper.build_scene, inputs=[mesh_filename_1, x_1, y_1, z_1, ai_1, aj_1, ak_1, scale_1,
|
354 |
+
mesh_filename_2, x_2, y_2, z_2, ai_2, aj_2, ak_2, scale_2,
|
355 |
+
mesh_filename_3, x_3, y_3, z_3, ai_3, aj_3, ak_3, scale_3,
|
356 |
+
mesh_filename_4, x_4, y_4, z_4, ai_4, aj_4, ak_4, scale_4,
|
357 |
+
mesh_filename_5, x_5, y_5, z_5, ai_5, aj_5, ak_5, scale_5,
|
358 |
+
session_id], outputs=[initial_scene])
|
359 |
+
|
360 |
+
b2.click(infer_wrapper.infer, inputs=[language_command, session_id], outputs=output_scene)
|
361 |
|
362 |
demo.queue(concurrency_count=10)
|
363 |
demo.launch()
|
app_v0.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import gradio as gr
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import sys
|
11 |
+
sys.path.append('./src')
|
12 |
+
|
13 |
+
from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
|
14 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
15 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
|
16 |
+
from StructDiffusion.diffusion.sampler import Sampler
|
17 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
18 |
+
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
19 |
+
from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
|
20 |
+
import StructDiffusion.utils.transformations as tra
|
21 |
+
|
22 |
+
|
23 |
+
def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
|
24 |
+
|
25 |
+
device = obj_xyzs.device
|
26 |
+
|
27 |
+
# obj_xyzs: B, N, P, 3 or 6
|
28 |
+
# struct_pose: B, 1, 4, 4
|
29 |
+
# pc_poses_in_struct: B, N, 4, 4
|
30 |
+
|
31 |
+
B, N, _, _ = pc_poses_in_struct.shape
|
32 |
+
_, _, P, _ = obj_xyzs.shape
|
33 |
+
|
34 |
+
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
|
35 |
+
# print(torch.mean(obj_xyzs, dim=2).shape)
|
36 |
+
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
|
37 |
+
current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
|
38 |
+
|
39 |
+
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
|
40 |
+
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
|
41 |
+
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
|
42 |
+
|
43 |
+
goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
|
44 |
+
# print("goal pc poses")
|
45 |
+
# print(goal_pc_pose)
|
46 |
+
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
|
47 |
+
|
48 |
+
# # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
49 |
+
# transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
50 |
+
# new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
|
51 |
+
# new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
|
52 |
+
|
53 |
+
# a verision that does not rely on pytorch3d
|
54 |
+
new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3
|
55 |
+
new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
|
56 |
+
new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
|
57 |
+
|
58 |
+
# put it back to B, N, P, 3
|
59 |
+
obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1)
|
60 |
+
|
61 |
+
return obj_xyzs
|
62 |
+
|
63 |
+
|
64 |
+
class Infer_Wrapper:
|
65 |
+
|
66 |
+
def __init__(self, args, cfg):
|
67 |
+
|
68 |
+
# load
|
69 |
+
pl.seed_everything(args.eval_random_seed)
|
70 |
+
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
71 |
+
|
72 |
+
checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
|
73 |
+
checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
|
74 |
+
|
75 |
+
self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
76 |
+
# override ignore_rgb for visualization
|
77 |
+
cfg.DATASET.ignore_rgb = False
|
78 |
+
self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
|
79 |
+
|
80 |
+
self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
|
81 |
+
|
82 |
+
def visualize_scene(self, di, session_id):
|
83 |
+
raw_datum = self.dataset.get_raw_data(di)
|
84 |
+
language_command = self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])
|
85 |
+
|
86 |
+
obj_xyz = raw_datum["pcs"]
|
87 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
|
88 |
+
|
89 |
+
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
90 |
+
|
91 |
+
scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
|
92 |
+
scene.export(scene_filename)
|
93 |
+
|
94 |
+
return language_command, scene_filename
|
95 |
+
|
96 |
+
def infer(self, di, session_id, progress=gr.Progress()):
|
97 |
+
|
98 |
+
# di = np.random.choice(len(self.dataset))
|
99 |
+
|
100 |
+
raw_datum = self.dataset.get_raw_data(di)
|
101 |
+
print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
|
102 |
+
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
|
103 |
+
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
|
104 |
+
|
105 |
+
num_poses = datum["goal_poses"].shape[0]
|
106 |
+
xs = self.sampler.sample(batch, num_poses, progress)
|
107 |
+
|
108 |
+
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
|
109 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
|
110 |
+
|
111 |
+
# vis
|
112 |
+
vis_obj_xyzs = new_obj_xyzs[:3]
|
113 |
+
if torch.is_tensor(vis_obj_xyzs):
|
114 |
+
if vis_obj_xyzs.is_cuda:
|
115 |
+
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
|
116 |
+
vis_obj_xyzs = vis_obj_xyzs.numpy()
|
117 |
+
|
118 |
+
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
|
119 |
+
# if verbose:
|
120 |
+
# print("example {}".format(bi))
|
121 |
+
# print(vis_obj_xyz.shape)
|
122 |
+
#
|
123 |
+
# if trimesh:
|
124 |
+
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
|
125 |
+
vis_obj_xyz = vis_obj_xyzs[0]
|
126 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
|
127 |
+
|
128 |
+
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
129 |
+
|
130 |
+
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
|
131 |
+
scene.export(scene_filename)
|
132 |
+
|
133 |
+
# pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
|
134 |
+
# scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
|
135 |
+
#
|
136 |
+
# vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
|
137 |
+
# vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
|
138 |
+
# vis_pc.export(pc_filename)
|
139 |
+
#
|
140 |
+
# scene = trimesh.Scene()
|
141 |
+
# # add the coordinate frame first
|
142 |
+
# # geom = trimesh.creation.axis(0.01)
|
143 |
+
# # scene.add_geometry(geom)
|
144 |
+
# table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
|
145 |
+
# table.apply_translation([0.5, 0, -0.01])
|
146 |
+
# table.visual.vertex_colors = [150, 111, 87, 125]
|
147 |
+
# scene.add_geometry(table)
|
148 |
+
# # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
|
149 |
+
# # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
|
150 |
+
# # bounds.apply_translation([0, 0, 0])
|
151 |
+
# # bounds.visual.vertex_colors = [30, 30, 30, 30]
|
152 |
+
# # scene.add_geometry(bounds)
|
153 |
+
# # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
|
154 |
+
# # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
|
155 |
+
# # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
|
156 |
+
# # [0.0, 0.0, 0.0, 1.0]])
|
157 |
+
# # RT_4x4 = np.linalg.inv(RT_4x4)
|
158 |
+
# # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
159 |
+
# # scene.camera_transform = RT_4x4
|
160 |
+
#
|
161 |
+
# mesh_list = trimesh.util.concatenate(scene.dump())
|
162 |
+
# print(mesh_list)
|
163 |
+
# trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
|
164 |
+
|
165 |
+
return scene_filename
|
166 |
+
|
167 |
+
def infer_new(self, di, session_id, progress=gr.Progress()):
|
168 |
+
|
169 |
+
# di = np.random.choice(len(self.dataset))
|
170 |
+
|
171 |
+
raw_datum = self.dataset.get_raw_data(di)
|
172 |
+
print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
|
173 |
+
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
|
174 |
+
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
|
175 |
+
|
176 |
+
num_poses = datum["goal_poses"].shape[0]
|
177 |
+
xs = self.sampler.sample(batch, num_poses, progress)
|
178 |
+
|
179 |
+
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
|
180 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
|
181 |
+
|
182 |
+
# vis
|
183 |
+
vis_obj_xyzs = new_obj_xyzs[:3]
|
184 |
+
if torch.is_tensor(vis_obj_xyzs):
|
185 |
+
if vis_obj_xyzs.is_cuda:
|
186 |
+
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
|
187 |
+
vis_obj_xyzs = vis_obj_xyzs.numpy()
|
188 |
+
|
189 |
+
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
|
190 |
+
# if verbose:
|
191 |
+
# print("example {}".format(bi))
|
192 |
+
# print(vis_obj_xyz.shape)
|
193 |
+
#
|
194 |
+
# if trimesh:
|
195 |
+
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
|
196 |
+
vis_obj_xyz = vis_obj_xyzs[0]
|
197 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
|
198 |
+
|
199 |
+
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
200 |
+
|
201 |
+
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
|
202 |
+
scene.export(scene_filename)
|
203 |
+
|
204 |
+
# pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
|
205 |
+
# scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
|
206 |
+
#
|
207 |
+
# vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
|
208 |
+
# vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
|
209 |
+
# vis_pc.export(pc_filename)
|
210 |
+
#
|
211 |
+
# scene = trimesh.Scene()
|
212 |
+
# # add the coordinate frame first
|
213 |
+
# # geom = trimesh.creation.axis(0.01)
|
214 |
+
# # scene.add_geometry(geom)
|
215 |
+
# table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
|
216 |
+
# table.apply_translation([0.5, 0, -0.01])
|
217 |
+
# table.visual.vertex_colors = [150, 111, 87, 125]
|
218 |
+
# scene.add_geometry(table)
|
219 |
+
# # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
|
220 |
+
# # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
|
221 |
+
# # bounds.apply_translation([0, 0, 0])
|
222 |
+
# # bounds.visual.vertex_colors = [30, 30, 30, 30]
|
223 |
+
# # scene.add_geometry(bounds)
|
224 |
+
# # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
|
225 |
+
# # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
|
226 |
+
# # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
|
227 |
+
# # [0.0, 0.0, 0.0, 1.0]])
|
228 |
+
# # RT_4x4 = np.linalg.inv(RT_4x4)
|
229 |
+
# # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
230 |
+
# # scene.camera_transform = RT_4x4
|
231 |
+
#
|
232 |
+
# mesh_list = trimesh.util.concatenate(scene.dump())
|
233 |
+
# print(mesh_list)
|
234 |
+
# trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
|
235 |
+
|
236 |
+
return scene_filename
|
237 |
+
|
238 |
+
|
239 |
+
args = OmegaConf.create()
|
240 |
+
args.base_config_file = "./configs/base.yaml"
|
241 |
+
args.config_file = "./configs/conditional_pose_diffusion.yaml"
|
242 |
+
args.checkpoint_id = "ConditionalPoseDiffusion"
|
243 |
+
args.eval_random_seed = 42
|
244 |
+
args.num_samples = 1
|
245 |
+
|
246 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
247 |
+
cfg = OmegaConf.load(args.config_file)
|
248 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
249 |
+
|
250 |
+
infer_wrapper = Infer_Wrapper(args, cfg)
|
251 |
+
|
252 |
+
# version 0
|
253 |
+
# demo = gr.Interface(
|
254 |
+
# fn=infer_wrapper.run,
|
255 |
+
# inputs=gr.Slider(0, len(infer_wrapper.dataset)),
|
256 |
+
# # clear color range [0-1.0]
|
257 |
+
# outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
|
258 |
+
# )
|
259 |
+
#
|
260 |
+
# demo.launch()
|
261 |
+
|
262 |
+
# version 1
|
263 |
+
demo = gr.Blocks(theme=gr.themes.Soft())
|
264 |
+
with demo:
|
265 |
+
gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
|
266 |
+
# font-size:18px
|
267 |
+
gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
|
268 |
+
|
269 |
+
session_id = gr.State(value=np.random.randint(0, 1000))
|
270 |
+
data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
|
271 |
+
input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
|
272 |
+
language_command = gr.Textbox(label="Input Language Command")
|
273 |
+
output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
|
274 |
+
|
275 |
+
b1 = gr.Button("Show Input Language and Scene")
|
276 |
+
b2 = gr.Button("Generate 3D Structure")
|
277 |
+
|
278 |
+
b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
|
279 |
+
b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
|
280 |
+
|
281 |
+
demo.queue(concurrency_count=10)
|
282 |
+
demo.launch()
|
app_v1.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import gradio as gr
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import sys
|
11 |
+
sys.path.append('./src')
|
12 |
+
|
13 |
+
from StructDiffusion.data.semantic_arrangement_language_demo import SemanticArrangementDataset
|
14 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
15 |
+
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel, PairwiseCollisionModel
|
16 |
+
from StructDiffusion.diffusion.sampler import Sampler, SamplerV2
|
17 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
18 |
+
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
|
19 |
+
from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh, get_trimesh_scene_with_table
|
20 |
+
import StructDiffusion.utils.transformations as tra
|
21 |
+
from StructDiffusion.language.sentence_encoder import SentenceBertEncoder
|
22 |
+
|
23 |
+
|
24 |
+
def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
|
25 |
+
|
26 |
+
device = obj_xyzs.device
|
27 |
+
|
28 |
+
# obj_xyzs: B, N, P, 3 or 6
|
29 |
+
# struct_pose: B, 1, 4, 4
|
30 |
+
# pc_poses_in_struct: B, N, 4, 4
|
31 |
+
|
32 |
+
B, N, _, _ = pc_poses_in_struct.shape
|
33 |
+
_, _, P, _ = obj_xyzs.shape
|
34 |
+
|
35 |
+
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
|
36 |
+
# print(torch.mean(obj_xyzs, dim=2).shape)
|
37 |
+
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
|
38 |
+
current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
|
39 |
+
|
40 |
+
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
|
41 |
+
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
|
42 |
+
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
|
43 |
+
|
44 |
+
goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
|
45 |
+
# print("goal pc poses")
|
46 |
+
# print(goal_pc_pose)
|
47 |
+
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
|
48 |
+
|
49 |
+
# # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
50 |
+
# transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
51 |
+
# new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
|
52 |
+
# new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
|
53 |
+
|
54 |
+
# a verision that does not rely on pytorch3d
|
55 |
+
new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3
|
56 |
+
new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
|
57 |
+
new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
|
58 |
+
|
59 |
+
# put it back to B, N, P, 3
|
60 |
+
obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1)
|
61 |
+
|
62 |
+
return obj_xyzs
|
63 |
+
|
64 |
+
|
65 |
+
class Infer_Wrapper:
|
66 |
+
|
67 |
+
def __init__(self, args, cfg):
|
68 |
+
|
69 |
+
# load
|
70 |
+
pl.seed_everything(args.eval_random_seed)
|
71 |
+
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
72 |
+
|
73 |
+
diffusion_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.diffusion_checkpoint_id, "checkpoints"))
|
74 |
+
collision_checkpoint_path = get_checkpoint_path_from_dir(os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.collision_checkpoint_id, "checkpoints"))
|
75 |
+
|
76 |
+
self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
|
77 |
+
# override ignore_rgb for visualization
|
78 |
+
cfg.DATASET.ignore_rgb = False
|
79 |
+
self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
|
80 |
+
|
81 |
+
self.sampler = SamplerV2(ConditionalPoseDiffusionModel, diffusion_checkpoint_path,
|
82 |
+
PairwiseCollisionModel, collision_checkpoint_path, self.device)
|
83 |
+
|
84 |
+
def visualize_scene(self, di, session_id):
|
85 |
+
|
86 |
+
raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
|
87 |
+
language_command = raw_datum["template_sentence"]
|
88 |
+
|
89 |
+
obj_xyz = raw_datum["pcs"]
|
90 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
|
91 |
+
|
92 |
+
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
93 |
+
|
94 |
+
scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
|
95 |
+
scene.export(scene_filename)
|
96 |
+
|
97 |
+
return language_command, scene_filename
|
98 |
+
|
99 |
+
|
100 |
+
def infer(self, di, session_id, progress=gr.Progress()):
|
101 |
+
|
102 |
+
# di = np.random.choice(len(self.dataset))
|
103 |
+
|
104 |
+
raw_datum = self.dataset.get_raw_data(di, inference_mode=True, shuffle_object_index=True)
|
105 |
+
print(raw_datum["template_sentence"])
|
106 |
+
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer, use_sentence_embedding=self.dataset.use_sentence_embedding)
|
107 |
+
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
|
108 |
+
|
109 |
+
num_poses = datum["goal_poses"].shape[0]
|
110 |
+
struct_pose, pc_poses_in_struct = self.sampler.sample(batch, num_poses, args.num_elites, args.discriminator_batch_size)
|
111 |
+
|
112 |
+
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"][:args.num_elites], struct_pose, pc_poses_in_struct)
|
113 |
+
|
114 |
+
# vis
|
115 |
+
vis_obj_xyzs = new_obj_xyzs[:3]
|
116 |
+
if torch.is_tensor(vis_obj_xyzs):
|
117 |
+
if vis_obj_xyzs.is_cuda:
|
118 |
+
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
|
119 |
+
vis_obj_xyzs = vis_obj_xyzs.numpy()
|
120 |
+
|
121 |
+
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
|
122 |
+
# if verbose:
|
123 |
+
# print("example {}".format(bi))
|
124 |
+
# print(vis_obj_xyz.shape)
|
125 |
+
#
|
126 |
+
# if trimesh:
|
127 |
+
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
|
128 |
+
vis_obj_xyz = vis_obj_xyzs[0]
|
129 |
+
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
|
130 |
+
|
131 |
+
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
|
132 |
+
|
133 |
+
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
|
134 |
+
scene.export(scene_filename)
|
135 |
+
|
136 |
+
# pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
|
137 |
+
# scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
|
138 |
+
#
|
139 |
+
# vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
|
140 |
+
# vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
|
141 |
+
# vis_pc.export(pc_filename)
|
142 |
+
#
|
143 |
+
# scene = trimesh.Scene()
|
144 |
+
# # add the coordinate frame first
|
145 |
+
# # geom = trimesh.creation.axis(0.01)
|
146 |
+
# # scene.add_geometry(geom)
|
147 |
+
# table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
|
148 |
+
# table.apply_translation([0.5, 0, -0.01])
|
149 |
+
# table.visual.vertex_colors = [150, 111, 87, 125]
|
150 |
+
# scene.add_geometry(table)
|
151 |
+
# # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
|
152 |
+
# # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
|
153 |
+
# # bounds.apply_translation([0, 0, 0])
|
154 |
+
# # bounds.visual.vertex_colors = [30, 30, 30, 30]
|
155 |
+
# # scene.add_geometry(bounds)
|
156 |
+
# # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
|
157 |
+
# # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
|
158 |
+
# # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
|
159 |
+
# # [0.0, 0.0, 0.0, 1.0]])
|
160 |
+
# # RT_4x4 = np.linalg.inv(RT_4x4)
|
161 |
+
# # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
162 |
+
# # scene.camera_transform = RT_4x4
|
163 |
+
#
|
164 |
+
# mesh_list = trimesh.util.concatenate(scene.dump())
|
165 |
+
# print(mesh_list)
|
166 |
+
# trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
|
167 |
+
|
168 |
+
return scene_filename
|
169 |
+
|
170 |
+
|
171 |
+
args = OmegaConf.create()
|
172 |
+
args.base_config_file = "./configs/base.yaml"
|
173 |
+
args.config_file = "./configs/conditional_pose_diffusion_language.yaml"
|
174 |
+
args.diffusion_checkpoint_id = "ConditionalPoseDiffusionLanguage"
|
175 |
+
args.collision_checkpoint_id = "CollisionDiscriminator"
|
176 |
+
args.eval_random_seed = 42
|
177 |
+
args.num_samples = 50
|
178 |
+
args.num_elites = 3
|
179 |
+
args.discriminator_batch_size = 10
|
180 |
+
|
181 |
+
base_cfg = OmegaConf.load(args.base_config_file)
|
182 |
+
cfg = OmegaConf.load(args.config_file)
|
183 |
+
cfg = OmegaConf.merge(base_cfg, cfg)
|
184 |
+
|
185 |
+
infer_wrapper = Infer_Wrapper(args, cfg)
|
186 |
+
|
187 |
+
# version 0
|
188 |
+
# demo = gr.Interface(
|
189 |
+
# fn=infer_wrapper.run,
|
190 |
+
# inputs=gr.Slider(0, len(infer_wrapper.dataset)),
|
191 |
+
# # clear color range [0-1.0]
|
192 |
+
# outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
|
193 |
+
# )
|
194 |
+
#
|
195 |
+
# demo.launch()
|
196 |
+
|
197 |
+
# version 1
|
198 |
+
demo = gr.Blocks(theme=gr.themes.Soft())
|
199 |
+
with demo:
|
200 |
+
gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
|
201 |
+
# font-size:18px
|
202 |
+
gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
|
203 |
+
|
204 |
+
session_id = gr.State(value=np.random.randint(0, 1000))
|
205 |
+
data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
|
206 |
+
input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
|
207 |
+
language_command = gr.Textbox(label="Input Language Command")
|
208 |
+
output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
|
209 |
+
|
210 |
+
b1 = gr.Button("Show Input Language and Scene")
|
211 |
+
b2 = gr.Button("Generate 3D Structure")
|
212 |
+
|
213 |
+
b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
|
214 |
+
b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
|
215 |
+
|
216 |
+
demo.queue(concurrency_count=10)
|
217 |
+
demo.launch()
|
configs/conditional_pose_diffusion_language.yaml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
random_seed: 1
|
2 |
+
|
3 |
+
WANDB:
|
4 |
+
project: StructDiffusion
|
5 |
+
save_dir: ${base_dirs.wandb_dir}
|
6 |
+
name: conditional_pose_diffusion_language_shuffle
|
7 |
+
|
8 |
+
DATASET:
|
9 |
+
data_root: ${base_dirs.data}
|
10 |
+
vocab_dir: ${base_dirs.data}/type_vocabs_coarse.json
|
11 |
+
|
12 |
+
# important
|
13 |
+
use_virtual_structure_frame: True
|
14 |
+
ignore_distractor_objects: True
|
15 |
+
ignore_rgb: True
|
16 |
+
|
17 |
+
# the following are determined by the dataset
|
18 |
+
max_num_target_objects: 7
|
19 |
+
max_num_distractor_objects: 5
|
20 |
+
# set to 1 because we use sentence embedding, which only takes one spot in the input seq to transformer diffusion
|
21 |
+
max_num_shape_parameters: 1
|
22 |
+
# set to zeros because they are not used for now
|
23 |
+
max_num_rearrange_features: 0
|
24 |
+
max_num_anchor_features: 0
|
25 |
+
|
26 |
+
# language
|
27 |
+
sentence_embedding_file: ${base_dirs.data}/template_sentence_data.pkl
|
28 |
+
use_incomplete_sentence: True
|
29 |
+
|
30 |
+
# shuffle
|
31 |
+
shuffle_object_index: True
|
32 |
+
|
33 |
+
num_pts: 1024
|
34 |
+
filter_num_moved_objects_range:
|
35 |
+
data_augmentation: False
|
36 |
+
|
37 |
+
DATALOADER:
|
38 |
+
batch_size: 64
|
39 |
+
num_workers: 8
|
40 |
+
pin_memory: True
|
41 |
+
|
42 |
+
MODEL:
|
43 |
+
# transformer encoder
|
44 |
+
encoder_input_dim: 256
|
45 |
+
num_attention_heads: 8
|
46 |
+
encoder_hidden_dim: 512
|
47 |
+
encoder_dropout: 0.0
|
48 |
+
encoder_activation: relu
|
49 |
+
encoder_num_layers: 8
|
50 |
+
# output head
|
51 |
+
structure_dropout: 0
|
52 |
+
object_dropout: 0
|
53 |
+
# pc encoder
|
54 |
+
ignore_rgb: ${DATASET.ignore_rgb}
|
55 |
+
pc_emb_dim: 256
|
56 |
+
posed_pc_emb_dim: 80
|
57 |
+
# pose encoder
|
58 |
+
pose_emb_dim: 80
|
59 |
+
# language
|
60 |
+
word_emb_dim: 160
|
61 |
+
# diffusion step
|
62 |
+
time_emb_dim: 80
|
63 |
+
# sequence embeddings
|
64 |
+
# max_num_target_objects (+ max_num_distractor_objects if not ignore_distractor_objects)
|
65 |
+
max_seq_size: 7
|
66 |
+
max_token_type_size: 4
|
67 |
+
seq_pos_emb_dim: 8
|
68 |
+
seq_type_emb_dim: 8
|
69 |
+
# virtual frame
|
70 |
+
use_virtual_structure_frame: ${DATASET.use_virtual_structure_frame}
|
71 |
+
# language
|
72 |
+
use_sentence_embedding: True
|
73 |
+
sentence_embedding_dim: 384
|
74 |
+
|
75 |
+
NOISE_SCHEDULE:
|
76 |
+
timesteps: 200
|
77 |
+
|
78 |
+
LOSS:
|
79 |
+
type: huber
|
80 |
+
|
81 |
+
OPTIMIZER:
|
82 |
+
lr: 0.0001
|
83 |
+
weight_decay: 0 #0.0001
|
84 |
+
# lr_restart: 3000
|
85 |
+
# warmup: 10
|
86 |
+
|
87 |
+
TRAINER:
|
88 |
+
max_epochs: 200
|
89 |
+
gradient_clip_val: 1.0
|
90 |
+
gpus: 1
|
91 |
+
deterministic: False
|
92 |
+
# enable_progress_bar: False
|
data/template_sentence_data.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbb41fd847ca6ee48d484c346fd8b0bbf478dfcc4559bf9646e3b7e7bf9fe83b
|
3 |
+
size 5680159
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ pyglet==1.5.0
|
|
7 |
openpyxl
|
8 |
pytorch_lightning==1.6.1
|
9 |
wandb===0.13.10
|
10 |
-
omegaconf==2.2.2
|
|
|
|
7 |
openpyxl
|
8 |
pytorch_lightning==1.6.1
|
9 |
wandb===0.13.10
|
10 |
+
omegaconf==2.2.2
|
11 |
+
sentence-transformers
|
src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc and b/src/StructDiffusion/data/__pycache__/semantic_arrangement_demo.cpython-38.pyc differ
|
|
src/StructDiffusion/data/__pycache__/semantic_arrangement_language.cpython-38.pyc
ADDED
Binary file (18.5 kB). View file
|
|
src/StructDiffusion/data/__pycache__/semantic_arrangement_language_demo.cpython-38.pyc
ADDED
Binary file (19.3 kB). View file
|
|
src/StructDiffusion/data/pairwise_collision.py
CHANGED
@@ -32,11 +32,27 @@ def load_pairwise_collision_data(h5_filename):
|
|
32 |
return data_dict
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
class PairwiseCollisionDataset(torch.utils.data.Dataset):
|
36 |
|
37 |
def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
|
38 |
num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
|
39 |
-
debug=False):
|
40 |
|
41 |
# load dictionary mapping from urdf to list of pc data, each sample is
|
42 |
# {"step_t": step_t, "obj": obj, "filename": filename}
|
@@ -49,6 +65,8 @@ class PairwiseCollisionDataset(torch.utils.data.Dataset):
|
|
49 |
filename = pd["filename"]
|
50 |
if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
|
51 |
continue
|
|
|
|
|
52 |
valid_pc_data.append(pd)
|
53 |
if valid_pc_data:
|
54 |
self.urdf_to_pc_data[urdf] = valid_pc_data
|
@@ -297,65 +315,3 @@ class PairwiseCollisionDataset(torch.utils.data.Dataset):
|
|
297 |
"label": torch.FloatTensor([label]),
|
298 |
}
|
299 |
return datum
|
300 |
-
|
301 |
-
# @staticmethod
|
302 |
-
# def collate_fn(data):
|
303 |
-
# """
|
304 |
-
# :param data:
|
305 |
-
# :return:
|
306 |
-
# """
|
307 |
-
#
|
308 |
-
# batched_data_dict = {}
|
309 |
-
# for key in ["is_circle"]:
|
310 |
-
# batched_data_dict[key] = torch.cat([dict[key] for dict in data], dim=0)
|
311 |
-
# for key in ["scene_xyz"]:
|
312 |
-
# batched_data_dict[key] = torch.stack([dict[key] for dict in data], dim=0)
|
313 |
-
#
|
314 |
-
# return batched_data_dict
|
315 |
-
#
|
316 |
-
# # def create_pair_xyzs_from_obj_xyzs(self, new_obj_xyzs, debug=False):
|
317 |
-
# #
|
318 |
-
# # new_obj_xyzs = [xyz.cpu().numpy() for xyz in new_obj_xyzs]
|
319 |
-
# #
|
320 |
-
# # # compute pairwise collision
|
321 |
-
# # scene_xyzs = []
|
322 |
-
# # obj_xyz_pair_idxs = list(itertools.combinations(range(len(new_obj_xyzs)), 2))
|
323 |
-
# #
|
324 |
-
# # for obj_xyz_pair_idx in obj_xyz_pair_idxs:
|
325 |
-
# # obj_xyz_pair = [new_obj_xyzs[obj_xyz_pair_idx[0]], new_obj_xyzs[obj_xyz_pair_idx[1]]]
|
326 |
-
# # num_indicator = 2
|
327 |
-
# # obj_xyz_pair_ind = []
|
328 |
-
# # for oi, obj_xyz in enumerate(obj_xyz_pair):
|
329 |
-
# # obj_xyz = np.concatenate([obj_xyz, np.tile(np.eye(num_indicator)[oi], (obj_xyz.shape[0], 1))], axis=1)
|
330 |
-
# # obj_xyz_pair_ind.append(obj_xyz)
|
331 |
-
# # pair_scene_xyz = np.concatenate(obj_xyz_pair_ind, axis=0)
|
332 |
-
# #
|
333 |
-
# # # subsampling and normalizing pc
|
334 |
-
# # rand_idx = np.random.randint(0, pair_scene_xyz.shape[0], self.num_scene_pts)
|
335 |
-
# # pair_scene_xyz = pair_scene_xyz[rand_idx]
|
336 |
-
# # if self.normalize_pc:
|
337 |
-
# # pair_scene_xyz[:, 0:3] = pc_normalize(pair_scene_xyz[:, 0:3])
|
338 |
-
# #
|
339 |
-
# # scene_xyzs.append(array_to_tensor(pair_scene_xyz))
|
340 |
-
# #
|
341 |
-
# # if debug:
|
342 |
-
# # for scene_xyz in scene_xyzs:
|
343 |
-
# # show_pcs([scene_xyz[:, 0:3]], [np.tile(np.array([0, 1, 0], dtype=np.float), (scene_xyz.shape[0], 1))],
|
344 |
-
# # add_coordinate_frame=True)
|
345 |
-
# #
|
346 |
-
# # return scene_xyzs
|
347 |
-
|
348 |
-
|
349 |
-
if __name__ == "__main__":
|
350 |
-
dataset = PairwiseCollisionDataset(urdf_pc_idx_file="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data/urdf_pc_idx.pkl",
|
351 |
-
collision_data_dir="/home/weiyu/data_drive/StructDiffusion/pairwise_collision_data",
|
352 |
-
debug=False)
|
353 |
-
|
354 |
-
for i in tqdm.tqdm(np.random.permutation(len(dataset))):
|
355 |
-
# print(i)
|
356 |
-
d = dataset[i]
|
357 |
-
# print(d["label"])
|
358 |
-
|
359 |
-
# dl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8)
|
360 |
-
# for b in tqdm.tqdm(dl):
|
361 |
-
# pass
|
|
|
32 |
return data_dict
|
33 |
|
34 |
|
35 |
+
def replace_root_directory(original_filename: str, new_root: str) -> str:
|
36 |
+
# Split the original filename into a list by directory
|
37 |
+
original_parts = original_filename.split('/')
|
38 |
+
|
39 |
+
# Find the index of the "data_new_objects" part
|
40 |
+
data_index = original_parts.index('data_new_objects')
|
41 |
+
|
42 |
+
# Split the new root into a list by directory
|
43 |
+
new_root_parts = new_root.split('/')
|
44 |
+
|
45 |
+
# Combine the new root with the rest of the original filename
|
46 |
+
updated_filename = '/'.join(new_root_parts + original_parts[data_index + 1:])
|
47 |
+
|
48 |
+
return updated_filename
|
49 |
+
|
50 |
+
|
51 |
class PairwiseCollisionDataset(torch.utils.data.Dataset):
|
52 |
|
53 |
def __init__(self, urdf_pc_idx_file, collision_data_dir, random_rotation=True,
|
54 |
num_pts=1024, normalize_pc=True, num_scene_pts=2048, data_augmentation=False,
|
55 |
+
debug=False, new_data_root=None):
|
56 |
|
57 |
# load dictionary mapping from urdf to list of pc data, each sample is
|
58 |
# {"step_t": step_t, "obj": obj, "filename": filename}
|
|
|
65 |
filename = pd["filename"]
|
66 |
if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename or "data00505290" in filename:
|
67 |
continue
|
68 |
+
if new_data_root:
|
69 |
+
pd["filename"] = replace_root_directory(pd["filename"], new_data_root)
|
70 |
valid_pc_data.append(pd)
|
71 |
if valid_pc_data:
|
72 |
self.urdf_to_pc_data[urdf] = valid_pc_data
|
|
|
315 |
"label": torch.FloatTensor([label]),
|
316 |
}
|
317 |
return datum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/StructDiffusion/data/semantic_arrangement.py
CHANGED
@@ -533,47 +533,4 @@ def compute_min_max(dataloader):
|
|
533 |
current_min, _ = torch.min(goal_poses, dim=0)
|
534 |
max_value[max_value < current_max] = current_max[max_value < current_max]
|
535 |
max_value[max_value > current_min] = current_min[max_value > current_min]
|
536 |
-
print(f"{min_value} - {max_value}")
|
537 |
-
|
538 |
-
|
539 |
-
if __name__ == "__main__":
|
540 |
-
|
541 |
-
tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
|
542 |
-
|
543 |
-
data_roots = []
|
544 |
-
index_roots = []
|
545 |
-
for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
|
546 |
-
data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
|
547 |
-
index_roots.append(index)
|
548 |
-
|
549 |
-
dataset = SemanticArrangementDataset(data_roots=data_roots,
|
550 |
-
index_roots=index_roots,
|
551 |
-
split="valid", tokenizer=tokenizer,
|
552 |
-
max_num_target_objects=7,
|
553 |
-
max_num_distractor_objects=5,
|
554 |
-
max_num_shape_parameters=5,
|
555 |
-
max_num_rearrange_features=0,
|
556 |
-
max_num_anchor_features=0,
|
557 |
-
num_pts=1024,
|
558 |
-
use_virtual_structure_frame=True,
|
559 |
-
ignore_distractor_objects=True,
|
560 |
-
ignore_rgb=True,
|
561 |
-
filter_num_moved_objects_range=None, # [5, 5]
|
562 |
-
data_augmentation=False,
|
563 |
-
shuffle_object_index=False,
|
564 |
-
debug=False)
|
565 |
-
|
566 |
-
# print(len(dataset))
|
567 |
-
# for d in dataset:
|
568 |
-
# print("\n\n" + "="*100)
|
569 |
-
|
570 |
-
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
|
571 |
-
for i, d in enumerate(tqdm(dataloader)):
|
572 |
-
pass
|
573 |
-
# for k in d:
|
574 |
-
# if isinstance(d[k], torch.Tensor):
|
575 |
-
# print("--size", k, d[k].shape)
|
576 |
-
# for k in d:
|
577 |
-
# print(k, d[k])
|
578 |
-
#
|
579 |
-
# input("next?")
|
|
|
533 |
current_min, _ = torch.min(goal_poses, dim=0)
|
534 |
max_value[max_value < current_max] = current_max[max_value < current_max]
|
535 |
max_value[max_value > current_min] = current_min[max_value > current_min]
|
536 |
+
print(f"{min_value} - {max_value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/StructDiffusion/data/semantic_arrangement_language.py
ADDED
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import cv2
|
3 |
+
import h5py
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import trimesh
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
# Local imports
|
16 |
+
from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
|
17 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
18 |
+
|
19 |
+
import StructDiffusion.utils.brain2.camera as cam
|
20 |
+
import StructDiffusion.utils.brain2.image as img
|
21 |
+
import StructDiffusion.utils.transformations as tra
|
22 |
+
|
23 |
+
|
24 |
+
class SemanticArrangementDataset(torch.utils.data.Dataset):
|
25 |
+
|
26 |
+
def __init__(self, data_roots, index_roots, split, tokenizer,
|
27 |
+
max_num_target_objects=11, max_num_distractor_objects=5,
|
28 |
+
max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
|
29 |
+
num_pts=1024,
|
30 |
+
use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
|
31 |
+
filter_num_moved_objects_range=None, shuffle_object_index=False,
|
32 |
+
sentence_embedding_file=None, use_incomplete_sentence=False,
|
33 |
+
data_augmentation=True, debug=False, **kwargs):
|
34 |
+
"""
|
35 |
+
|
36 |
+
Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
|
37 |
+
|
38 |
+
:param data_root:
|
39 |
+
:param split: train, valid, or test
|
40 |
+
:param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
|
41 |
+
:param debug:
|
42 |
+
:param max_num_shape_parameters:
|
43 |
+
:param max_num_objects:
|
44 |
+
:param max_num_rearrange_features:
|
45 |
+
:param max_num_anchor_features:
|
46 |
+
:param num_pts:
|
47 |
+
:param use_stored_arrangement_indices:
|
48 |
+
:param kwargs:
|
49 |
+
"""
|
50 |
+
|
51 |
+
self.use_virtual_structure_frame = use_virtual_structure_frame
|
52 |
+
self.ignore_distractor_objects = ignore_distractor_objects
|
53 |
+
self.ignore_rgb = ignore_rgb and not debug
|
54 |
+
|
55 |
+
self.num_pts = num_pts
|
56 |
+
self.debug = debug
|
57 |
+
|
58 |
+
self.max_num_objects = max_num_target_objects
|
59 |
+
self.max_num_other_objects = max_num_distractor_objects
|
60 |
+
self.max_num_shape_parameters = max_num_shape_parameters
|
61 |
+
self.max_num_rearrange_features = max_num_rearrange_features
|
62 |
+
self.max_num_anchor_features = max_num_anchor_features
|
63 |
+
self.shuffle_object_index = shuffle_object_index
|
64 |
+
|
65 |
+
# used to tokenize the language part
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
|
68 |
+
# retrieve data
|
69 |
+
self.data_roots = data_roots
|
70 |
+
self.arrangement_data = []
|
71 |
+
arrangement_steps = []
|
72 |
+
for ddx in range(len(data_roots)):
|
73 |
+
data_root = data_roots[ddx]
|
74 |
+
index_root = index_roots[ddx]
|
75 |
+
arrangement_indices_file = os.path.join(data_root, index_root, "{}_arrangement_indices_file_all.txt".format(split))
|
76 |
+
if os.path.exists(arrangement_indices_file):
|
77 |
+
with open(arrangement_indices_file, "r") as fh:
|
78 |
+
arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])
|
79 |
+
else:
|
80 |
+
print("{} does not exist".format(arrangement_indices_file))
|
81 |
+
# only keep the goal, ignore the intermediate steps
|
82 |
+
for filename, step_t in arrangement_steps:
|
83 |
+
if step_t == 0:
|
84 |
+
if "data00026058" in filename or "data00011415" in filename or "data00026061" in filename or "data00700565" in filename:
|
85 |
+
continue
|
86 |
+
self.arrangement_data.append((filename, step_t))
|
87 |
+
# if specified, filter data
|
88 |
+
if filter_num_moved_objects_range is not None:
|
89 |
+
self.arrangement_data = self.filter_based_on_number_of_moved_objects(filter_num_moved_objects_range)
|
90 |
+
print("{} valid sequences".format(len(self.arrangement_data)))
|
91 |
+
|
92 |
+
# language
|
93 |
+
if sentence_embedding_file:
|
94 |
+
assert max_num_shape_parameters == 1
|
95 |
+
# since we do not use them right now, ignore them
|
96 |
+
# assert max_num_rearrange_features == 0
|
97 |
+
# assert max_num_anchor_features == 0
|
98 |
+
with open(sentence_embedding_file, "rb") as fh:
|
99 |
+
template_sentence_data = pickle.load(fh)
|
100 |
+
self.use_sentence_embedding = True
|
101 |
+
self.type_value_tuple_to_template_sentences = template_sentence_data["type_value_tuple_to_template_sentences"]
|
102 |
+
self.template_sentence_to_embedding = template_sentence_data["template_sentence_to_embedding"]
|
103 |
+
self.use_incomplete_sentence = use_incomplete_sentence
|
104 |
+
print("use sentence embedding")
|
105 |
+
print(len(self.type_value_tuple_to_template_sentences))
|
106 |
+
print(len(self.template_sentence_to_embedding))
|
107 |
+
else:
|
108 |
+
self.use_sentence_embedding = False
|
109 |
+
|
110 |
+
# Data Aug
|
111 |
+
self.data_augmentation = data_augmentation
|
112 |
+
# additive noise
|
113 |
+
self.gp_rescale_factor_range = [12, 20]
|
114 |
+
self.gaussian_scale_range = [0., 0.003]
|
115 |
+
# multiplicative noise
|
116 |
+
self.gamma_shape = 1000.
|
117 |
+
self.gamma_scale = 0.001
|
118 |
+
|
119 |
+
def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
|
120 |
+
assert len(list(filter_num_moved_objects_range)) == 2
|
121 |
+
min_num, max_num = filter_num_moved_objects_range
|
122 |
+
print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
|
123 |
+
ok_data = []
|
124 |
+
for filename, step_t in self.arrangement_data:
|
125 |
+
h5 = h5py.File(filename, 'r')
|
126 |
+
moved_objs = h5['moved_objs'][()].split(',')
|
127 |
+
if min_num <= len(moved_objs) <= max_num:
|
128 |
+
ok_data.append((filename, step_t))
|
129 |
+
print("{} valid sequences left".format(len(ok_data)))
|
130 |
+
return ok_data
|
131 |
+
|
132 |
+
def get_data_idx(self, idx):
|
133 |
+
# Create the datum to return
|
134 |
+
file_idx = np.argmax(idx < self.file_to_count)
|
135 |
+
data = h5py.File(self.data_files[file_idx], 'r')
|
136 |
+
if file_idx > 0:
|
137 |
+
# for lang2sym, idx is always 0
|
138 |
+
idx = idx - self.file_to_count[file_idx - 1]
|
139 |
+
return data, idx, file_idx
|
140 |
+
|
141 |
+
def add_noise_to_depth(self, depth_img):
|
142 |
+
""" add depth noise """
|
143 |
+
multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
|
144 |
+
depth_img = multiplicative_noise * depth_img
|
145 |
+
return depth_img
|
146 |
+
|
147 |
+
def add_noise_to_xyz(self, xyz_img, depth_img):
|
148 |
+
""" TODO: remove this code or at least celean it up"""
|
149 |
+
xyz_img = xyz_img.copy()
|
150 |
+
H, W, C = xyz_img.shape
|
151 |
+
gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
|
152 |
+
self.gp_rescale_factor_range[1])
|
153 |
+
gp_scale = np.random.uniform(self.gaussian_scale_range[0],
|
154 |
+
self.gaussian_scale_range[1])
|
155 |
+
small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
|
156 |
+
additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
|
157 |
+
additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
|
158 |
+
xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
|
159 |
+
return xyz_img
|
160 |
+
|
161 |
+
def random_index(self):
|
162 |
+
return self[np.random.randint(len(self))]
|
163 |
+
|
164 |
+
def _get_rgb(self, h5, idx, ee=True):
|
165 |
+
RGB = "ee_rgb" if ee else "rgb"
|
166 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
167 |
+
return rgb1
|
168 |
+
|
169 |
+
def _get_depth(self, h5, idx, ee=True):
|
170 |
+
DEPTH = "ee_depth" if ee else "depth"
|
171 |
+
|
172 |
+
def _get_images(self, h5, idx, ee=True):
|
173 |
+
if ee:
|
174 |
+
RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
|
175 |
+
DMIN, DMAX = "ee_depth_min", "ee_depth_max"
|
176 |
+
else:
|
177 |
+
RGB, DEPTH, SEG = "rgb", "depth", "seg"
|
178 |
+
DMIN, DMAX = "depth_min", "depth_max"
|
179 |
+
dmin = h5[DMIN][idx]
|
180 |
+
dmax = h5[DMAX][idx]
|
181 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
182 |
+
depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
|
183 |
+
seg1 = img.PNGToNumpy(h5[SEG][idx])
|
184 |
+
|
185 |
+
valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
|
186 |
+
|
187 |
+
# proj_matrix = h5['proj_matrix'][()]
|
188 |
+
camera = cam.get_camera_from_h5(h5)
|
189 |
+
if self.data_augmentation:
|
190 |
+
depth1 = self.add_noise_to_depth(depth1)
|
191 |
+
|
192 |
+
xyz1 = cam.compute_xyz(depth1, camera)
|
193 |
+
if self.data_augmentation:
|
194 |
+
xyz1 = self.add_noise_to_xyz(xyz1, depth1)
|
195 |
+
|
196 |
+
# Transform the point cloud
|
197 |
+
# Here it is...
|
198 |
+
# CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
|
199 |
+
CAM_POSE = "ee_camera_view" if ee else "camera_view"
|
200 |
+
cam_pose = h5[CAM_POSE][idx]
|
201 |
+
if ee:
|
202 |
+
# ee_camera_view has 0s for x, y, z
|
203 |
+
cam_pos = h5["ee_cam_pose"][:][:3, 3]
|
204 |
+
cam_pose[:3, 3] = cam_pos
|
205 |
+
|
206 |
+
# Get transformed point cloud
|
207 |
+
h, w, d = xyz1.shape
|
208 |
+
xyz1 = xyz1.reshape(h * w, -1)
|
209 |
+
xyz1 = trimesh.transform_points(xyz1, cam_pose)
|
210 |
+
xyz1 = xyz1.reshape(h, w, -1)
|
211 |
+
|
212 |
+
scene1 = rgb1, depth1, seg1, valid1, xyz1
|
213 |
+
|
214 |
+
return scene1
|
215 |
+
|
216 |
+
def __len__(self):
|
217 |
+
return len(self.arrangement_data)
|
218 |
+
|
219 |
+
def _get_ids(self, h5):
|
220 |
+
"""
|
221 |
+
get object ids
|
222 |
+
|
223 |
+
@param h5:
|
224 |
+
@return:
|
225 |
+
"""
|
226 |
+
ids = {}
|
227 |
+
for k in h5.keys():
|
228 |
+
if k.startswith("id_"):
|
229 |
+
ids[k[3:]] = h5[k][()]
|
230 |
+
return ids
|
231 |
+
|
232 |
+
def get_positive_ratio(self):
|
233 |
+
num_pos = 0
|
234 |
+
for d in self.arrangement_data:
|
235 |
+
filename, step_t = d
|
236 |
+
if step_t == 0:
|
237 |
+
num_pos += 1
|
238 |
+
return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
|
239 |
+
|
240 |
+
def get_object_position_vocab_sizes(self):
|
241 |
+
return self.tokenizer.get_object_position_vocab_sizes()
|
242 |
+
|
243 |
+
def get_vocab_size(self):
|
244 |
+
return self.tokenizer.get_vocab_size()
|
245 |
+
|
246 |
+
def get_data_index(self, idx):
|
247 |
+
filename = self.arrangement_data[idx]
|
248 |
+
return filename
|
249 |
+
|
250 |
+
def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
|
251 |
+
"""
|
252 |
+
|
253 |
+
:param idx:
|
254 |
+
:param inference_mode:
|
255 |
+
:param shuffle_object_index: used to test different orders of objects
|
256 |
+
:return:
|
257 |
+
"""
|
258 |
+
|
259 |
+
filename, _ = self.arrangement_data[idx]
|
260 |
+
|
261 |
+
h5 = h5py.File(filename, 'r')
|
262 |
+
ids = self._get_ids(h5)
|
263 |
+
all_objs = sorted([o for o in ids.keys() if "object_" in o])
|
264 |
+
goal_specification = json.loads(str(np.array(h5["goal_specification"])))
|
265 |
+
num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
|
266 |
+
num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
|
267 |
+
assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
|
268 |
+
assert num_rearrange_objs <= self.max_num_objects
|
269 |
+
assert num_other_objs <= self.max_num_other_objects
|
270 |
+
|
271 |
+
# important: only using the last step
|
272 |
+
step_t = num_rearrange_objs
|
273 |
+
|
274 |
+
target_objs = all_objs[:num_rearrange_objs]
|
275 |
+
other_objs = all_objs[num_rearrange_objs:]
|
276 |
+
|
277 |
+
structure_parameters = goal_specification["shape"]
|
278 |
+
|
279 |
+
# Important: ensure the order is correct
|
280 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
281 |
+
target_objs = target_objs[::-1]
|
282 |
+
elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
|
283 |
+
target_objs = target_objs
|
284 |
+
else:
|
285 |
+
raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
|
286 |
+
all_objs = target_objs + other_objs
|
287 |
+
|
288 |
+
###################################
|
289 |
+
# getting scene images and point clouds
|
290 |
+
scene = self._get_images(h5, step_t, ee=True)
|
291 |
+
rgb, depth, seg, valid, xyz = scene
|
292 |
+
if inference_mode:
|
293 |
+
initial_scene = scene
|
294 |
+
|
295 |
+
# getting object point clouds
|
296 |
+
obj_pcs = []
|
297 |
+
obj_pad_mask = []
|
298 |
+
current_pc_poses = []
|
299 |
+
other_obj_pcs = []
|
300 |
+
other_obj_pad_mask = []
|
301 |
+
for obj in all_objs:
|
302 |
+
obj_mask = np.logical_and(seg == ids[obj], valid)
|
303 |
+
if np.sum(obj_mask) <= 0:
|
304 |
+
raise Exception
|
305 |
+
ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
|
306 |
+
if not ok:
|
307 |
+
raise Exception
|
308 |
+
|
309 |
+
if obj in target_objs:
|
310 |
+
if self.ignore_rgb:
|
311 |
+
obj_pcs.append(obj_xyz)
|
312 |
+
else:
|
313 |
+
obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
314 |
+
obj_pad_mask.append(0)
|
315 |
+
pc_pose = np.eye(4)
|
316 |
+
pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
|
317 |
+
current_pc_poses.append(pc_pose)
|
318 |
+
elif obj in other_objs:
|
319 |
+
if self.ignore_rgb:
|
320 |
+
other_obj_pcs.append(obj_xyz)
|
321 |
+
else:
|
322 |
+
other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
323 |
+
other_obj_pad_mask.append(0)
|
324 |
+
else:
|
325 |
+
raise Exception
|
326 |
+
|
327 |
+
###################################
|
328 |
+
# computes goal positions for objects
|
329 |
+
# Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
|
330 |
+
if self.use_virtual_structure_frame:
|
331 |
+
goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
|
332 |
+
structure_parameters["rotation"][2])
|
333 |
+
goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
|
334 |
+
structure_parameters["position"][2]]
|
335 |
+
goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
|
336 |
+
|
337 |
+
goal_obj_poses = []
|
338 |
+
current_obj_poses = []
|
339 |
+
goal_pc_poses = []
|
340 |
+
for obj, current_pc_pose in zip(target_objs, current_pc_poses):
|
341 |
+
goal_pose = h5[obj][0]
|
342 |
+
current_pose = h5[obj][step_t]
|
343 |
+
if inference_mode:
|
344 |
+
goal_obj_poses.append(goal_pose)
|
345 |
+
current_obj_poses.append(current_pose)
|
346 |
+
|
347 |
+
goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
|
348 |
+
if self.use_virtual_structure_frame:
|
349 |
+
goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
|
350 |
+
goal_pc_poses.append(goal_pc_pose)
|
351 |
+
|
352 |
+
# transform current object point cloud to the goal point cloud in the world frame
|
353 |
+
if self.debug:
|
354 |
+
new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
|
355 |
+
for i, obj_pc in enumerate(new_obj_pcs):
|
356 |
+
|
357 |
+
current_pc_pose = current_pc_poses[i]
|
358 |
+
goal_pc_pose = goal_pc_poses[i]
|
359 |
+
if self.use_virtual_structure_frame:
|
360 |
+
goal_pc_pose = goal_structure_pose @ goal_pc_pose
|
361 |
+
print("current pc pose", current_pc_pose)
|
362 |
+
print("goal pc pose", goal_pc_pose)
|
363 |
+
|
364 |
+
goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
|
365 |
+
print("transform", goal_pc_transform)
|
366 |
+
new_obj_pc = copy.deepcopy(obj_pc)
|
367 |
+
new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
|
368 |
+
print(new_obj_pc.shape)
|
369 |
+
|
370 |
+
# visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
|
371 |
+
new_obj_pcs[i] = new_obj_pc
|
372 |
+
new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
373 |
+
new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
374 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
|
375 |
+
[pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
|
376 |
+
add_coordinate_frame=True)
|
377 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
|
378 |
+
|
379 |
+
# pad data
|
380 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
381 |
+
obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
382 |
+
obj_pad_mask.append(1)
|
383 |
+
for i in range(self.max_num_other_objects - len(other_objs)):
|
384 |
+
other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
385 |
+
other_obj_pad_mask.append(1)
|
386 |
+
|
387 |
+
###################################
|
388 |
+
# preparing sentence
|
389 |
+
sentence = []
|
390 |
+
sentence_pad_mask = []
|
391 |
+
|
392 |
+
# structure parameters
|
393 |
+
# 5 parameters
|
394 |
+
structure_parameters = goal_specification["shape"]
|
395 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
396 |
+
sentence.append((structure_parameters["type"], "shape"))
|
397 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
398 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
399 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
400 |
+
if structure_parameters["type"] == "circle":
|
401 |
+
sentence.append((structure_parameters["radius"], "radius"))
|
402 |
+
elif structure_parameters["type"] == "line":
|
403 |
+
sentence.append((structure_parameters["length"] / 2.0, "radius"))
|
404 |
+
if not self.use_sentence_embedding:
|
405 |
+
for _ in range(5):
|
406 |
+
sentence_pad_mask.append(0)
|
407 |
+
else:
|
408 |
+
sentence.append((structure_parameters["type"], "shape"))
|
409 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
410 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
411 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
412 |
+
if not self.use_sentence_embedding:
|
413 |
+
for _ in range(4):
|
414 |
+
sentence_pad_mask.append(0)
|
415 |
+
sentence.append(("PAD", None))
|
416 |
+
sentence_pad_mask.append(1)
|
417 |
+
|
418 |
+
if self.use_sentence_embedding:
|
419 |
+
|
420 |
+
if self.use_incomplete_sentence:
|
421 |
+
token_idxs = np.random.permutation(len(sentence))
|
422 |
+
token_idxs = token_idxs[:np.random.randint(1, len(sentence) + 1)]
|
423 |
+
token_idxs = sorted(token_idxs)
|
424 |
+
incomplete_sentence = [sentence[ti] for ti in token_idxs]
|
425 |
+
else:
|
426 |
+
incomplete_sentence = sentence
|
427 |
+
|
428 |
+
type_value_tuple = self.tokenizer.convert_structure_params_to_type_value_tuple(incomplete_sentence)
|
429 |
+
template_sentence = np.random.choice(self.type_value_tuple_to_template_sentences[type_value_tuple])
|
430 |
+
sentence_embedding = self.template_sentence_to_embedding[template_sentence]
|
431 |
+
sentence_pad_mask = [0]
|
432 |
+
|
433 |
+
###################################
|
434 |
+
# paddings
|
435 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
436 |
+
goal_pc_poses.append(np.eye(4))
|
437 |
+
|
438 |
+
###################################
|
439 |
+
if self.debug:
|
440 |
+
print("---")
|
441 |
+
print("all objects:", all_objs)
|
442 |
+
print("target objects:", target_objs)
|
443 |
+
print("other objects:", other_objs)
|
444 |
+
print("goal specification:", goal_specification)
|
445 |
+
print("sentence:", sentence)
|
446 |
+
if self.use_sentence_embedding:
|
447 |
+
print("use sentence embedding")
|
448 |
+
if self.use_incomplete_sentence:
|
449 |
+
print("incomplete_sentence:", incomplete_sentence)
|
450 |
+
print("template sentence:", template_sentence)
|
451 |
+
show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
|
452 |
+
|
453 |
+
assert len(obj_pcs) == len(goal_pc_poses)
|
454 |
+
###################################
|
455 |
+
|
456 |
+
# shuffle the position of objects
|
457 |
+
# important: only shuffle for dinner
|
458 |
+
if shuffle_object_index and structure_parameters["type"] == "dinner":
|
459 |
+
num_target_objs = len(target_objs)
|
460 |
+
shuffle_target_object_indices = list(range(num_target_objs))
|
461 |
+
random.shuffle(shuffle_target_object_indices)
|
462 |
+
shuffle_object_indices = shuffle_target_object_indices + list(range(num_target_objs, self.max_num_objects))
|
463 |
+
obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
|
464 |
+
goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
|
465 |
+
if inference_mode:
|
466 |
+
goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
|
467 |
+
current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
|
468 |
+
target_objs = [target_objs[i] for i in shuffle_target_object_indices[:num_target_objs]]
|
469 |
+
current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices[:num_target_objs]]
|
470 |
+
|
471 |
+
###################################
|
472 |
+
if self.use_virtual_structure_frame:
|
473 |
+
if self.ignore_distractor_objects:
|
474 |
+
# language, structure virtual frame, target objects
|
475 |
+
pcs = obj_pcs
|
476 |
+
type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
|
477 |
+
position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
|
478 |
+
pad_mask = sentence_pad_mask + [0] + obj_pad_mask
|
479 |
+
else:
|
480 |
+
# language, distractor objects, structure virtual frame, target objects
|
481 |
+
pcs = other_obj_pcs + obj_pcs
|
482 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
|
483 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
|
484 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
|
485 |
+
goal_poses = [goal_structure_pose] + goal_pc_poses
|
486 |
+
else:
|
487 |
+
if self.ignore_distractor_objects:
|
488 |
+
# language, target objects
|
489 |
+
pcs = obj_pcs
|
490 |
+
type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
|
491 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
|
492 |
+
pad_mask = sentence_pad_mask + obj_pad_mask
|
493 |
+
else:
|
494 |
+
# language, distractor objects, target objects
|
495 |
+
pcs = other_obj_pcs + obj_pcs
|
496 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
|
497 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
|
498 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
|
499 |
+
goal_poses = goal_pc_poses
|
500 |
+
|
501 |
+
datum = {
|
502 |
+
"pcs": pcs,
|
503 |
+
"goal_poses": goal_poses,
|
504 |
+
"type_index": type_index,
|
505 |
+
"position_index": position_index,
|
506 |
+
"pad_mask": pad_mask,
|
507 |
+
"t": step_t,
|
508 |
+
"filename": filename
|
509 |
+
}
|
510 |
+
if self.use_sentence_embedding:
|
511 |
+
datum["sentence"] = sentence_embedding
|
512 |
+
else:
|
513 |
+
datum["sentence"] = sentence
|
514 |
+
|
515 |
+
if inference_mode:
|
516 |
+
datum["rgb"] = rgb
|
517 |
+
datum["goal_obj_poses"] = goal_obj_poses
|
518 |
+
datum["current_obj_poses"] = current_obj_poses
|
519 |
+
datum["target_objs"] = target_objs
|
520 |
+
datum["initial_scene"] = initial_scene
|
521 |
+
datum["ids"] = ids
|
522 |
+
datum["goal_specification"] = goal_specification
|
523 |
+
datum["current_pc_poses"] = current_pc_poses
|
524 |
+
if self.use_sentence_embedding:
|
525 |
+
datum["template_sentence"] = template_sentence
|
526 |
+
|
527 |
+
return datum
|
528 |
+
|
529 |
+
@staticmethod
|
530 |
+
def convert_to_tensors(datum, tokenizer, use_sentence_embedding=False):
|
531 |
+
tensors = {
|
532 |
+
"pcs": torch.stack(datum["pcs"], dim=0),
|
533 |
+
"goal_poses": torch.FloatTensor(np.array(datum["goal_poses"])),
|
534 |
+
"type_index": torch.LongTensor(np.array(datum["type_index"])),
|
535 |
+
"position_index": torch.LongTensor(np.array(datum["position_index"])),
|
536 |
+
"pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
|
537 |
+
"t": datum["t"],
|
538 |
+
"filename": datum["filename"]
|
539 |
+
}
|
540 |
+
if use_sentence_embedding:
|
541 |
+
tensors["sentence"] = torch.FloatTensor(datum["sentence"]) # after batching, B x sentence embed dim
|
542 |
+
else:
|
543 |
+
tensors["sentence"] = torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]]))
|
544 |
+
return tensors
|
545 |
+
|
546 |
+
def __getitem__(self, idx):
|
547 |
+
|
548 |
+
datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
|
549 |
+
self.tokenizer,
|
550 |
+
self.use_sentence_embedding)
|
551 |
+
|
552 |
+
return datum
|
553 |
+
|
554 |
+
def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
|
555 |
+
tensor_x = {}
|
556 |
+
|
557 |
+
tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
558 |
+
tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
|
559 |
+
if not inference_mode:
|
560 |
+
tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
561 |
+
|
562 |
+
tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
|
563 |
+
tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
|
564 |
+
tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
|
565 |
+
|
566 |
+
return tensor_x
|
567 |
+
|
568 |
+
|
569 |
+
def compute_min_max(dataloader):
|
570 |
+
|
571 |
+
# tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
|
572 |
+
# -0.9079, -0.8668, -0.9105, -0.4186])
|
573 |
+
# tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
|
574 |
+
# 0.4787, 0.6421, 1.0000])
|
575 |
+
# tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
|
576 |
+
# -0.0000, 0.0000, 0.0000, 1.0000])
|
577 |
+
# tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
|
578 |
+
# 0.0000, 0.0000, 1.0000])
|
579 |
+
|
580 |
+
min_value = torch.ones(16) * 10000
|
581 |
+
max_value = torch.ones(16) * -10000
|
582 |
+
for d in tqdm(dataloader):
|
583 |
+
goal_poses = d["goal_poses"]
|
584 |
+
goal_poses = goal_poses.reshape(-1, 16)
|
585 |
+
current_max, _ = torch.max(goal_poses, dim=0)
|
586 |
+
current_min, _ = torch.min(goal_poses, dim=0)
|
587 |
+
max_value[max_value < current_max] = current_max[max_value < current_max]
|
588 |
+
max_value[max_value > current_min] = current_min[max_value > current_min]
|
589 |
+
print(f"{min_value} - {max_value}")
|
590 |
+
|
591 |
+
|
592 |
+
if __name__ == "__main__":
|
593 |
+
|
594 |
+
tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
|
595 |
+
|
596 |
+
data_roots = []
|
597 |
+
index_roots = []
|
598 |
+
for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
|
599 |
+
data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
|
600 |
+
index_roots.append(index)
|
601 |
+
|
602 |
+
dataset = SemanticArrangementDataset(data_roots=data_roots,
|
603 |
+
index_roots=index_roots,
|
604 |
+
split="valid", tokenizer=tokenizer,
|
605 |
+
max_num_target_objects=7,
|
606 |
+
max_num_distractor_objects=5,
|
607 |
+
max_num_shape_parameters=1,
|
608 |
+
max_num_rearrange_features=0,
|
609 |
+
max_num_anchor_features=0,
|
610 |
+
num_pts=1024,
|
611 |
+
use_virtual_structure_frame=True,
|
612 |
+
ignore_distractor_objects=True,
|
613 |
+
ignore_rgb=True,
|
614 |
+
filter_num_moved_objects_range=None, # [5, 5]
|
615 |
+
data_augmentation=False,
|
616 |
+
shuffle_object_index=True,
|
617 |
+
sentence_embedding_file="/home/weiyu/Research/StructDiffusion/old/StructDiffusion/src/StructDiffusion/language/template_sentence_data.pkl",
|
618 |
+
use_incomplete_sentence=True,
|
619 |
+
debug=False)
|
620 |
+
|
621 |
+
# print(len(dataset))
|
622 |
+
# for d in dataset:
|
623 |
+
# print("\n\n" + "="*100)
|
624 |
+
|
625 |
+
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
|
626 |
+
for i, d in enumerate(tqdm(dataloader)):
|
627 |
+
for k in d:
|
628 |
+
if isinstance(d[k], torch.Tensor):
|
629 |
+
print("--size", k, d[k].shape)
|
630 |
+
for k in d:
|
631 |
+
print(k, d[k])
|
632 |
+
|
633 |
+
input("next?")
|
src/StructDiffusion/data/semantic_arrangement_language_demo.py
ADDED
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import cv2
|
3 |
+
import h5py
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import trimesh
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
# Local imports
|
16 |
+
from StructDiffusion.utils.rearrangement import show_pcs, get_pts, combine_and_sample_xyzs
|
17 |
+
from StructDiffusion.language.tokenizer import Tokenizer
|
18 |
+
|
19 |
+
import StructDiffusion.utils.brain2.camera as cam
|
20 |
+
import StructDiffusion.utils.brain2.image as img
|
21 |
+
import StructDiffusion.utils.transformations as tra
|
22 |
+
|
23 |
+
|
24 |
+
class SemanticArrangementDataset(torch.utils.data.Dataset):
|
25 |
+
|
26 |
+
def __init__(self, data_root, tokenizer,
|
27 |
+
max_num_target_objects=11, max_num_distractor_objects=5,
|
28 |
+
max_num_shape_parameters=7, max_num_rearrange_features=1, max_num_anchor_features=3,
|
29 |
+
num_pts=1024,
|
30 |
+
use_virtual_structure_frame=True, ignore_distractor_objects=True, ignore_rgb=True,
|
31 |
+
filter_num_moved_objects_range=None, shuffle_object_index=False,
|
32 |
+
sentence_embedding_file=None, use_incomplete_sentence=False,
|
33 |
+
data_augmentation=True, debug=False, **kwargs):
|
34 |
+
"""
|
35 |
+
|
36 |
+
Note: setting filter_num_moved_objects_range=[k, k] and max_num_objects=k will create no padding for target objs
|
37 |
+
|
38 |
+
:param data_root:
|
39 |
+
:param split: train, valid, or test
|
40 |
+
:param shuffle_object_index: whether to shuffle the positions of target objects and other objects in the sequence
|
41 |
+
:param debug:
|
42 |
+
:param max_num_shape_parameters:
|
43 |
+
:param max_num_objects:
|
44 |
+
:param max_num_rearrange_features:
|
45 |
+
:param max_num_anchor_features:
|
46 |
+
:param num_pts:
|
47 |
+
:param use_stored_arrangement_indices:
|
48 |
+
:param kwargs:
|
49 |
+
"""
|
50 |
+
|
51 |
+
self.use_virtual_structure_frame = use_virtual_structure_frame
|
52 |
+
self.ignore_distractor_objects = ignore_distractor_objects
|
53 |
+
self.ignore_rgb = ignore_rgb and not debug
|
54 |
+
|
55 |
+
self.num_pts = num_pts
|
56 |
+
self.debug = debug
|
57 |
+
|
58 |
+
self.max_num_objects = max_num_target_objects
|
59 |
+
self.max_num_other_objects = max_num_distractor_objects
|
60 |
+
self.max_num_shape_parameters = max_num_shape_parameters
|
61 |
+
self.max_num_rearrange_features = max_num_rearrange_features
|
62 |
+
self.max_num_anchor_features = max_num_anchor_features
|
63 |
+
self.shuffle_object_index = shuffle_object_index
|
64 |
+
|
65 |
+
# used to tokenize the language part
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
|
68 |
+
# retrieve data
|
69 |
+
self.data_root = data_root
|
70 |
+
self.arrangement_data = []
|
71 |
+
for filename in os.listdir(data_root):
|
72 |
+
if ".h5" in filename:
|
73 |
+
self.arrangement_data.append((os.path.join(data_root, filename), 0))
|
74 |
+
print("{} valid sequences".format(len(self.arrangement_data)))
|
75 |
+
|
76 |
+
# language
|
77 |
+
if sentence_embedding_file:
|
78 |
+
assert max_num_shape_parameters == 1
|
79 |
+
# since we do not use them right now, ignore them
|
80 |
+
# assert max_num_rearrange_features == 0
|
81 |
+
# assert max_num_anchor_features == 0
|
82 |
+
with open(sentence_embedding_file, "rb") as fh:
|
83 |
+
template_sentence_data = pickle.load(fh)
|
84 |
+
self.use_sentence_embedding = True
|
85 |
+
self.type_value_tuple_to_template_sentences = template_sentence_data["type_value_tuple_to_template_sentences"]
|
86 |
+
self.template_sentence_to_embedding = template_sentence_data["template_sentence_to_embedding"]
|
87 |
+
self.use_incomplete_sentence = use_incomplete_sentence
|
88 |
+
print("use sentence embedding")
|
89 |
+
print(len(self.type_value_tuple_to_template_sentences))
|
90 |
+
print(len(self.template_sentence_to_embedding))
|
91 |
+
else:
|
92 |
+
self.use_sentence_embedding = False
|
93 |
+
|
94 |
+
# Data Aug
|
95 |
+
self.data_augmentation = data_augmentation
|
96 |
+
# additive noise
|
97 |
+
self.gp_rescale_factor_range = [12, 20]
|
98 |
+
self.gaussian_scale_range = [0., 0.003]
|
99 |
+
# multiplicative noise
|
100 |
+
self.gamma_shape = 1000.
|
101 |
+
self.gamma_scale = 0.001
|
102 |
+
|
103 |
+
def filter_based_on_number_of_moved_objects(self, filter_num_moved_objects_range):
|
104 |
+
assert len(list(filter_num_moved_objects_range)) == 2
|
105 |
+
min_num, max_num = filter_num_moved_objects_range
|
106 |
+
print("Remove scenes that have less than {} or more than {} objects being moved".format(min_num, max_num))
|
107 |
+
ok_data = []
|
108 |
+
for filename, step_t in self.arrangement_data:
|
109 |
+
h5 = h5py.File(filename, 'r')
|
110 |
+
moved_objs = h5['moved_objs'][()].split(',')
|
111 |
+
if min_num <= len(moved_objs) <= max_num:
|
112 |
+
ok_data.append((filename, step_t))
|
113 |
+
print("{} valid sequences left".format(len(ok_data)))
|
114 |
+
return ok_data
|
115 |
+
|
116 |
+
def get_data_idx(self, idx):
|
117 |
+
# Create the datum to return
|
118 |
+
file_idx = np.argmax(idx < self.file_to_count)
|
119 |
+
data = h5py.File(self.data_files[file_idx], 'r')
|
120 |
+
if file_idx > 0:
|
121 |
+
# for lang2sym, idx is always 0
|
122 |
+
idx = idx - self.file_to_count[file_idx - 1]
|
123 |
+
return data, idx, file_idx
|
124 |
+
|
125 |
+
def add_noise_to_depth(self, depth_img):
|
126 |
+
""" add depth noise """
|
127 |
+
multiplicative_noise = np.random.gamma(self.gamma_shape, self.gamma_scale)
|
128 |
+
depth_img = multiplicative_noise * depth_img
|
129 |
+
return depth_img
|
130 |
+
|
131 |
+
def add_noise_to_xyz(self, xyz_img, depth_img):
|
132 |
+
""" TODO: remove this code or at least celean it up"""
|
133 |
+
xyz_img = xyz_img.copy()
|
134 |
+
H, W, C = xyz_img.shape
|
135 |
+
gp_rescale_factor = np.random.randint(self.gp_rescale_factor_range[0],
|
136 |
+
self.gp_rescale_factor_range[1])
|
137 |
+
gp_scale = np.random.uniform(self.gaussian_scale_range[0],
|
138 |
+
self.gaussian_scale_range[1])
|
139 |
+
small_H, small_W = (np.array([H, W]) / gp_rescale_factor).astype(int)
|
140 |
+
additive_noise = np.random.normal(loc=0.0, scale=gp_scale, size=(small_H, small_W, C))
|
141 |
+
additive_noise = cv2.resize(additive_noise, (W, H), interpolation=cv2.INTER_CUBIC)
|
142 |
+
xyz_img[depth_img > 0, :] += additive_noise[depth_img > 0, :]
|
143 |
+
return xyz_img
|
144 |
+
|
145 |
+
def random_index(self):
|
146 |
+
return self[np.random.randint(len(self))]
|
147 |
+
|
148 |
+
def _get_rgb(self, h5, idx, ee=True):
|
149 |
+
RGB = "ee_rgb" if ee else "rgb"
|
150 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
151 |
+
return rgb1
|
152 |
+
|
153 |
+
def _get_depth(self, h5, idx, ee=True):
|
154 |
+
DEPTH = "ee_depth" if ee else "depth"
|
155 |
+
|
156 |
+
def _get_images(self, h5, idx, ee=True):
|
157 |
+
if ee:
|
158 |
+
RGB, DEPTH, SEG = "ee_rgb", "ee_depth", "ee_seg"
|
159 |
+
DMIN, DMAX = "ee_depth_min", "ee_depth_max"
|
160 |
+
else:
|
161 |
+
RGB, DEPTH, SEG = "rgb", "depth", "seg"
|
162 |
+
DMIN, DMAX = "depth_min", "depth_max"
|
163 |
+
dmin = h5[DMIN][idx]
|
164 |
+
dmax = h5[DMAX][idx]
|
165 |
+
rgb1 = img.PNGToNumpy(h5[RGB][idx])[:, :, :3] / 255. # remove alpha
|
166 |
+
depth1 = h5[DEPTH][idx] / 20000. * (dmax - dmin) + dmin
|
167 |
+
seg1 = img.PNGToNumpy(h5[SEG][idx])
|
168 |
+
|
169 |
+
valid1 = np.logical_and(depth1 > 0.1, depth1 < 2.)
|
170 |
+
|
171 |
+
# proj_matrix = h5['proj_matrix'][()]
|
172 |
+
camera = cam.get_camera_from_h5(h5)
|
173 |
+
if self.data_augmentation:
|
174 |
+
depth1 = self.add_noise_to_depth(depth1)
|
175 |
+
|
176 |
+
xyz1 = cam.compute_xyz(depth1, camera)
|
177 |
+
if self.data_augmentation:
|
178 |
+
xyz1 = self.add_noise_to_xyz(xyz1, depth1)
|
179 |
+
|
180 |
+
# Transform the point cloud
|
181 |
+
# Here it is...
|
182 |
+
# CAM_POSE = "ee_cam_pose" if ee else "cam_pose"
|
183 |
+
CAM_POSE = "ee_camera_view" if ee else "camera_view"
|
184 |
+
cam_pose = h5[CAM_POSE][idx]
|
185 |
+
if ee:
|
186 |
+
# ee_camera_view has 0s for x, y, z
|
187 |
+
cam_pos = h5["ee_cam_pose"][:][:3, 3]
|
188 |
+
cam_pose[:3, 3] = cam_pos
|
189 |
+
|
190 |
+
# Get transformed point cloud
|
191 |
+
h, w, d = xyz1.shape
|
192 |
+
xyz1 = xyz1.reshape(h * w, -1)
|
193 |
+
xyz1 = trimesh.transform_points(xyz1, cam_pose)
|
194 |
+
xyz1 = xyz1.reshape(h, w, -1)
|
195 |
+
|
196 |
+
scene1 = rgb1, depth1, seg1, valid1, xyz1
|
197 |
+
|
198 |
+
return scene1
|
199 |
+
|
200 |
+
def __len__(self):
|
201 |
+
return len(self.arrangement_data)
|
202 |
+
|
203 |
+
def _get_ids(self, h5):
|
204 |
+
"""
|
205 |
+
get object ids
|
206 |
+
|
207 |
+
@param h5:
|
208 |
+
@return:
|
209 |
+
"""
|
210 |
+
ids = {}
|
211 |
+
for k in h5.keys():
|
212 |
+
if k.startswith("id_"):
|
213 |
+
ids[k[3:]] = h5[k][()]
|
214 |
+
return ids
|
215 |
+
|
216 |
+
def get_positive_ratio(self):
|
217 |
+
num_pos = 0
|
218 |
+
for d in self.arrangement_data:
|
219 |
+
filename, step_t = d
|
220 |
+
if step_t == 0:
|
221 |
+
num_pos += 1
|
222 |
+
return (len(self.arrangement_data) - num_pos) * 1.0 / num_pos
|
223 |
+
|
224 |
+
def get_object_position_vocab_sizes(self):
|
225 |
+
return self.tokenizer.get_object_position_vocab_sizes()
|
226 |
+
|
227 |
+
def get_vocab_size(self):
|
228 |
+
return self.tokenizer.get_vocab_size()
|
229 |
+
|
230 |
+
def get_data_index(self, idx):
|
231 |
+
filename = self.arrangement_data[idx]
|
232 |
+
return filename
|
233 |
+
|
234 |
+
def get_raw_data(self, idx, inference_mode=False, shuffle_object_index=False):
|
235 |
+
"""
|
236 |
+
|
237 |
+
:param idx:
|
238 |
+
:param inference_mode:
|
239 |
+
:param shuffle_object_index: used to test different orders of objects
|
240 |
+
:return:
|
241 |
+
"""
|
242 |
+
|
243 |
+
filename, _ = self.arrangement_data[idx]
|
244 |
+
|
245 |
+
h5 = h5py.File(filename, 'r')
|
246 |
+
ids = self._get_ids(h5)
|
247 |
+
all_objs = sorted([o for o in ids.keys() if "object_" in o])
|
248 |
+
goal_specification = json.loads(str(np.array(h5["goal_specification"])))
|
249 |
+
num_rearrange_objs = len(goal_specification["rearrange"]["objects"])
|
250 |
+
num_other_objs = len(goal_specification["anchor"]["objects"] + goal_specification["distract"]["objects"])
|
251 |
+
assert len(all_objs) == num_rearrange_objs + num_other_objs, "{}, {}".format(len(all_objs), num_rearrange_objs + num_other_objs)
|
252 |
+
assert num_rearrange_objs <= self.max_num_objects
|
253 |
+
assert num_other_objs <= self.max_num_other_objects
|
254 |
+
|
255 |
+
# important: only using the last step
|
256 |
+
step_t = num_rearrange_objs
|
257 |
+
|
258 |
+
target_objs = all_objs[:num_rearrange_objs]
|
259 |
+
other_objs = all_objs[num_rearrange_objs:]
|
260 |
+
|
261 |
+
structure_parameters = goal_specification["shape"]
|
262 |
+
|
263 |
+
# Important: ensure the order is correct
|
264 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
265 |
+
target_objs = target_objs[::-1]
|
266 |
+
elif structure_parameters["type"] == "tower" or structure_parameters["type"] == "dinner":
|
267 |
+
target_objs = target_objs
|
268 |
+
else:
|
269 |
+
raise KeyError("{} structure is not recognized".format(structure_parameters["type"]))
|
270 |
+
all_objs = target_objs + other_objs
|
271 |
+
|
272 |
+
###################################
|
273 |
+
# getting scene images and point clouds
|
274 |
+
scene = self._get_images(h5, step_t, ee=True)
|
275 |
+
rgb, depth, seg, valid, xyz = scene
|
276 |
+
if inference_mode:
|
277 |
+
initial_scene = scene
|
278 |
+
|
279 |
+
# getting object point clouds
|
280 |
+
obj_pcs = []
|
281 |
+
obj_pad_mask = []
|
282 |
+
current_pc_poses = []
|
283 |
+
other_obj_pcs = []
|
284 |
+
other_obj_pad_mask = []
|
285 |
+
for obj in all_objs:
|
286 |
+
obj_mask = np.logical_and(seg == ids[obj], valid)
|
287 |
+
if np.sum(obj_mask) <= 0:
|
288 |
+
raise Exception
|
289 |
+
ok, obj_xyz, obj_rgb, _ = get_pts(xyz, rgb, obj_mask, num_pts=self.num_pts)
|
290 |
+
if not ok:
|
291 |
+
raise Exception
|
292 |
+
|
293 |
+
if obj in target_objs:
|
294 |
+
if self.ignore_rgb:
|
295 |
+
obj_pcs.append(obj_xyz)
|
296 |
+
else:
|
297 |
+
obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
298 |
+
obj_pad_mask.append(0)
|
299 |
+
pc_pose = np.eye(4)
|
300 |
+
pc_pose[:3, 3] = torch.mean(obj_xyz, dim=0).numpy()
|
301 |
+
current_pc_poses.append(pc_pose)
|
302 |
+
elif obj in other_objs:
|
303 |
+
if self.ignore_rgb:
|
304 |
+
other_obj_pcs.append(obj_xyz)
|
305 |
+
else:
|
306 |
+
other_obj_pcs.append(torch.concat([obj_xyz, obj_rgb], dim=-1))
|
307 |
+
other_obj_pad_mask.append(0)
|
308 |
+
else:
|
309 |
+
raise Exception
|
310 |
+
|
311 |
+
###################################
|
312 |
+
# computes goal positions for objects
|
313 |
+
# Important: because of the noises we added to point clouds, the rearranged point clouds will not be perfect
|
314 |
+
if self.use_virtual_structure_frame:
|
315 |
+
goal_structure_pose = tra.euler_matrix(structure_parameters["rotation"][0], structure_parameters["rotation"][1],
|
316 |
+
structure_parameters["rotation"][2])
|
317 |
+
goal_structure_pose[:3, 3] = [structure_parameters["position"][0], structure_parameters["position"][1],
|
318 |
+
structure_parameters["position"][2]]
|
319 |
+
goal_structure_pose_inv = np.linalg.inv(goal_structure_pose)
|
320 |
+
|
321 |
+
goal_obj_poses = []
|
322 |
+
current_obj_poses = []
|
323 |
+
goal_pc_poses = []
|
324 |
+
for obj, current_pc_pose in zip(target_objs, current_pc_poses):
|
325 |
+
goal_pose = h5[obj][0]
|
326 |
+
current_pose = h5[obj][step_t]
|
327 |
+
if inference_mode:
|
328 |
+
goal_obj_poses.append(goal_pose)
|
329 |
+
current_obj_poses.append(current_pose)
|
330 |
+
|
331 |
+
goal_pc_pose = goal_pose @ np.linalg.inv(current_pose) @ current_pc_pose
|
332 |
+
if self.use_virtual_structure_frame:
|
333 |
+
goal_pc_pose = goal_structure_pose_inv @ goal_pc_pose
|
334 |
+
goal_pc_poses.append(goal_pc_pose)
|
335 |
+
|
336 |
+
# transform current object point cloud to the goal point cloud in the world frame
|
337 |
+
if self.debug:
|
338 |
+
new_obj_pcs = [copy.deepcopy(pc.numpy()) for pc in obj_pcs]
|
339 |
+
for i, obj_pc in enumerate(new_obj_pcs):
|
340 |
+
|
341 |
+
current_pc_pose = current_pc_poses[i]
|
342 |
+
goal_pc_pose = goal_pc_poses[i]
|
343 |
+
if self.use_virtual_structure_frame:
|
344 |
+
goal_pc_pose = goal_structure_pose @ goal_pc_pose
|
345 |
+
print("current pc pose", current_pc_pose)
|
346 |
+
print("goal pc pose", goal_pc_pose)
|
347 |
+
|
348 |
+
goal_pc_transform = goal_pc_pose @ np.linalg.inv(current_pc_pose)
|
349 |
+
print("transform", goal_pc_transform)
|
350 |
+
new_obj_pc = copy.deepcopy(obj_pc)
|
351 |
+
new_obj_pc[:, :3] = trimesh.transform_points(obj_pc[:, :3], goal_pc_transform)
|
352 |
+
print(new_obj_pc.shape)
|
353 |
+
|
354 |
+
# visualize rearrangement sequence (new_obj_xyzs), the current object before moving (obj_xyz), and other objects
|
355 |
+
new_obj_pcs[i] = new_obj_pc
|
356 |
+
new_obj_pcs[i][:, 3:] = np.tile(np.array([1, 0, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
357 |
+
new_obj_rgb_current = np.tile(np.array([0, 1, 0], dtype=np.float), (new_obj_pc.shape[0], 1))
|
358 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs] + [pc[:, :3] for pc in other_obj_pcs] + [obj_pc[:, :3]],
|
359 |
+
[pc[:, 3:] for pc in new_obj_pcs] + [pc[:, 3:] for pc in other_obj_pcs] + [new_obj_rgb_current],
|
360 |
+
add_coordinate_frame=True)
|
361 |
+
show_pcs([pc[:, :3] for pc in new_obj_pcs], [pc[:, 3:] for pc in new_obj_pcs], add_coordinate_frame=True)
|
362 |
+
|
363 |
+
# pad data
|
364 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
365 |
+
obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
366 |
+
obj_pad_mask.append(1)
|
367 |
+
for i in range(self.max_num_other_objects - len(other_objs)):
|
368 |
+
other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
369 |
+
other_obj_pad_mask.append(1)
|
370 |
+
|
371 |
+
###################################
|
372 |
+
# preparing sentence
|
373 |
+
sentence = []
|
374 |
+
sentence_pad_mask = []
|
375 |
+
|
376 |
+
# structure parameters
|
377 |
+
# 5 parameters
|
378 |
+
structure_parameters = goal_specification["shape"]
|
379 |
+
if structure_parameters["type"] == "circle" or structure_parameters["type"] == "line":
|
380 |
+
sentence.append((structure_parameters["type"], "shape"))
|
381 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
382 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
383 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
384 |
+
if structure_parameters["type"] == "circle":
|
385 |
+
sentence.append((structure_parameters["radius"], "radius"))
|
386 |
+
elif structure_parameters["type"] == "line":
|
387 |
+
sentence.append((structure_parameters["length"] / 2.0, "radius"))
|
388 |
+
if not self.use_sentence_embedding:
|
389 |
+
for _ in range(5):
|
390 |
+
sentence_pad_mask.append(0)
|
391 |
+
else:
|
392 |
+
sentence.append((structure_parameters["type"], "shape"))
|
393 |
+
sentence.append((structure_parameters["rotation"][2], "rotation"))
|
394 |
+
sentence.append((structure_parameters["position"][0], "position_x"))
|
395 |
+
sentence.append((structure_parameters["position"][1], "position_y"))
|
396 |
+
if not self.use_sentence_embedding:
|
397 |
+
for _ in range(4):
|
398 |
+
sentence_pad_mask.append(0)
|
399 |
+
sentence.append(("PAD", None))
|
400 |
+
sentence_pad_mask.append(1)
|
401 |
+
|
402 |
+
if self.use_sentence_embedding:
|
403 |
+
|
404 |
+
if self.use_incomplete_sentence:
|
405 |
+
token_idxs = np.random.permutation(len(sentence))
|
406 |
+
token_idxs = token_idxs[:np.random.randint(1, len(sentence) + 1)]
|
407 |
+
token_idxs = sorted(token_idxs)
|
408 |
+
incomplete_sentence = [sentence[ti] for ti in token_idxs]
|
409 |
+
else:
|
410 |
+
incomplete_sentence = sentence
|
411 |
+
|
412 |
+
type_value_tuple = self.tokenizer.convert_structure_params_to_type_value_tuple(incomplete_sentence)
|
413 |
+
template_sentence = np.random.choice(self.type_value_tuple_to_template_sentences[type_value_tuple])
|
414 |
+
sentence_embedding = self.template_sentence_to_embedding[template_sentence]
|
415 |
+
sentence_pad_mask = [0]
|
416 |
+
|
417 |
+
###################################
|
418 |
+
# paddings
|
419 |
+
for i in range(self.max_num_objects - len(target_objs)):
|
420 |
+
goal_pc_poses.append(np.eye(4))
|
421 |
+
|
422 |
+
###################################
|
423 |
+
if self.debug:
|
424 |
+
print("---")
|
425 |
+
print("all objects:", all_objs)
|
426 |
+
print("target objects:", target_objs)
|
427 |
+
print("other objects:", other_objs)
|
428 |
+
print("goal specification:", goal_specification)
|
429 |
+
print("sentence:", sentence)
|
430 |
+
if self.use_sentence_embedding:
|
431 |
+
print("use sentence embedding")
|
432 |
+
if self.use_incomplete_sentence:
|
433 |
+
print("incomplete_sentence:", incomplete_sentence)
|
434 |
+
print("template sentence:", template_sentence)
|
435 |
+
show_pcs([pc[:, :3] for pc in obj_pcs + other_obj_pcs], [pc[:, 3:] for pc in obj_pcs + other_obj_pcs], add_coordinate_frame=True)
|
436 |
+
|
437 |
+
assert len(obj_pcs) == len(goal_pc_poses)
|
438 |
+
###################################
|
439 |
+
|
440 |
+
# shuffle the position of objects
|
441 |
+
# important: only shuffle for dinner
|
442 |
+
if shuffle_object_index and structure_parameters["type"] == "dinner":
|
443 |
+
num_target_objs = len(target_objs)
|
444 |
+
shuffle_target_object_indices = list(range(num_target_objs))
|
445 |
+
random.shuffle(shuffle_target_object_indices)
|
446 |
+
shuffle_object_indices = shuffle_target_object_indices + list(range(num_target_objs, self.max_num_objects))
|
447 |
+
obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
|
448 |
+
goal_pc_poses = [goal_pc_poses[i] for i in shuffle_object_indices]
|
449 |
+
if inference_mode:
|
450 |
+
goal_obj_poses = [goal_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
|
451 |
+
current_obj_poses = [current_obj_poses[i] for i in shuffle_object_indices[:num_target_objs]]
|
452 |
+
target_objs = [target_objs[i] for i in shuffle_target_object_indices[:num_target_objs]]
|
453 |
+
current_pc_poses = [current_pc_poses[i] for i in shuffle_object_indices[:num_target_objs]]
|
454 |
+
|
455 |
+
###################################
|
456 |
+
if self.use_virtual_structure_frame:
|
457 |
+
if self.ignore_distractor_objects:
|
458 |
+
# language, structure virtual frame, target objects
|
459 |
+
pcs = obj_pcs
|
460 |
+
type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
|
461 |
+
position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
|
462 |
+
pad_mask = sentence_pad_mask + [0] + obj_pad_mask
|
463 |
+
else:
|
464 |
+
# language, distractor objects, structure virtual frame, target objects
|
465 |
+
pcs = other_obj_pcs + obj_pcs
|
466 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
|
467 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
|
468 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
|
469 |
+
goal_poses = [goal_structure_pose] + goal_pc_poses
|
470 |
+
else:
|
471 |
+
if self.ignore_distractor_objects:
|
472 |
+
# language, target objects
|
473 |
+
pcs = obj_pcs
|
474 |
+
type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
|
475 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
|
476 |
+
pad_mask = sentence_pad_mask + obj_pad_mask
|
477 |
+
else:
|
478 |
+
# language, distractor objects, target objects
|
479 |
+
pcs = other_obj_pcs + obj_pcs
|
480 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
|
481 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
|
482 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
|
483 |
+
goal_poses = goal_pc_poses
|
484 |
+
|
485 |
+
datum = {
|
486 |
+
"pcs": pcs,
|
487 |
+
"goal_poses": goal_poses,
|
488 |
+
"type_index": type_index,
|
489 |
+
"position_index": position_index,
|
490 |
+
"pad_mask": pad_mask,
|
491 |
+
"t": step_t,
|
492 |
+
"filename": filename
|
493 |
+
}
|
494 |
+
if self.use_sentence_embedding:
|
495 |
+
datum["sentence"] = sentence_embedding
|
496 |
+
else:
|
497 |
+
datum["sentence"] = sentence
|
498 |
+
|
499 |
+
if inference_mode:
|
500 |
+
datum["rgb"] = rgb
|
501 |
+
datum["goal_obj_poses"] = goal_obj_poses
|
502 |
+
datum["current_obj_poses"] = current_obj_poses
|
503 |
+
datum["target_objs"] = target_objs
|
504 |
+
datum["initial_scene"] = initial_scene
|
505 |
+
datum["ids"] = ids
|
506 |
+
datum["goal_specification"] = goal_specification
|
507 |
+
datum["current_pc_poses"] = current_pc_poses
|
508 |
+
if self.use_sentence_embedding:
|
509 |
+
datum["template_sentence"] = template_sentence
|
510 |
+
|
511 |
+
return datum
|
512 |
+
|
513 |
+
def build_data_from_xyzs(self, obj_xyzs, sentence_embedding, shuffle_object_index=True):
|
514 |
+
|
515 |
+
## objects
|
516 |
+
obj_pcs = []
|
517 |
+
obj_pad_mask = []
|
518 |
+
current_pc_poses = []
|
519 |
+
other_obj_pcs = []
|
520 |
+
other_obj_pad_mask = []
|
521 |
+
for obj_xyz in obj_xyzs:
|
522 |
+
obj_pcs.append(torch.from_numpy(obj_xyz.astype(np.float32)))
|
523 |
+
obj_pad_mask.append(0)
|
524 |
+
|
525 |
+
# pad data
|
526 |
+
num_target_objs = len(obj_pcs)
|
527 |
+
for i in range(self.max_num_objects - num_target_objs):
|
528 |
+
obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
529 |
+
obj_pad_mask.append(1)
|
530 |
+
for i in range(self.max_num_other_objects):
|
531 |
+
other_obj_pcs.append(torch.zeros_like(obj_pcs[0], dtype=torch.float32))
|
532 |
+
other_obj_pad_mask.append(1)
|
533 |
+
|
534 |
+
## sentence
|
535 |
+
sentence_pad_mask = [0]
|
536 |
+
|
537 |
+
if shuffle_object_index:
|
538 |
+
num_target_objs = num_target_objs
|
539 |
+
shuffle_target_object_indices = list(range(num_target_objs))
|
540 |
+
random.shuffle(shuffle_target_object_indices)
|
541 |
+
shuffle_object_indices = shuffle_target_object_indices + list(range(num_target_objs, self.max_num_objects))
|
542 |
+
obj_pcs = [obj_pcs[i] for i in shuffle_object_indices]
|
543 |
+
|
544 |
+
###################################
|
545 |
+
if self.use_virtual_structure_frame:
|
546 |
+
if self.ignore_distractor_objects:
|
547 |
+
# language, structure virtual frame, target objects
|
548 |
+
pcs = obj_pcs
|
549 |
+
type_index = [0] * self.max_num_shape_parameters + [2] + [3] * self.max_num_objects
|
550 |
+
position_index = list(range(self.max_num_shape_parameters)) + [0] + list(range(self.max_num_objects))
|
551 |
+
pad_mask = sentence_pad_mask + [0] + obj_pad_mask
|
552 |
+
else:
|
553 |
+
# language, distractor objects, structure virtual frame, target objects
|
554 |
+
pcs = other_obj_pcs + obj_pcs
|
555 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [2] + [3] * self.max_num_objects
|
556 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + [0] + list(range(self.max_num_objects))
|
557 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + [0] + obj_pad_mask
|
558 |
+
num_goal_poses = self.max_num_objects + 1
|
559 |
+
else:
|
560 |
+
if self.ignore_distractor_objects:
|
561 |
+
# language, target objects
|
562 |
+
pcs = obj_pcs
|
563 |
+
type_index = [0] * self.max_num_shape_parameters + [3] * self.max_num_objects
|
564 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_objects))
|
565 |
+
pad_mask = sentence_pad_mask + obj_pad_mask
|
566 |
+
else:
|
567 |
+
# language, distractor objects, target objects
|
568 |
+
pcs = other_obj_pcs + obj_pcs
|
569 |
+
type_index = [0] * self.max_num_shape_parameters + [1] * self.max_num_other_objects + [3] * self.max_num_objects
|
570 |
+
position_index = list(range(self.max_num_shape_parameters)) + list(range(self.max_num_other_objects)) + list(range(self.max_num_objects))
|
571 |
+
pad_mask = sentence_pad_mask + other_obj_pad_mask + obj_pad_mask
|
572 |
+
num_goal_poses = self.max_num_objects
|
573 |
+
|
574 |
+
datum = {
|
575 |
+
"pcs": pcs,
|
576 |
+
"type_index": type_index,
|
577 |
+
"position_index": position_index,
|
578 |
+
"pad_mask": pad_mask,
|
579 |
+
"sentence": sentence_embedding,
|
580 |
+
"num_goal_poses": num_goal_poses,
|
581 |
+
"t": 0,
|
582 |
+
"filename": "inference"
|
583 |
+
}
|
584 |
+
|
585 |
+
return datum
|
586 |
+
|
587 |
+
@staticmethod
|
588 |
+
def convert_to_tensors(datum, tokenizer, use_sentence_embedding=False):
|
589 |
+
tensors = {
|
590 |
+
"pcs": torch.stack(datum["pcs"], dim=0),
|
591 |
+
"type_index": torch.LongTensor(np.array(datum["type_index"])),
|
592 |
+
"position_index": torch.LongTensor(np.array(datum["position_index"])),
|
593 |
+
"pad_mask": torch.LongTensor(np.array(datum["pad_mask"])),
|
594 |
+
"t": datum["t"],
|
595 |
+
"filename": datum["filename"]
|
596 |
+
}
|
597 |
+
if "goal_poses" in datum:
|
598 |
+
tensors["goal_poses"] = torch.FloatTensor(np.array(datum["goal_poses"])),
|
599 |
+
|
600 |
+
if use_sentence_embedding:
|
601 |
+
tensors["sentence"] = torch.FloatTensor(datum["sentence"]) # after batching, B x sentence embed dim
|
602 |
+
else:
|
603 |
+
tensors["sentence"] = torch.LongTensor(np.array([tokenizer.tokenize(*i) for i in datum["sentence"]]))
|
604 |
+
return tensors
|
605 |
+
|
606 |
+
def __getitem__(self, idx):
|
607 |
+
|
608 |
+
datum = self.convert_to_tensors(self.get_raw_data(idx, shuffle_object_index=self.shuffle_object_index),
|
609 |
+
self.tokenizer,
|
610 |
+
self.use_sentence_embedding)
|
611 |
+
|
612 |
+
return datum
|
613 |
+
|
614 |
+
def single_datum_to_batch(self, x, num_samples, device, inference_mode=True):
|
615 |
+
tensor_x = {}
|
616 |
+
|
617 |
+
tensor_x["pcs"] = x["pcs"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
618 |
+
tensor_x["sentence"] = x["sentence"].to(device)[None, :].repeat(num_samples, 1)
|
619 |
+
if not inference_mode:
|
620 |
+
tensor_x["goal_poses"] = x["goal_poses"].to(device)[None, :, :, :].repeat(num_samples, 1, 1, 1)
|
621 |
+
|
622 |
+
tensor_x["type_index"] = x["type_index"].to(device)[None, :].repeat(num_samples, 1)
|
623 |
+
tensor_x["position_index"] = x["position_index"].to(device)[None, :].repeat(num_samples, 1)
|
624 |
+
tensor_x["pad_mask"] = x["pad_mask"].to(device)[None, :].repeat(num_samples, 1)
|
625 |
+
|
626 |
+
return tensor_x
|
627 |
+
|
628 |
+
|
629 |
+
def compute_min_max(dataloader):
|
630 |
+
|
631 |
+
# tensor([-0.3557, -0.3847, 0.0000, -1.0000, -1.0000, -0.4759, -1.0000, -1.0000,
|
632 |
+
# -0.9079, -0.8668, -0.9105, -0.4186])
|
633 |
+
# tensor([0.3915, 0.3494, 0.3267, 1.0000, 1.0000, 0.8961, 1.0000, 1.0000, 0.8194,
|
634 |
+
# 0.4787, 0.6421, 1.0000])
|
635 |
+
# tensor([0.0918, -0.3758, 0.0000, -1.0000, -1.0000, 0.0000, -1.0000, -1.0000,
|
636 |
+
# -0.0000, 0.0000, 0.0000, 1.0000])
|
637 |
+
# tensor([0.9199, 0.3710, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, -0.0000,
|
638 |
+
# 0.0000, 0.0000, 1.0000])
|
639 |
+
|
640 |
+
min_value = torch.ones(16) * 10000
|
641 |
+
max_value = torch.ones(16) * -10000
|
642 |
+
for d in tqdm(dataloader):
|
643 |
+
goal_poses = d["goal_poses"]
|
644 |
+
goal_poses = goal_poses.reshape(-1, 16)
|
645 |
+
current_max, _ = torch.max(goal_poses, dim=0)
|
646 |
+
current_min, _ = torch.min(goal_poses, dim=0)
|
647 |
+
max_value[max_value < current_max] = current_max[max_value < current_max]
|
648 |
+
max_value[max_value > current_min] = current_min[max_value > current_min]
|
649 |
+
print(f"{min_value} - {max_value}")
|
650 |
+
|
651 |
+
|
652 |
+
if __name__ == "__main__":
|
653 |
+
|
654 |
+
tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
|
655 |
+
|
656 |
+
data_roots = []
|
657 |
+
index_roots = []
|
658 |
+
for shape, index in [("circle", "index_10k"), ("line", "index_10k"), ("stacking", "index_10k"), ("dinner", "index_10k")]:
|
659 |
+
data_roots.append("/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result".format(shape))
|
660 |
+
index_roots.append(index)
|
661 |
+
|
662 |
+
dataset = SemanticArrangementDataset(data_roots=data_roots,
|
663 |
+
index_roots=index_roots,
|
664 |
+
split="valid", tokenizer=tokenizer,
|
665 |
+
max_num_target_objects=7,
|
666 |
+
max_num_distractor_objects=5,
|
667 |
+
max_num_shape_parameters=1,
|
668 |
+
max_num_rearrange_features=0,
|
669 |
+
max_num_anchor_features=0,
|
670 |
+
num_pts=1024,
|
671 |
+
use_virtual_structure_frame=True,
|
672 |
+
ignore_distractor_objects=True,
|
673 |
+
ignore_rgb=True,
|
674 |
+
filter_num_moved_objects_range=None, # [5, 5]
|
675 |
+
data_augmentation=False,
|
676 |
+
shuffle_object_index=True,
|
677 |
+
sentence_embedding_file="/home/weiyu/Research/StructDiffusion/old/StructDiffusion/src/StructDiffusion/language/template_sentence_data.pkl",
|
678 |
+
use_incomplete_sentence=True,
|
679 |
+
debug=False)
|
680 |
+
|
681 |
+
# print(len(dataset))
|
682 |
+
# for d in dataset:
|
683 |
+
# print("\n\n" + "="*100)
|
684 |
+
|
685 |
+
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)
|
686 |
+
for i, d in enumerate(tqdm(dataloader)):
|
687 |
+
for k in d:
|
688 |
+
if isinstance(d[k], torch.Tensor):
|
689 |
+
print("--size", k, d[k].shape)
|
690 |
+
for k in d:
|
691 |
+
print(k, d[k])
|
692 |
+
|
693 |
+
input("next?")
|
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc differ
|
|
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc differ
|
|
src/StructDiffusion/diffusion/sampler.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
|
|
3 |
from StructDiffusion.diffusion.noise_schedule import extract
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class Sampler:
|
6 |
|
@@ -14,7 +19,7 @@ class Sampler:
|
|
14 |
self.backbone.to(device)
|
15 |
self.backbone.eval()
|
16 |
|
17 |
-
def sample(self, batch, num_poses
|
18 |
|
19 |
noise_schedule = self.model.noise_schedule
|
20 |
|
@@ -23,7 +28,7 @@ class Sampler:
|
|
23 |
x_noisy = torch.randn((B, num_poses, 9), device=self.device)
|
24 |
|
25 |
xs = []
|
26 |
-
for t_index in
|
27 |
desc='sampling loop time step', total=noise_schedule.timesteps):
|
28 |
|
29 |
t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
|
@@ -57,236 +62,239 @@ class Sampler:
|
|
57 |
xs = list(reversed(xs))
|
58 |
return xs
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
#
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
#
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
#
|
125 |
-
#
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
#
|
131 |
-
|
132 |
-
#
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
#
|
137 |
-
#
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
#
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
#
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
#
|
165 |
-
#
|
166 |
-
#
|
167 |
-
#
|
168 |
-
#
|
169 |
-
#
|
170 |
-
|
171 |
-
|
172 |
-
#
|
173 |
-
|
174 |
-
#
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
#
|
190 |
-
#
|
191 |
-
#
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
#
|
210 |
-
#
|
211 |
-
|
212 |
-
|
213 |
-
#
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
#
|
218 |
-
#
|
219 |
-
#
|
220 |
-
#
|
221 |
-
#
|
222 |
-
#
|
223 |
-
#
|
224 |
-
#
|
225 |
-
|
226 |
-
#
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
#
|
231 |
-
#
|
232 |
-
#
|
233 |
-
#
|
234 |
-
#
|
235 |
-
#
|
236 |
-
#
|
237 |
-
#
|
238 |
-
#
|
239 |
-
#
|
240 |
-
#
|
241 |
-
#
|
242 |
-
#
|
243 |
-
#
|
244 |
-
#
|
245 |
-
#
|
246 |
-
#
|
247 |
-
#
|
248 |
-
#
|
249 |
-
#
|
250 |
-
#
|
251 |
-
#
|
252 |
-
#
|
253 |
-
#
|
254 |
-
#
|
255 |
-
#
|
256 |
-
#
|
257 |
-
#
|
258 |
-
#
|
259 |
-
#
|
260 |
-
#
|
261 |
-
|
262 |
-
#
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
#
|
273 |
-
#
|
274 |
-
#
|
275 |
-
#
|
276 |
-
#
|
277 |
-
#
|
278 |
-
#
|
279 |
-
#
|
280 |
-
#
|
281 |
-
#
|
282 |
-
|
283 |
-
#
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
+
|
4 |
from StructDiffusion.diffusion.noise_schedule import extract
|
5 |
+
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
|
6 |
+
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_new
|
7 |
+
import StructDiffusion.utils.tra3d as tra3d
|
8 |
+
|
9 |
|
10 |
class Sampler:
|
11 |
|
|
|
19 |
self.backbone.to(device)
|
20 |
self.backbone.eval()
|
21 |
|
22 |
+
def sample(self, batch, num_poses):
|
23 |
|
24 |
noise_schedule = self.model.noise_schedule
|
25 |
|
|
|
28 |
x_noisy = torch.randn((B, num_poses, 9), device=self.device)
|
29 |
|
30 |
xs = []
|
31 |
+
for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
|
32 |
desc='sampling loop time step', total=noise_schedule.timesteps):
|
33 |
|
34 |
t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
|
|
|
62 |
xs = list(reversed(xs))
|
63 |
return xs
|
64 |
|
65 |
+
class SamplerV2:
|
66 |
+
|
67 |
+
def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
|
68 |
+
collision_model_class, collision_checkpoint_path,
|
69 |
+
device, debug=False):
|
70 |
+
|
71 |
+
self.debug = debug
|
72 |
+
self.device = device
|
73 |
+
|
74 |
+
self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
|
75 |
+
self.diffusion_backbone = self.diffusion_model.model
|
76 |
+
self.diffusion_backbone.to(device)
|
77 |
+
self.diffusion_backbone.eval()
|
78 |
+
|
79 |
+
self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
|
80 |
+
self.collision_backbone = self.collision_model.model
|
81 |
+
self.collision_backbone.to(device)
|
82 |
+
self.collision_backbone.eval()
|
83 |
+
|
84 |
+
def sample(self, batch, num_poses, num_elite, discriminator_batch_size):
|
85 |
+
|
86 |
+
noise_schedule = self.diffusion_model.noise_schedule
|
87 |
+
|
88 |
+
B = batch["pcs"].shape[0]
|
89 |
+
|
90 |
+
x_noisy = torch.randn((B, num_poses, 9), device=self.device)
|
91 |
+
|
92 |
+
xs = []
|
93 |
+
for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
|
94 |
+
desc='sampling loop time step', total=noise_schedule.timesteps):
|
95 |
+
|
96 |
+
t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
|
97 |
+
|
98 |
+
# noise schedule
|
99 |
+
betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
|
100 |
+
sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
|
101 |
+
sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
|
102 |
+
|
103 |
+
# predict noise
|
104 |
+
pcs = batch["pcs"]
|
105 |
+
sentence = batch["sentence"]
|
106 |
+
type_index = batch["type_index"]
|
107 |
+
position_index = batch["position_index"]
|
108 |
+
pad_mask = batch["pad_mask"]
|
109 |
+
# calling the backbone instead of the pytorch-lightning model
|
110 |
+
with torch.no_grad():
|
111 |
+
predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
|
112 |
+
|
113 |
+
# compute noisy x at t
|
114 |
+
model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
|
115 |
+
if t_index == 0:
|
116 |
+
x_noisy = model_mean
|
117 |
+
else:
|
118 |
+
posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
|
119 |
+
noise = torch.randn_like(x_noisy)
|
120 |
+
x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
|
121 |
+
|
122 |
+
xs.append(x_noisy)
|
123 |
+
|
124 |
+
xs = list(reversed(xs))
|
125 |
+
|
126 |
+
visualize = True
|
127 |
+
|
128 |
+
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
|
129 |
+
# struct_pose: B, 1, 4, 4
|
130 |
+
# pc_poses_in_struct: B, N, 4, 4
|
131 |
+
|
132 |
+
S = B
|
133 |
+
B_discriminator = discriminator_batch_size
|
134 |
+
####################################################
|
135 |
+
# only keep one copy
|
136 |
+
|
137 |
+
# N, P, 3
|
138 |
+
obj_xyzs = batch["pcs"][0][:, :, :3]
|
139 |
+
print("obj_xyzs shape", obj_xyzs.shape)
|
140 |
+
|
141 |
+
# 1, N
|
142 |
+
# object_pad_mask: padding location has 1
|
143 |
+
num_target_objs = num_poses
|
144 |
+
if self.diffusion_backbone.use_virtual_structure_frame:
|
145 |
+
num_target_objs -= 1
|
146 |
+
object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
|
147 |
+
target_object_inds = 1 - object_pad_mask
|
148 |
+
print("target_object_inds shape", target_object_inds.shape)
|
149 |
+
print("target_object_inds", target_object_inds)
|
150 |
+
|
151 |
+
N, P, _ = obj_xyzs.shape
|
152 |
+
print("S, N, P: {}, {}, {}".format(S, N, P))
|
153 |
+
|
154 |
+
####################################################
|
155 |
+
# S, N, ...
|
156 |
+
|
157 |
+
struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
|
158 |
+
struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
|
159 |
+
|
160 |
+
new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
|
161 |
+
current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
|
162 |
+
current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
|
163 |
+
current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
|
164 |
+
|
165 |
+
# optimize xyzrpy
|
166 |
+
obj_params = torch.zeros((S, N, 6)).to(self.device)
|
167 |
+
obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
|
168 |
+
obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
|
169 |
+
#
|
170 |
+
# new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
|
171 |
+
#
|
172 |
+
# if visualize:
|
173 |
+
# print("visualizing rearrangements predicted by the generator")
|
174 |
+
# visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
|
175 |
+
|
176 |
+
####################################################
|
177 |
+
# rank
|
178 |
+
|
179 |
+
# evaluate in batches
|
180 |
+
scores = torch.zeros(S).to(self.device)
|
181 |
+
no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
|
182 |
+
num_batches = int(S / B_discriminator)
|
183 |
+
if S % B_discriminator != 0:
|
184 |
+
num_batches += 1
|
185 |
+
for b in range(num_batches):
|
186 |
+
if b + 1 == num_batches:
|
187 |
+
cur_batch_idxs_start = b * B_discriminator
|
188 |
+
cur_batch_idxs_end = S
|
189 |
+
else:
|
190 |
+
cur_batch_idxs_start = b * B_discriminator
|
191 |
+
cur_batch_idxs_end = (b + 1) * B_discriminator
|
192 |
+
cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
|
193 |
+
|
194 |
+
# print("current batch idxs start", cur_batch_idxs_start)
|
195 |
+
# print("current batch idxs end", cur_batch_idxs_end)
|
196 |
+
# print("size of the current batch", cur_batch_size)
|
197 |
+
|
198 |
+
batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
|
199 |
+
batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
|
200 |
+
batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
|
201 |
+
|
202 |
+
new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
|
203 |
+
move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
|
204 |
+
target_object_inds, self.device,
|
205 |
+
return_scene_pts=False,
|
206 |
+
return_scene_pts_and_pc_idxs=False,
|
207 |
+
num_scene_pts=False,
|
208 |
+
normalize_pc=False,
|
209 |
+
return_pair_pc=True,
|
210 |
+
num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
|
211 |
+
normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
|
212 |
+
|
213 |
+
#######################################
|
214 |
+
# predict whether there are pairwise collisions
|
215 |
+
# if collision_score_weight > 0:
|
216 |
+
with torch.no_grad():
|
217 |
+
_, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
|
218 |
+
# obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
|
219 |
+
collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
|
220 |
+
collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
|
221 |
+
|
222 |
+
# debug
|
223 |
+
# for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
|
224 |
+
# print("batch id", bi)
|
225 |
+
# for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
|
226 |
+
# print("pair", pi)
|
227 |
+
# # obj_pair_xyzs: 2 * P, 5
|
228 |
+
# print("collision score", collision_scores[bi, pi])
|
229 |
+
# trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
|
230 |
+
|
231 |
+
# 1 - mean() since the collision model predicts 1 if there is a collision
|
232 |
+
no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
|
233 |
+
if visualize:
|
234 |
+
print("no intersection scores", no_intersection_scores)
|
235 |
+
# #######################################
|
236 |
+
# if discriminator_score_weight > 0:
|
237 |
+
# # # debug:
|
238 |
+
# # print(subsampled_scene_xyz.shape)
|
239 |
+
# # print(subsampled_scene_xyz[0])
|
240 |
+
# # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
|
241 |
+
# #
|
242 |
+
# with torch.no_grad():
|
243 |
+
#
|
244 |
+
# # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
|
245 |
+
# # local_sentence = sentence[:, [0, 4]]
|
246 |
+
# # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
|
247 |
+
# # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
|
248 |
+
#
|
249 |
+
# sentence_disc = torch.LongTensor(
|
250 |
+
# [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
|
251 |
+
# sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
|
252 |
+
# position_index_dic = torch.LongTensor(raw_position_index_discriminator)
|
253 |
+
#
|
254 |
+
# preds = discriminator_model.forward(subsampled_scene_xyz,
|
255 |
+
# sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
|
256 |
+
# sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
|
257 |
+
# 1).to(device),
|
258 |
+
# position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
|
259 |
+
# device))
|
260 |
+
# # preds = discriminator_model.forward(subsampled_scene_xyz)
|
261 |
+
# preds = discriminator_model.convert_logits(preds)
|
262 |
+
# preds = preds["is_circle"] # cur_batch_size,
|
263 |
+
# scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
|
264 |
+
# if visualize:
|
265 |
+
# print("discriminator scores", scores)
|
266 |
+
|
267 |
+
# scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
|
268 |
+
scores = no_intersection_scores
|
269 |
+
sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
|
270 |
+
elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
|
271 |
+
elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
|
272 |
+
elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
|
273 |
+
elite_scores = scores[sort_idx]
|
274 |
+
print("elite scores:", elite_scores)
|
275 |
+
|
276 |
+
####################################################
|
277 |
+
# # visualize best samples
|
278 |
+
# num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
|
279 |
+
# batch_current_pc_pose = current_pc_pose[0: num_elite * N]
|
280 |
+
# best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
|
281 |
+
# move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
|
282 |
+
# target_object_inds, self.device,
|
283 |
+
# return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
|
284 |
+
# if visualize:
|
285 |
+
# print("visualizing elite rearrangements ranked by collision model/discriminator")
|
286 |
+
# visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
|
287 |
+
|
288 |
+
# num_elite, N, 6
|
289 |
+
elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
|
290 |
+
pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
|
291 |
+
pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
|
292 |
+
pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
|
293 |
+
pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
|
294 |
+
|
295 |
+
struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
|
296 |
+
|
297 |
+
print(struct_pose.shape)
|
298 |
+
print(pc_poses_in_struct.shape)
|
299 |
+
|
300 |
+
return struct_pose, pc_poses_in_struct
|
src/StructDiffusion/language/__pycache__/sentence_encoder.cpython-38.pyc
ADDED
Binary file (881 Bytes). View file
|
|
src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc and b/src/StructDiffusion/language/__pycache__/tokenizer.cpython-38.pyc differ
|
|
src/StructDiffusion/language/convert_to_natural_language.ipynb
ADDED
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 51,
|
6 |
+
"outputs": [],
|
7 |
+
"source": [
|
8 |
+
"import os\n",
|
9 |
+
"import h5py\n",
|
10 |
+
"import json\n",
|
11 |
+
"import numpy as np\n",
|
12 |
+
"import tqdm\n",
|
13 |
+
"import itertools\n",
|
14 |
+
"import copy\n",
|
15 |
+
"from collections import defaultdict\n",
|
16 |
+
"\n",
|
17 |
+
"from StructDiffuser.tokenizer import Tokenizer"
|
18 |
+
],
|
19 |
+
"metadata": {
|
20 |
+
"collapsed": false,
|
21 |
+
"pycharm": {
|
22 |
+
"name": "#%%\n"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 13,
|
29 |
+
"metadata": {
|
30 |
+
"collapsed": true,
|
31 |
+
"pycharm": {
|
32 |
+
"name": "#%%\n"
|
33 |
+
}
|
34 |
+
},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"class SemanticArrangementDataset:\n",
|
38 |
+
"\n",
|
39 |
+
" def __init__(self, data_roots, index_roots, splits, tokenizer):\n",
|
40 |
+
"\n",
|
41 |
+
" self.data_roots = data_roots\n",
|
42 |
+
" print(\"data dirs:\", self.data_roots)\n",
|
43 |
+
"\n",
|
44 |
+
" self.tokenizer = tokenizer\n",
|
45 |
+
"\n",
|
46 |
+
" self.arrangement_data = []\n",
|
47 |
+
" arrangement_steps = []\n",
|
48 |
+
" for split in splits:\n",
|
49 |
+
" for data_root, index_root in zip(data_roots, index_roots):\n",
|
50 |
+
" arrangement_indices_file = os.path.join(data_root, index_root, \"{}_arrangement_indices_file_all.txt\".format(split))\n",
|
51 |
+
" if os.path.exists(arrangement_indices_file):\n",
|
52 |
+
" with open(arrangement_indices_file, \"r\") as fh:\n",
|
53 |
+
" arrangement_steps.extend([(os.path.join(data_root, f[0]), f[1]) for f in eval(fh.readline().strip())])\n",
|
54 |
+
" else:\n",
|
55 |
+
" print(\"{} does not exist\".format(arrangement_indices_file))\n",
|
56 |
+
"\n",
|
57 |
+
" # only keep one dummy step for each rearrangement\n",
|
58 |
+
" for filename, step_t in arrangement_steps:\n",
|
59 |
+
" if step_t == 0:\n",
|
60 |
+
" self.arrangement_data.append(filename)\n",
|
61 |
+
" print(\"{} valid sequences\".format(len(self.arrangement_data)))\n",
|
62 |
+
"\n",
|
63 |
+
" def __len__(self):\n",
|
64 |
+
" return len(self.arrangement_data)\n",
|
65 |
+
"\n",
|
66 |
+
" def get_raw_data(self, idx):\n",
|
67 |
+
"\n",
|
68 |
+
" filename = self.arrangement_data[idx]\n",
|
69 |
+
" h5 = h5py.File(filename, 'r')\n",
|
70 |
+
" goal_specification = json.loads(str(np.array(h5[\"goal_specification\"])))\n",
|
71 |
+
"\n",
|
72 |
+
" ###################################\n",
|
73 |
+
" # preparing sentence\n",
|
74 |
+
" struct_spec = []\n",
|
75 |
+
"\n",
|
76 |
+
" # structure parameters\n",
|
77 |
+
" # 5 parameters\n",
|
78 |
+
" structure_parameters = goal_specification[\"shape\"]\n",
|
79 |
+
" if structure_parameters[\"type\"] == \"circle\" or structure_parameters[\"type\"] == \"line\":\n",
|
80 |
+
" struct_spec.append((structure_parameters[\"type\"], \"shape\"))\n",
|
81 |
+
" struct_spec.append((structure_parameters[\"rotation\"][2], \"rotation\"))\n",
|
82 |
+
" struct_spec.append((structure_parameters[\"position\"][0], \"position_x\"))\n",
|
83 |
+
" struct_spec.append((structure_parameters[\"position\"][1], \"position_y\"))\n",
|
84 |
+
" if structure_parameters[\"type\"] == \"circle\":\n",
|
85 |
+
" struct_spec.append((structure_parameters[\"radius\"], \"radius\"))\n",
|
86 |
+
" elif structure_parameters[\"type\"] == \"line\":\n",
|
87 |
+
" struct_spec.append((structure_parameters[\"length\"] / 2.0, \"radius\"))\n",
|
88 |
+
" else:\n",
|
89 |
+
" struct_spec.append((structure_parameters[\"type\"], \"shape\"))\n",
|
90 |
+
" struct_spec.append((structure_parameters[\"rotation\"][2], \"rotation\"))\n",
|
91 |
+
" struct_spec.append((structure_parameters[\"position\"][0], \"position_x\"))\n",
|
92 |
+
" struct_spec.append((structure_parameters[\"position\"][1], \"position_y\"))\n",
|
93 |
+
"\n",
|
94 |
+
" return struct_spec"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "markdown",
|
99 |
+
"source": [],
|
100 |
+
"metadata": {
|
101 |
+
"collapsed": false,
|
102 |
+
"pycharm": {
|
103 |
+
"name": "#%% md\n"
|
104 |
+
}
|
105 |
+
}
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 14,
|
110 |
+
"outputs": [
|
111 |
+
{
|
112 |
+
"name": "stdout",
|
113 |
+
"output_type": "stream",
|
114 |
+
"text": [
|
115 |
+
"\n",
|
116 |
+
"Build one vacab for everything...\n",
|
117 |
+
"The vocab has 124 tokens: {'PAD': 0, 'CLS': 1, 'class:MASK': 2, 'class:Basket': 3, 'class:BeerBottle': 4, 'class:Book': 5, 'class:Bottle': 6, 'class:Bowl': 7, 'class:Calculator': 8, 'class:Candle': 9, 'class:CellPhone': 10, 'class:ComputerMouse': 11, 'class:Controller': 12, 'class:Cup': 13, 'class:Donut': 14, 'class:Fork': 15, 'class:Hammer': 16, 'class:Knife': 17, 'class:Marker': 18, 'class:MilkCarton': 19, 'class:Mug': 20, 'class:Pan': 21, 'class:Pen': 22, 'class:PillBottle': 23, 'class:Plate': 24, 'class:PowerStrip': 25, 'class:Scissors': 26, 'class:SoapBottle': 27, 'class:SodaCan': 28, 'class:Spoon': 29, 'class:Stapler': 30, 'class:Teapot': 31, 'class:VideoGameController': 32, 'class:WineBottle': 33, 'class:CanOpener': 34, 'class:Fruit': 35, 'scene:MASK': 36, 'scene:dinner': 37, 'size:MASK': 38, 'size:L': 39, 'size:M': 40, 'size:S': 41, 'color:MASK': 42, 'color:blue': 43, 'color:cyan': 44, 'color:green': 45, 'color:magenta': 46, 'color:red': 47, 'color:yellow': 48, 'material:MASK': 49, 'material:glass': 50, 'material:metal': 51, 'material:plastic': 52, 'radius:MASK': 53, 'radius:less': 54, 'radius:greater': 55, 'radius:equal': 56, 'radius:0': 57, 'radius:1': 58, 'radius:2': 59, 'position_x:MASK': 60, 'position_x:less': 61, 'position_x:greater': 62, 'position_x:equal': 63, 'position_x:0': 64, 'position_x:1': 65, 'position_x:2': 66, 'position_y:MASK': 67, 'position_y:less': 68, 'position_y:greater': 69, 'position_y:equal': 70, 'position_y:0': 71, 'position_y:1': 72, 'position_y:2': 73, 'rotation:MASK': 74, 'rotation:less': 75, 'rotation:greater': 76, 'rotation:equal': 77, 'rotation:0': 78, 'rotation:1': 79, 'rotation:2': 80, 'rotation:3': 81, 'height:MASK': 82, 'height:less': 83, 'height:greater': 84, 'height:equal': 85, 'height:0': 86, 'height:1': 87, 'height:2': 88, 'height:3': 89, 'height:4': 90, 'height:5': 91, 'height:6': 92, 'height:7': 93, 'height:8': 94, 'height:9': 95, 'volumn:MASK': 96, 'volumn:less': 97, 'volumn:greater': 98, 'volumn:equal': 99, 'volumn:0': 100, 'volumn:1': 101, 'volumn:2': 102, 'volumn:3': 103, 'volumn:4': 104, 'volumn:5': 105, 'volumn:6': 106, 'volumn:7': 107, 'volumn:8': 108, 'volumn:9': 109, 'uniform_angle:MASK': 110, 'uniform_angle:False': 111, 'uniform_angle:True': 112, 'face_center:MASK': 113, 'face_center:False': 114, 'face_center:True': 115, 'angle_ratio:MASK': 116, 'angle_ratio:0.5': 117, 'angle_ratio:1.0': 118, 'shape:MASK': 119, 'shape:circle': 120, 'shape:line': 121, 'shape:tower': 122, 'shape:dinner': 123}\n",
|
118 |
+
"\n",
|
119 |
+
"Build vocabs for object position\n",
|
120 |
+
"The obj_x vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
|
121 |
+
"The obj_y vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
|
122 |
+
"The obj_z vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
|
123 |
+
"The obj_rr vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
|
124 |
+
"The obj_rp vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
|
125 |
+
"The obj_ry vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
|
126 |
+
"The struct_x vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
|
127 |
+
"The struct_y vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
|
128 |
+
"The struct_z vocab has 202 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201}\n",
|
129 |
+
"The struct_rr vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
|
130 |
+
"The struct_rp vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
|
131 |
+
"The struct_ry vocab has 362 tokens: {'PAD': 0, 'MASK': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '10': 12, '11': 13, '12': 14, '13': 15, '14': 16, '15': 17, '16': 18, '17': 19, '18': 20, '19': 21, '20': 22, '21': 23, '22': 24, '23': 25, '24': 26, '25': 27, '26': 28, '27': 29, '28': 30, '29': 31, '30': 32, '31': 33, '32': 34, '33': 35, '34': 36, '35': 37, '36': 38, '37': 39, '38': 40, '39': 41, '40': 42, '41': 43, '42': 44, '43': 45, '44': 46, '45': 47, '46': 48, '47': 49, '48': 50, '49': 51, '50': 52, '51': 53, '52': 54, '53': 55, '54': 56, '55': 57, '56': 58, '57': 59, '58': 60, '59': 61, '60': 62, '61': 63, '62': 64, '63': 65, '64': 66, '65': 67, '66': 68, '67': 69, '68': 70, '69': 71, '70': 72, '71': 73, '72': 74, '73': 75, '74': 76, '75': 77, '76': 78, '77': 79, '78': 80, '79': 81, '80': 82, '81': 83, '82': 84, '83': 85, '84': 86, '85': 87, '86': 88, '87': 89, '88': 90, '89': 91, '90': 92, '91': 93, '92': 94, '93': 95, '94': 96, '95': 97, '96': 98, '97': 99, '98': 100, '99': 101, '100': 102, '101': 103, '102': 104, '103': 105, '104': 106, '105': 107, '106': 108, '107': 109, '108': 110, '109': 111, '110': 112, '111': 113, '112': 114, '113': 115, '114': 116, '115': 117, '116': 118, '117': 119, '118': 120, '119': 121, '120': 122, '121': 123, '122': 124, '123': 125, '124': 126, '125': 127, '126': 128, '127': 129, '128': 130, '129': 131, '130': 132, '131': 133, '132': 134, '133': 135, '134': 136, '135': 137, '136': 138, '137': 139, '138': 140, '139': 141, '140': 142, '141': 143, '142': 144, '143': 145, '144': 146, '145': 147, '146': 148, '147': 149, '148': 150, '149': 151, '150': 152, '151': 153, '152': 154, '153': 155, '154': 156, '155': 157, '156': 158, '157': 159, '158': 160, '159': 161, '160': 162, '161': 163, '162': 164, '163': 165, '164': 166, '165': 167, '166': 168, '167': 169, '168': 170, '169': 171, '170': 172, '171': 173, '172': 174, '173': 175, '174': 176, '175': 177, '176': 178, '177': 179, '178': 180, '179': 181, '180': 182, '181': 183, '182': 184, '183': 185, '184': 186, '185': 187, '186': 188, '187': 189, '188': 190, '189': 191, '190': 192, '191': 193, '192': 194, '193': 195, '194': 196, '195': 197, '196': 198, '197': 199, '198': 200, '199': 201, '200': 202, '201': 203, '202': 204, '203': 205, '204': 206, '205': 207, '206': 208, '207': 209, '208': 210, '209': 211, '210': 212, '211': 213, '212': 214, '213': 215, '214': 216, '215': 217, '216': 218, '217': 219, '218': 220, '219': 221, '220': 222, '221': 223, '222': 224, '223': 225, '224': 226, '225': 227, '226': 228, '227': 229, '228': 230, '229': 231, '230': 232, '231': 233, '232': 234, '233': 235, '234': 236, '235': 237, '236': 238, '237': 239, '238': 240, '239': 241, '240': 242, '241': 243, '242': 244, '243': 245, '244': 246, '245': 247, '246': 248, '247': 249, '248': 250, '249': 251, '250': 252, '251': 253, '252': 254, '253': 255, '254': 256, '255': 257, '256': 258, '257': 259, '258': 260, '259': 261, '260': 262, '261': 263, '262': 264, '263': 265, '264': 266, '265': 267, '266': 268, '267': 269, '268': 270, '269': 271, '270': 272, '271': 273, '272': 274, '273': 275, '274': 276, '275': 277, '276': 278, '277': 279, '278': 280, '279': 281, '280': 282, '281': 283, '282': 284, '283': 285, '284': 286, '285': 287, '286': 288, '287': 289, '288': 290, '289': 291, '290': 292, '291': 293, '292': 294, '293': 295, '294': 296, '295': 297, '296': 298, '297': 299, '298': 300, '299': 301, '300': 302, '301': 303, '302': 304, '303': 305, '304': 306, '305': 307, '306': 308, '307': 309, '308': 310, '309': 311, '310': 312, '311': 313, '312': 314, '313': 315, '314': 316, '315': 317, '316': 318, '317': 319, '318': 320, '319': 321, '320': 322, '321': 323, '322': 324, '323': 325, '324': 326, '325': 327, '326': 328, '327': 329, '328': 330, '329': 331, '330': 332, '331': 333, '332': 334, '333': 335, '334': 336, '335': 337, '336': 338, '337': 339, '338': 340, '339': 341, '340': 342, '341': 343, '342': 344, '343': 345, '344': 346, '345': 347, '346': 348, '347': 349, '348': 350, '349': 351, '350': 352, '351': 353, '352': 354, '353': 355, '354': 356, '355': 357, '356': 358, '357': 359, '358': 360, '359': 361}\n",
|
132 |
+
"data dirs: ['/home/weiyu/data_drive/data_new_objects/examples_circle_new_objects/result', '/home/weiyu/data_drive/data_new_objects/examples_line_new_objects/result', '/home/weiyu/data_drive/data_new_objects/examples_tower_new_objects/result', '/home/weiyu/data_drive/data_new_objects/examples_dinner_new_objects/result']\n",
|
133 |
+
"40000 valid sequences\n"
|
134 |
+
]
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"source": [
|
138 |
+
"tokenizer = Tokenizer(\"/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json\")\n",
|
139 |
+
"\n",
|
140 |
+
"data_roots = []\n",
|
141 |
+
"index_roots = []\n",
|
142 |
+
"for shape, index in [(\"circle\", \"index_10k\"), (\"line\", \"index_10k\"), (\"tower\", \"index_10k\"), (\"dinner\", \"index_10k\")]:\n",
|
143 |
+
" data_roots.append(\"/home/weiyu/data_drive/data_new_objects/examples_{}_new_objects/result\".format(shape))\n",
|
144 |
+
" index_roots.append(index)\n",
|
145 |
+
"\n",
|
146 |
+
"dataset = SemanticArrangementDataset(data_roots=data_roots, index_roots=index_roots, splits=[\"train\", \"valid\", \"test\"], tokenizer=tokenizer)"
|
147 |
+
],
|
148 |
+
"metadata": {
|
149 |
+
"collapsed": false,
|
150 |
+
"pycharm": {
|
151 |
+
"name": "#%%\n"
|
152 |
+
}
|
153 |
+
}
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 4,
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"name": "stdout",
|
161 |
+
"output_type": "stream",
|
162 |
+
"text": [
|
163 |
+
"\n",
|
164 |
+
"\n",
|
165 |
+
"{'place_at_once': 'False', 'position': [0.4530459674902468, 0.2866384076623889, 0.011194709806729462], 'rotation': [5.101818936729106e-05, 1.362746309147995e-06, 2.145504341444197], 'type': 'tower'}\n",
|
166 |
+
"[('tower', 'shape'), (2.145504341444197, 'rotation'), (0.4530459674902468, 'position_x'), (0.2866384076623889, 'position_y')]\n",
|
167 |
+
"tower in the middle left of the table facing west\n",
|
168 |
+
"[('tower', 'shape'), (2.145504341444197, 'rotation'), (0.4530459674902468, 'position_x'), (0.2866384076623889, 'position_y')]\n",
|
169 |
+
"tower in the middle left of the table facing west\n",
|
170 |
+
"(('rotation', 'west'), ('shape', 'tower'), ('x', 'middle'), ('y', 'left'))\n",
|
171 |
+
"\n",
|
172 |
+
"\n",
|
173 |
+
"{'length': 0.15789473684210525, 'length_increment': 0.05, 'max_length': 1.0, 'min_length': 0.0, 'place_at_once': 'True', 'position': [0.5744088910421017, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'dinner', 'uniform_space': 'False'}\n",
|
174 |
+
"[('dinner', 'shape'), (0.0, 'rotation'), (0.5744088910421017, 'position_x'), (0.0, 'position_y')]\n",
|
175 |
+
"dinner in the middle center of the table facing south\n",
|
176 |
+
"[('dinner', 'shape'), (0.0, 'rotation'), (0.5744088910421017, 'position_x'), (0.0, 'position_y')]\n",
|
177 |
+
"dinner in the middle center of the table facing south\n",
|
178 |
+
"(('rotation', 'south'), ('shape', 'dinner'), ('x', 'middle'), ('y', 'center'))\n",
|
179 |
+
"\n",
|
180 |
+
"\n",
|
181 |
+
"{'place_at_once': 'False', 'position': [0.5300184865230677, -0.11749143967722209, 0.043775766459831195], 'rotation': [8.311828443210225e-05, 2.8403995850279114e-05, -1.9831750137833084], 'type': 'tower'}\n",
|
182 |
+
"[('tower', 'shape'), (-1.9831750137833084, 'rotation'), (0.5300184865230677, 'position_x'), (-0.11749143967722209, 'position_y')]\n",
|
183 |
+
"tower in the middle center of the table facing north\n",
|
184 |
+
"[('tower', 'shape')]\n",
|
185 |
+
"tower\n",
|
186 |
+
"(('shape', 'tower'),)\n",
|
187 |
+
"\n",
|
188 |
+
"\n",
|
189 |
+
"{'length': 0.3157894736842105, 'length_increment': 0.05, 'max_length': 1.0, 'min_length': 0.0, 'place_at_once': 'True', 'position': [0.6482385523146229, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'dinner', 'uniform_space': 'False'}\n",
|
190 |
+
"[('dinner', 'shape'), (0.0, 'rotation'), (0.6482385523146229, 'position_x'), (0.0, 'position_y')]\n",
|
191 |
+
"dinner in the top center of the table facing south\n",
|
192 |
+
"[('dinner', 'shape')]\n",
|
193 |
+
"dinner\n",
|
194 |
+
"(('shape', 'dinner'),)\n",
|
195 |
+
"\n",
|
196 |
+
"\n",
|
197 |
+
"{'angle_ratio': 1.0, 'face_center': 'True', 'max_radius': 0.5, 'min_radius': 0.050687861718942046, 'place_at_once': 'True', 'position': [0.2998438437491998, -0.03599718247376027, 0.0], 'radius': 0.0966402394976866, 'radius_increment': 0.005, 'rotation': [0.0, -0.0, 2.053106459668934], 'type': 'circle', 'uniform_angle': 'True'}\n",
|
198 |
+
"[('circle', 'shape'), (2.053106459668934, 'rotation'), (0.2998438437491998, 'position_x'), (-0.03599718247376027, 'position_y'), (0.0966402394976866, 'radius')]\n",
|
199 |
+
"small circle in the middle center of the table facing west\n",
|
200 |
+
"[('circle', 'shape'), (2.053106459668934, 'rotation'), (0.2998438437491998, 'position_x'), (-0.03599718247376027, 'position_y'), (0.0966402394976866, 'radius')]\n",
|
201 |
+
"small circle in the middle center of the table facing west\n",
|
202 |
+
"(('rotation', 'west'), ('shape', 'circle'), ('size', 'small'), ('x', 'middle'), ('y', 'center'))\n",
|
203 |
+
"\n",
|
204 |
+
"\n",
|
205 |
+
"{'length': 0.4245597103515504, 'length_increment': 0.005, 'max_length': 1.0, 'min_length': 0.21760311495166934, 'place_at_once': 'True', 'position': [0.6672547106460816, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'line', 'uniform_space': 'True'}\n",
|
206 |
+
"[('line', 'shape'), (0.0, 'rotation'), (0.6672547106460816, 'position_x'), (0.0, 'position_y'), (0.2122798551757752, 'radius')]\n",
|
207 |
+
"medium line in the top center of the table facing south\n",
|
208 |
+
"[('line', 'shape'), (0.0, 'rotation'), (0.6672547106460816, 'position_x'), (0.2122798551757752, 'radius')]\n",
|
209 |
+
"medium line in the top facing south\n",
|
210 |
+
"(('rotation', 'south'), ('shape', 'line'), ('size', 'medium'), ('x', 'top'))\n",
|
211 |
+
"\n",
|
212 |
+
"\n",
|
213 |
+
"{'place_at_once': 'False', 'position': [0.6555576184899171, 0.22241488561049588, 0.006522659915853506], 'rotation': [-0.000139418832574769, -7.243860660016997e-05, 2.2437880740062814], 'type': 'tower'}\n",
|
214 |
+
"[('tower', 'shape'), (2.2437880740062814, 'rotation'), (0.6555576184899171, 'position_x'), (0.22241488561049588, 'position_y')]\n",
|
215 |
+
"tower in the top left of the table facing west\n",
|
216 |
+
"[(2.2437880740062814, 'rotation'), (0.6555576184899171, 'position_x')]\n",
|
217 |
+
"in the top facing west\n",
|
218 |
+
"(('rotation', 'west'), ('x', 'top'))\n",
|
219 |
+
"\n",
|
220 |
+
"\n",
|
221 |
+
"{'length': 0.4925060249864075, 'length_increment': 0.005, 'max_length': 1.0, 'min_length': 0.4925060249864075, 'place_at_once': 'True', 'position': [0.7754676784901477, 0.0, 0.0], 'rotation': [0.0, -0.0, 0.0], 'type': 'line', 'uniform_space': 'False'}\n",
|
222 |
+
"[('line', 'shape'), (0.0, 'rotation'), (0.7754676784901477, 'position_x'), (0.0, 'position_y'), (0.24625301249320375, 'radius')]\n",
|
223 |
+
"medium line in the top center of the table facing south\n",
|
224 |
+
"[(0.0, 'rotation'), (0.7754676784901477, 'position_x')]\n",
|
225 |
+
"in the top facing south\n",
|
226 |
+
"(('rotation', 'south'), ('x', 'top'))\n",
|
227 |
+
"\n",
|
228 |
+
"\n",
|
229 |
+
"{'angle_ratio': 1.0, 'face_center': 'True', 'max_radius': 0.5, 'min_radius': 0.2260219063147572, 'place_at_once': 'True', 'position': [0.6256453430245876, 0.1131426073908803, 0.0], 'radius': 0.2260219063147572, 'radius_increment': 0.005, 'rotation': [0.0, -0.0, 1.6063513593439724], 'type': 'circle', 'uniform_angle': 'True'}\n",
|
230 |
+
"[('circle', 'shape'), (1.6063513593439724, 'rotation'), (0.6256453430245876, 'position_x'), (0.1131426073908803, 'position_y'), (0.2260219063147572, 'radius')]\n",
|
231 |
+
"medium circle in the middle center of the table facing west\n",
|
232 |
+
"[(1.6063513593439724, 'rotation'), (0.6256453430245876, 'position_x')]\n",
|
233 |
+
"in the middle facing west\n",
|
234 |
+
"(('rotation', 'west'), ('x', 'middle'))\n",
|
235 |
+
"\n",
|
236 |
+
"\n",
|
237 |
+
"{'angle_ratio': 1.0, 'face_center': 'True', 'max_radius': 0.5, 'min_radius': 0.14976631196286583, 'place_at_once': 'True', 'position': [0.5157008668336853, 0.11005531020590054, 0.0], 'radius': 0.15991801306539147, 'radius_increment': 0.005, 'rotation': [0.0, -0.0, -2.2145659262893918], 'type': 'circle', 'uniform_angle': 'True'}\n",
|
238 |
+
"[('circle', 'shape'), (-2.2145659262893918, 'rotation'), (0.5157008668336853, 'position_x'), (0.11005531020590054, 'position_y'), (0.15991801306539147, 'radius')]\n",
|
239 |
+
"small circle in the middle center of the table facing north\n",
|
240 |
+
"[('circle', 'shape'), (0.5157008668336853, 'position_x'), (0.15991801306539147, 'radius')]\n",
|
241 |
+
"small circle in the middle\n",
|
242 |
+
"(('shape', 'circle'), ('size', 'small'), ('x', 'middle'))\n"
|
243 |
+
]
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"source": [
|
247 |
+
"idxs = np.random.permutation(len(dataset))\n",
|
248 |
+
"for i in idxs[:10]:\n",
|
249 |
+
" print(\"\\n\")\n",
|
250 |
+
" struct_spec = dataset.get_raw_data(i)\n",
|
251 |
+
" print(struct_spec)\n",
|
252 |
+
" struct_word_spec = tokenizer.convert_structure_params_to_natural_language(struct_spec)\n",
|
253 |
+
" print(struct_word_spec)\n",
|
254 |
+
"\n",
|
255 |
+
" token_idxs = np.random.permutation(len(struct_spec))\n",
|
256 |
+
" token_idxs = token_idxs[:np.random.randint(1, len(struct_spec) + 1)]\n",
|
257 |
+
" token_idxs = sorted(token_idxs)\n",
|
258 |
+
" incomplete_struct_spec = [struct_spec[ti] for ti in token_idxs]\n",
|
259 |
+
"\n",
|
260 |
+
" print(incomplete_struct_spec)\n",
|
261 |
+
" print(tokenizer.convert_structure_params_to_natural_language(incomplete_struct_spec))\n",
|
262 |
+
"\n",
|
263 |
+
" type_value_tuple = tokenizer.convert_structure_params_to_type_value_tuple(incomplete_struct_spec)\n",
|
264 |
+
" print(type_value_tuple)"
|
265 |
+
],
|
266 |
+
"metadata": {
|
267 |
+
"collapsed": false,
|
268 |
+
"pycharm": {
|
269 |
+
"name": "#%%\n"
|
270 |
+
}
|
271 |
+
}
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 49,
|
276 |
+
"outputs": [
|
277 |
+
{
|
278 |
+
"name": "stderr",
|
279 |
+
"output_type": "stream",
|
280 |
+
"text": [
|
281 |
+
"100%|██████████| 40000/40000 [00:23<00:00, 1699.94it/s]"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"name": "stdout",
|
286 |
+
"output_type": "stream",
|
287 |
+
"text": [
|
288 |
+
"669\n"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"name": "stderr",
|
293 |
+
"output_type": "stream",
|
294 |
+
"text": [
|
295 |
+
"\n"
|
296 |
+
]
|
297 |
+
}
|
298 |
+
],
|
299 |
+
"source": [
|
300 |
+
"unique_type_value_tuples = set()\n",
|
301 |
+
"for i in tqdm.tqdm(idxs):\n",
|
302 |
+
" struct_spec = dataset.get_raw_data(i)\n",
|
303 |
+
"\n",
|
304 |
+
" incomplete_struct_specs = []\n",
|
305 |
+
" for L in range(1, len(struct_spec) + 1):\n",
|
306 |
+
" for subset in itertools.combinations(struct_spec, L):\n",
|
307 |
+
" incomplete_struct_specs.append(subset)\n",
|
308 |
+
"\n",
|
309 |
+
" # print(incomplete_struct_specs)\n",
|
310 |
+
"\n",
|
311 |
+
" type_value_tuples = []\n",
|
312 |
+
" for incomplete_struct_spec in incomplete_struct_specs:\n",
|
313 |
+
" type_value_tuples.append(tokenizer.convert_structure_params_to_type_value_tuple(incomplete_struct_spec))\n",
|
314 |
+
"\n",
|
315 |
+
" unique_type_value_tuples.update(type_value_tuples)\n",
|
316 |
+
"\n",
|
317 |
+
"print(len(unique_type_value_tuples))"
|
318 |
+
],
|
319 |
+
"metadata": {
|
320 |
+
"collapsed": false,
|
321 |
+
"pycharm": {
|
322 |
+
"name": "#%%\n"
|
323 |
+
}
|
324 |
+
}
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "code",
|
328 |
+
"execution_count": null,
|
329 |
+
"outputs": [],
|
330 |
+
"source": [
|
331 |
+
"sentence_template = [\n",
|
332 |
+
" \"Put the objects {in a [size][shape]} on the {[x][y] of} the table {facing [rotation]}.\",\n",
|
333 |
+
" \"Build a [size][shape] of the [objects] on the [x][y] of the table facing [rotation].\",\n",
|
334 |
+
" \"Put the [objects] on the [x][y] of the table and make a [shape] facing [rotation].\",\n",
|
335 |
+
" \"Rearrange the [objects] into a [shape], and put the structure on the [x][y] of the table facing [rotation].\",\n",
|
336 |
+
" \"Could you ...\",\n",
|
337 |
+
" \"Please ...\",\n",
|
338 |
+
" \"Pick up the objects, put them into a [size][shape], place the [shape] on the [x][y] of table, make sure the [shape] is facing [rotation].\"]\n",
|
339 |
+
"\n"
|
340 |
+
],
|
341 |
+
"metadata": {
|
342 |
+
"collapsed": false,
|
343 |
+
"pycharm": {
|
344 |
+
"name": "#%%\n"
|
345 |
+
}
|
346 |
+
}
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"cell_type": "markdown",
|
350 |
+
"source": [
|
351 |
+
"Enumerate all possible combinations of types"
|
352 |
+
],
|
353 |
+
"metadata": {
|
354 |
+
"collapsed": false,
|
355 |
+
"pycharm": {
|
356 |
+
"name": "#%% md\n"
|
357 |
+
}
|
358 |
+
}
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "code",
|
362 |
+
"execution_count": 31,
|
363 |
+
"outputs": [
|
364 |
+
{
|
365 |
+
"name": "stdout",
|
366 |
+
"output_type": "stream",
|
367 |
+
"text": [
|
368 |
+
"31\n",
|
369 |
+
"[('size',), ('shape',), ('x',), ('y',), ('rotation',), ('shape', 'size'), ('size', 'x'), ('size', 'y'), ('rotation', 'size'), ('shape', 'x'), ('shape', 'y'), ('rotation', 'shape'), ('x', 'y'), ('rotation', 'x'), ('rotation', 'y'), ('shape', 'size', 'x'), ('shape', 'size', 'y'), ('rotation', 'shape', 'size'), ('size', 'x', 'y'), ('rotation', 'size', 'x'), ('rotation', 'size', 'y'), ('shape', 'x', 'y'), ('rotation', 'shape', 'x'), ('rotation', 'shape', 'y'), ('rotation', 'x', 'y'), ('shape', 'size', 'x', 'y'), ('rotation', 'shape', 'size', 'x'), ('rotation', 'shape', 'size', 'y'), ('rotation', 'size', 'x', 'y'), ('rotation', 'shape', 'x', 'y'), ('rotation', 'shape', 'size', 'x', 'y')]\n"
|
370 |
+
]
|
371 |
+
}
|
372 |
+
],
|
373 |
+
"source": [
|
374 |
+
"import itertools\n",
|
375 |
+
"types = [\"size\", \"shape\", \"x\", \"y\", \"rotation\"]\n",
|
376 |
+
"\n",
|
377 |
+
"type_combs = []\n",
|
378 |
+
"for L in range(1, len(types) + 1):\n",
|
379 |
+
" for subset in itertools.combinations(types, L):\n",
|
380 |
+
" type_combs.append(tuple(sorted(subset)))\n",
|
381 |
+
"\n",
|
382 |
+
"print(len(type_combs))\n",
|
383 |
+
"print(type_combs)"
|
384 |
+
],
|
385 |
+
"metadata": {
|
386 |
+
"collapsed": false,
|
387 |
+
"pycharm": {
|
388 |
+
"name": "#%%\n"
|
389 |
+
}
|
390 |
+
}
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "code",
|
394 |
+
"execution_count": 46,
|
395 |
+
"outputs": [
|
396 |
+
{
|
397 |
+
"name": "stdout",
|
398 |
+
"output_type": "stream",
|
399 |
+
"text": [
|
400 |
+
"build a [size] shape from the objects ('size',)\n",
|
401 |
+
"put the objects in to a [size] shape ('size',)\n",
|
402 |
+
"place the objects as a [size] shape ('size',)\n",
|
403 |
+
"make a [size] shape from the objects ('size',)\n",
|
404 |
+
"rearrange the objects into a [size] structure ('size',)\n",
|
405 |
+
"build a [shape] ('shape',)\n",
|
406 |
+
"make a [shape] ('shape',)\n",
|
407 |
+
"put the objects into a [shape] ('shape',)\n",
|
408 |
+
"place the objects as a [shape] ('shape',)\n",
|
409 |
+
"pick up the objects, and place them as a [shape] ('shape',)\n",
|
410 |
+
"place the objects on the [x] of the table ('x',)\n",
|
411 |
+
"put the objects on [x] ('x',)\n",
|
412 |
+
"make a structure from the objects and place it on [x] ('x',)\n",
|
413 |
+
"on the [x] of the table, place the objects ('x',)\n",
|
414 |
+
"move the objects to the [x] ('x',)\n",
|
415 |
+
"place the objects on the [y] of the table ('y',)\n",
|
416 |
+
"put the objects on [y] ('y',)\n",
|
417 |
+
"make a structure from the objects and place it on [y] ('y',)\n",
|
418 |
+
"on the [y] of the table, place the objects ('y',)\n",
|
419 |
+
"move the objects to the [y] ('y',)\n",
|
420 |
+
"build a structure facing [rotation] ('rotation',)\n",
|
421 |
+
"make a structure from the objects and make sure it is pointing [rotation] ('rotation',)\n",
|
422 |
+
"put the objects in a structure that faces [rotation] ('rotation',)\n",
|
423 |
+
"rotate the object structure so that it points [rotation] ('rotation',)\n",
|
424 |
+
"[rotation] is the direction the structure built from the objects should be facing ('rotation',)\n",
|
425 |
+
"build a [size] [shape] ('shape', 'size')\n",
|
426 |
+
"make a [size] [shape] ('shape', 'size')\n",
|
427 |
+
"put the objects into a [size] [shape] ('shape', 'size')\n",
|
428 |
+
"place the objects as a [size] [shape] ('shape', 'size')\n",
|
429 |
+
"pick up the objects, and place them as a [size] [shape] ('shape', 'size')\n",
|
430 |
+
"build a [size] shape from the objects on the [x] of the table ('size', 'x')\n",
|
431 |
+
"put the objects in to a [size] shape and place it on [x] ('size', 'x')\n",
|
432 |
+
"on the [x] of the table, place the objects as a [size] shape ('size', 'x')\n",
|
433 |
+
"make a [size] shape from the objects and move it to [x] ('size', 'x')\n",
|
434 |
+
"rearrange the objects into a [size] structure on [x] ('size', 'x')\n",
|
435 |
+
"build a [size] shape from the objects on the [y] of the table ('size', 'y')\n",
|
436 |
+
"put the objects in to a [size] shape and place it on [y] ('size', 'y')\n",
|
437 |
+
"on the [y] of the table, place the objects as a [size] shape ('size', 'y')\n",
|
438 |
+
"make a [size] shape from the objects and move it to [y] ('size', 'y')\n",
|
439 |
+
"rearrange the objects into a [size] structure on [y] ('size', 'y')\n",
|
440 |
+
"build a [size] shape from the objects facing [rotation] ('rotation', 'size')\n",
|
441 |
+
"put the objects in to a [size] shape and place it so that it faces [rotation] ('rotation', 'size')\n",
|
442 |
+
"place the objects as a [size] shape and [rotation] is the direction the shape built from the objects should be facing ('rotation', 'size')\n",
|
443 |
+
"make a [size] structure from the objects and rotate the object structure so that it points [rotation] ('rotation', 'size')\n",
|
444 |
+
"rearrange the objects into a [size] structure that points to [rotation] ('rotation', 'size')\n",
|
445 |
+
"build a [shape] from the objects on the [x] of the table ('shape', 'x')\n",
|
446 |
+
"put the objects in to a [shape] and place it on [x] ('shape', 'x')\n",
|
447 |
+
"on the [x] of the table, place the objects as a [shape] ('shape', 'x')\n",
|
448 |
+
"make a [shape] from the objects and move it to [x] ('shape', 'x')\n",
|
449 |
+
"rearrange the objects into a [shape] on [x] ('shape', 'x')\n",
|
450 |
+
"build a [shape] from the objects on the [y] of the table ('shape', 'y')\n",
|
451 |
+
"put the objects in to a [shape] and place it on [y] ('shape', 'y')\n",
|
452 |
+
"on the [y] of the table, place the objects as a [shape] ('shape', 'y')\n",
|
453 |
+
"make a [shape] from the objects and move it to [y] ('shape', 'y')\n",
|
454 |
+
"rearrange the objects into a [shape] on [y] ('shape', 'y')\n",
|
455 |
+
"build a [shape] from the objects facing [rotation] ('rotation', 'shape')\n",
|
456 |
+
"put the objects in to a [shape] and place it so that it faces [rotation] ('rotation', 'shape')\n",
|
457 |
+
"place the objects as a [shape] and [rotation] is the direction the shape built from the objects should be facing ('rotation', 'shape')\n",
|
458 |
+
"make a [shape] from the objects and rotate the shape so that it points [rotation] ('rotation', 'shape')\n",
|
459 |
+
"rearrange the objects into a [shape] that points to [rotation] ('rotation', 'shape')\n",
|
460 |
+
"place the objects on the [x] and [y] of the table ('x', 'y')\n",
|
461 |
+
"put the objects on [x] [y] of the table ('x', 'y')\n",
|
462 |
+
"make a structure from the objects and place it on [x] [y] ('x', 'y')\n",
|
463 |
+
"on the [x] [y] of the table, place the objects ('x', 'y')\n",
|
464 |
+
"move the objects to the [x] [y] ('x', 'y')\n",
|
465 |
+
"build a structure on the [x] of the table facing [rotation] ('rotation', 'x')\n",
|
466 |
+
"make a structure from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'x')\n",
|
467 |
+
"rearrange the objects in a structure that faces [rotation] and place it on [x] ('rotation', 'x')\n",
|
468 |
+
"move and rotate the object structure so that it is on [x] and points [rotation] ('rotation', 'x')\n",
|
469 |
+
"[rotation] is the direction the structure built from the objects should be facing, [x] is the location ('rotation', 'x')\n",
|
470 |
+
"build a structure on the [y] of the table facing [rotation] ('rotation', 'y')\n",
|
471 |
+
"make a structure from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'y')\n",
|
472 |
+
"rearrange the objects in a structure that faces [rotation] and place it on [y] ('rotation', 'y')\n",
|
473 |
+
"move and rotate the object structure so that it is on [y] and points [rotation] ('rotation', 'y')\n",
|
474 |
+
"[rotation] is the direction the structure built from the objects should be facing, [y] is the location ('rotation', 'y')\n",
|
475 |
+
"build a [size] [shape] from the objects on the [x] of the table ('shape', 'size', 'x')\n",
|
476 |
+
"put the objects in to a [size] [shape] and place it on [x] ('shape', 'size', 'x')\n",
|
477 |
+
"on the [x] of the table, place the objects as a [shape], make the shape [size] ('shape', 'size', 'x')\n",
|
478 |
+
"make a [size] [shape] from the objects and move it to [x] ('shape', 'size', 'x')\n",
|
479 |
+
"rearrange the objects into a [size] [shape] on [x] ('shape', 'size', 'x')\n",
|
480 |
+
"build a [size] [shape] from the objects on the [y] of the table ('shape', 'size', 'y')\n",
|
481 |
+
"put the objects in to a [size] [shape] and place it on [y] ('shape', 'size', 'y')\n",
|
482 |
+
"on the [y] of the table, place the objects as a [shape], make the shape [size] ('shape', 'size', 'y')\n",
|
483 |
+
"make a [size] [shape] from the objects and move it to [y] ('shape', 'size', 'y')\n",
|
484 |
+
"rearrange the objects into a [size] [shape] on [y] ('shape', 'size', 'y')\n",
|
485 |
+
"build a [size] [shape] from the objects facing [rotation] ('rotation', 'shape', 'size')\n",
|
486 |
+
"put the objects in to a [size] [shape] and place it so that it faces [rotation] ('rotation', 'shape', 'size')\n",
|
487 |
+
"place the objects as a [size] [shape] and [rotation] is the direction the shape built from the objects should be facing ('rotation', 'shape', 'size')\n",
|
488 |
+
"make a [size] [shape] from the objects and rotate the shape so that it points [rotation] ('rotation', 'shape', 'size')\n",
|
489 |
+
"rearrange the objects into a [size] [shape] that points to [rotation] ('rotation', 'shape', 'size')\n",
|
490 |
+
"build a [size] shape from the objects on the [x] [y] of the table ('size', 'x', 'y')\n",
|
491 |
+
"put the objects in to a [size] shape and place it on [x] and [y] ('size', 'x', 'y')\n",
|
492 |
+
"on the [x] [y] of the table, place the objects as a [size] shape ('size', 'x', 'y')\n",
|
493 |
+
"make a [size] shape from the objects and move it to [x] [y] ('size', 'x', 'y')\n",
|
494 |
+
"rearrange the objects into a [size] structure on [x] and on [y] ('size', 'x', 'y')\n",
|
495 |
+
"build a [size] structure on the [x] of the table facing [rotation] ('rotation', 'size', 'x')\n",
|
496 |
+
"make a [size] structure from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'size', 'x')\n",
|
497 |
+
"rearrange the objects in a [size] structure that faces [rotation] and place it on [x] ('rotation', 'size', 'x')\n",
|
498 |
+
"move and rotate the [size] object structure so that it is on [x] and points [rotation] ('rotation', 'size', 'x')\n",
|
499 |
+
"[rotation] is the direction the [size] structure built from the objects should be facing, [x] is the location ('rotation', 'size', 'x')\n",
|
500 |
+
"build a [size] structure on the [y] of the table facing [rotation] ('rotation', 'size', 'y')\n",
|
501 |
+
"make a [size] structure from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'size', 'y')\n",
|
502 |
+
"rearrange the objects in a [size] structure that faces [rotation] and place it on [y] ('rotation', 'size', 'y')\n",
|
503 |
+
"move and rotate the [size] object structure so that it is on [y] and points [rotation] ('rotation', 'size', 'y')\n",
|
504 |
+
"[rotation] is the direction the [size] structure built from the objects should be facing, [y] is the location ('rotation', 'size', 'y')\n",
|
505 |
+
"build a [shape] from the objects on the [x] [y] of the table ('shape', 'x', 'y')\n",
|
506 |
+
"put the objects in to a [shape] and place it on [x] and [y] ('shape', 'x', 'y')\n",
|
507 |
+
"on the [x] [y] of the table, place the objects as a [shape] ('shape', 'x', 'y')\n",
|
508 |
+
"make a [shape] from the objects and move it to [x] [y] ('shape', 'x', 'y')\n",
|
509 |
+
"rearrange the objects into a [shape] on [x] and on [y] ('shape', 'x', 'y')\n",
|
510 |
+
"build a [shape] on the [x] of the table facing [rotation] ('rotation', 'shape', 'x')\n",
|
511 |
+
"make a [shape] from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'shape', 'x')\n",
|
512 |
+
"rearrange the objects in a [shape] that faces [rotation] and place it on [x] ('rotation', 'shape', 'x')\n",
|
513 |
+
"move and rotate the [shape] so that it is on [x] and points [rotation] ('rotation', 'shape', 'x')\n",
|
514 |
+
"[rotation] is the direction the [shape] built from the objects should be facing, [x] is the location ('rotation', 'shape', 'x')\n",
|
515 |
+
"build a [shape] on the [y] of the table facing [rotation] ('rotation', 'shape', 'y')\n",
|
516 |
+
"make a [shape] from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'shape', 'y')\n",
|
517 |
+
"rearrange the objects in a [shape] that faces [rotation] and place it on [y] ('rotation', 'shape', 'y')\n",
|
518 |
+
"move and rotate the [shape] so that it is on [y] and points [rotation] ('rotation', 'shape', 'y')\n",
|
519 |
+
"[rotation] is the direction the [shape] built from the objects should be facing, [y] is the location ('rotation', 'shape', 'y')\n",
|
520 |
+
"build a structure on the [x] [y] of the table facing [rotation] ('rotation', 'x', 'y')\n",
|
521 |
+
"make a structure from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'x', 'y')\n",
|
522 |
+
"rearrange the objects in a structure that faces [rotation] and place it on [x] [y] ('rotation', 'x', 'y')\n",
|
523 |
+
"move and rotate the object structure so that it is on [x] [y] and points [rotation] ('rotation', 'x', 'y')\n",
|
524 |
+
"[rotation] is the direction the structure built from the objects should be facing, [x] [y] is the location ('rotation', 'x', 'y')\n",
|
525 |
+
"build a [shape] from the objects on the [x] [y] of the table, make the [shape] [size] ('shape', 'size', 'x', 'y')\n",
|
526 |
+
"put the objects in to a [size] [shape] and place it on [x] and [y] ('shape', 'size', 'x', 'y')\n",
|
527 |
+
"on the [x] [y] of the table, place the objects as a [size] [shape] ('shape', 'size', 'x', 'y')\n",
|
528 |
+
"make a [size] [shape] from the objects and move it to [x] [y] ('shape', 'size', 'x', 'y')\n",
|
529 |
+
"rearrange the objects into a [size] [shape] on [x] and on [y] ('shape', 'size', 'x', 'y')\n",
|
530 |
+
"build a [size] [shape] on the [x] of the table facing [rotation] ('rotation', 'shape', 'size', 'x')\n",
|
531 |
+
"make a [size] [shape] from the objects and make sure it is pointing [rotation] and on [x] ('rotation', 'shape', 'size', 'x')\n",
|
532 |
+
"rearrange the objects in a [size] [shape] that faces [rotation] and place it on [x] ('rotation', 'shape', 'size', 'x')\n",
|
533 |
+
"move and rotate the [size] [shape] so that it is on [x] and points [rotation] ('rotation', 'shape', 'size', 'x')\n",
|
534 |
+
"[rotation] is the direction the [size] [shape] built from the objects should be facing, [x] is the location ('rotation', 'shape', 'size', 'x')\n",
|
535 |
+
"build a [size] [shape] on the [y] of the table facing [rotation] ('rotation', 'shape', 'size', 'y')\n",
|
536 |
+
"make a [size] [shape] from the objects and make sure it is pointing [rotation] and on [y] ('rotation', 'shape', 'size', 'y')\n",
|
537 |
+
"rearrange the objects in a [size] [shape] that faces [rotation] and place it on [y] ('rotation', 'shape', 'size', 'y')\n",
|
538 |
+
"move and rotate the [size] [shape] so that it is on [y] and points [rotation] ('rotation', 'shape', 'size', 'y')\n",
|
539 |
+
"[rotation] is the direction the [size] [shape] built from the objects should be facing, [y] is the location ('rotation', 'shape', 'size', 'y')\n",
|
540 |
+
"build a [size] structure on the [x] [y] of the table facing [rotation] ('rotation', 'size', 'x', 'y')\n",
|
541 |
+
"make a [size] structure from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'size', 'x', 'y')\n",
|
542 |
+
"rearrange the objects in a [size] structure that faces [rotation] and place it on [x] [y] ('rotation', 'size', 'x', 'y')\n",
|
543 |
+
"move and rotate the [size] object structure so that it is on [x] [y] and points [rotation] ('rotation', 'size', 'x', 'y')\n",
|
544 |
+
"[rotation] is the direction the [size] structure built from the objects should be facing, [x] [y] is the location ('rotation', 'size', 'x', 'y')\n",
|
545 |
+
"build a [shape] on the [x] [y] of the table facing [rotation] ('rotation', 'shape', 'x', 'y')\n",
|
546 |
+
"make a [shape] from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'shape', 'x', 'y')\n",
|
547 |
+
"rearrange the objects as a [shape] that faces [rotation] and place it on [x] [y] ('rotation', 'shape', 'x', 'y')\n",
|
548 |
+
"move and rotate the [shape] so that it is on [x] [y] and points [rotation] ('rotation', 'shape', 'x', 'y')\n",
|
549 |
+
"[rotation] is the direction the [shape] built from the objects should be facing, [x] [y] is the location ('rotation', 'shape', 'x', 'y')\n",
|
550 |
+
"build a [size] [shape] on the [x] [y] of the table facing [rotation] ('rotation', 'shape', 'size', 'x', 'y')\n",
|
551 |
+
"make a [size] [shape] from the objects and make sure it is pointing [rotation] and on [x] [y] ('rotation', 'shape', 'size', 'x', 'y')\n",
|
552 |
+
"rearrange the objects as a [size] [shape] that faces [rotation] and place it on [x] [y] ('rotation', 'shape', 'size', 'x', 'y')\n",
|
553 |
+
"move and rotate the [size] [shape] so that it is on [x] [y] and points [rotation] ('rotation', 'shape', 'size', 'x', 'y')\n",
|
554 |
+
"[rotation] is the direction the [size] [shape] built from the objects should be facing, [x] [y] is the location ('rotation', 'shape', 'size', 'x', 'y')\n"
|
555 |
+
]
|
556 |
+
}
|
557 |
+
],
|
558 |
+
"source": [
|
559 |
+
"sentence_template_file = \"/home/weiyu/Research/intern/StructDiffuser/src/StructDiffuser/language/sentence_template.txt\"\n",
|
560 |
+
"\n",
|
561 |
+
"import re\n",
|
562 |
+
"\n",
|
563 |
+
"type_comb_to_templates = {}\n",
|
564 |
+
"for type_comb in type_combs:\n",
|
565 |
+
" type_comb_to_templates[type_comb] = []\n",
|
566 |
+
"\n",
|
567 |
+
"with open(sentence_template_file, \"r\") as fh:\n",
|
568 |
+
" for line in fh:\n",
|
569 |
+
" line = line.strip()\n",
|
570 |
+
" if line:\n",
|
571 |
+
" if line[0] == \"#\":\n",
|
572 |
+
" continue\n",
|
573 |
+
" type_list = re.findall('\\[[^\\]]*\\]', line)\n",
|
574 |
+
" type_comb = tuple(sorted(list(set([t[1:-1] for t in type_list]))))\n",
|
575 |
+
" print(line, type_comb)\n",
|
576 |
+
"\n",
|
577 |
+
" type_comb_to_templates[type_comb].append(line)"
|
578 |
+
],
|
579 |
+
"metadata": {
|
580 |
+
"collapsed": false,
|
581 |
+
"pycharm": {
|
582 |
+
"name": "#%%\n"
|
583 |
+
}
|
584 |
+
}
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"cell_type": "code",
|
588 |
+
"execution_count": 47,
|
589 |
+
"outputs": [],
|
590 |
+
"source": [
|
591 |
+
"for type_comb in type_comb_to_templates:\n",
|
592 |
+
" if len(type_comb_to_templates[type_comb]) != 5:\n",
|
593 |
+
" print(\"{} does not have 5 templates\".format(type_comb))"
|
594 |
+
],
|
595 |
+
"metadata": {
|
596 |
+
"collapsed": false,
|
597 |
+
"pycharm": {
|
598 |
+
"name": "#%%\n"
|
599 |
+
}
|
600 |
+
}
|
601 |
+
},
|
602 |
+
{
|
603 |
+
"cell_type": "code",
|
604 |
+
"execution_count": 58,
|
605 |
+
"outputs": [
|
606 |
+
{
|
607 |
+
"name": "stderr",
|
608 |
+
"output_type": "stream",
|
609 |
+
"text": [
|
610 |
+
"100%|██████████| 669/669 [00:00<00:00, 60546.98it/s]\n"
|
611 |
+
]
|
612 |
+
}
|
613 |
+
],
|
614 |
+
"source": [
|
615 |
+
"template_sentences = []\n",
|
616 |
+
"type_value_tuple_to_template_sentences = defaultdict(set)\n",
|
617 |
+
"for type_value_tuple in tqdm.tqdm(list(unique_type_value_tuples)):\n",
|
618 |
+
" type_comb = tuple(sorted([tv[0] for tv in type_value_tuple]))\n",
|
619 |
+
" template_sentences = copy.deepcopy(type_comb_to_templates[type_comb])\n",
|
620 |
+
"\n",
|
621 |
+
" # print(type_value_tuple)\n",
|
622 |
+
" for template_sentence in template_sentences:\n",
|
623 |
+
" for t, v in type_value_tuple:\n",
|
624 |
+
" template_sentence = template_sentence.replace(\"[{}]\".format(t), v)\n",
|
625 |
+
" # print(template_sentence)\n",
|
626 |
+
"\n",
|
627 |
+
" type_value_tuple_to_template_sentences[type_value_tuple].add(template_sentence)\n",
|
628 |
+
"\n",
|
629 |
+
"# convert to list\n",
|
630 |
+
"for type_value_tuple in type_value_tuple_to_template_sentences:\n",
|
631 |
+
" type_value_tuple_to_template_sentences[type_value_tuple] = list(type_value_tuple_to_template_sentences[type_value_tuple])"
|
632 |
+
],
|
633 |
+
"metadata": {
|
634 |
+
"collapsed": false,
|
635 |
+
"pycharm": {
|
636 |
+
"name": "#%%\n"
|
637 |
+
}
|
638 |
+
}
|
639 |
+
},
|
640 |
+
{
|
641 |
+
"cell_type": "code",
|
642 |
+
"execution_count": 73,
|
643 |
+
"outputs": [
|
644 |
+
{
|
645 |
+
"name": "stdout",
|
646 |
+
"output_type": "stream",
|
647 |
+
"text": [
|
648 |
+
"3345 unique template sentences\n"
|
649 |
+
]
|
650 |
+
}
|
651 |
+
],
|
652 |
+
"source": [
|
653 |
+
"unique_template_sentences = set()\n",
|
654 |
+
"\n",
|
655 |
+
"for type_value_tuple in type_value_tuple_to_template_sentences:\n",
|
656 |
+
" # print(\"\\n\")\n",
|
657 |
+
" # print(type_value_tuple)\n",
|
658 |
+
" for template_sentence in type_value_tuple_to_template_sentences[type_value_tuple]:\n",
|
659 |
+
" # print(template_sentence)\n",
|
660 |
+
" unique_template_sentences.add(template_sentence)\n",
|
661 |
+
"\n",
|
662 |
+
"unique_template_sentences = list(unique_template_sentences)\n",
|
663 |
+
"print(\"{} unique template sentences\".format(len(unique_template_sentences)))"
|
664 |
+
],
|
665 |
+
"metadata": {
|
666 |
+
"collapsed": false,
|
667 |
+
"pycharm": {
|
668 |
+
"name": "#%%\n"
|
669 |
+
}
|
670 |
+
}
|
671 |
+
},
|
672 |
+
{
|
673 |
+
"cell_type": "code",
|
674 |
+
"execution_count": 72,
|
675 |
+
"outputs": [],
|
676 |
+
"source": [
|
677 |
+
"from sentence_transformers import SentenceTransformer\n",
|
678 |
+
"model = SentenceTransformer('all-MiniLM-L6-v2')"
|
679 |
+
],
|
680 |
+
"metadata": {
|
681 |
+
"collapsed": false,
|
682 |
+
"pycharm": {
|
683 |
+
"name": "#%%\n"
|
684 |
+
}
|
685 |
+
}
|
686 |
+
},
|
687 |
+
{
|
688 |
+
"cell_type": "code",
|
689 |
+
"execution_count": 76,
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"name": "stdout",
|
693 |
+
"output_type": "stream",
|
694 |
+
"text": [
|
695 |
+
"(3345, 384)\n"
|
696 |
+
]
|
697 |
+
}
|
698 |
+
],
|
699 |
+
"source": [
|
700 |
+
"#Our sentences we like to encode\n",
|
701 |
+
"# sentences = ['This framework generates embeddings for each input sentence',\n",
|
702 |
+
"# 'Sentences are passed as a list of string.',\n",
|
703 |
+
"# 'The quick brown fox jumps over the lazy dog.']\n",
|
704 |
+
"#Sentences are encoded by calling model.encode()\n",
|
705 |
+
"\n",
|
706 |
+
"\n",
|
707 |
+
"embeddings = model.encode(unique_template_sentences)\n",
|
708 |
+
"print(embeddings.shape)"
|
709 |
+
],
|
710 |
+
"metadata": {
|
711 |
+
"collapsed": false,
|
712 |
+
"pycharm": {
|
713 |
+
"name": "#%%\n"
|
714 |
+
}
|
715 |
+
}
|
716 |
+
},
|
717 |
+
{
|
718 |
+
"cell_type": "code",
|
719 |
+
"execution_count": 80,
|
720 |
+
"outputs": [],
|
721 |
+
"source": [
|
722 |
+
"template_sentence_to_embedding = {}\n",
|
723 |
+
"for embedding, template_sentence in zip(embeddings, unique_template_sentences):\n",
|
724 |
+
" template_sentence_to_embedding[template_sentence] = embedding"
|
725 |
+
],
|
726 |
+
"metadata": {
|
727 |
+
"collapsed": false,
|
728 |
+
"pycharm": {
|
729 |
+
"name": "#%%\n"
|
730 |
+
}
|
731 |
+
}
|
732 |
+
},
|
733 |
+
{
|
734 |
+
"cell_type": "code",
|
735 |
+
"execution_count": 82,
|
736 |
+
"outputs": [],
|
737 |
+
"source": [
|
738 |
+
"import pickle\n",
|
739 |
+
"template_sentence_data = {\"template_sentence_to_embedding\": template_sentence_to_embedding,\n",
|
740 |
+
" \"type_value_tuple_to_template_sentences\": type_value_tuple_to_template_sentences}\n",
|
741 |
+
"with open(\"/home/weiyu/Research/intern/StructDiffuser/src/StructDiffuser/language/template_sentence_data.pkl\", \"wb\") as fh:\n",
|
742 |
+
" pickle.dump(template_sentence_data, fh)"
|
743 |
+
],
|
744 |
+
"metadata": {
|
745 |
+
"collapsed": false,
|
746 |
+
"pycharm": {
|
747 |
+
"name": "#%%\n"
|
748 |
+
}
|
749 |
+
}
|
750 |
+
}
|
751 |
+
],
|
752 |
+
"metadata": {
|
753 |
+
"kernelspec": {
|
754 |
+
"display_name": "Python 3",
|
755 |
+
"language": "python",
|
756 |
+
"name": "python3"
|
757 |
+
},
|
758 |
+
"language_info": {
|
759 |
+
"codemirror_mode": {
|
760 |
+
"name": "ipython",
|
761 |
+
"version": 2
|
762 |
+
},
|
763 |
+
"file_extension": ".py",
|
764 |
+
"mimetype": "text/x-python",
|
765 |
+
"name": "python",
|
766 |
+
"nbconvert_exporter": "python",
|
767 |
+
"pygments_lexer": "ipython2",
|
768 |
+
"version": "2.7.6"
|
769 |
+
}
|
770 |
+
},
|
771 |
+
"nbformat": 4,
|
772 |
+
"nbformat_minor": 0
|
773 |
+
}
|
src/StructDiffusion/language/sentence_encoder.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
|
3 |
+
class SentenceBertEncoder:
|
4 |
+
|
5 |
+
def __init__(self):
|
6 |
+
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
7 |
+
|
8 |
+
def encode(self, sentences):
|
9 |
+
#Our sentences we like to encode
|
10 |
+
# sentences = ['This framework generates embeddings for each input sentence',
|
11 |
+
# 'Sentences are passed as a list of string.',
|
12 |
+
# 'The quick brown fox jumps over the lazy dog.']
|
13 |
+
#Sentences are encoded by calling model.encode()
|
14 |
+
|
15 |
+
embeddings = self.model.encode(sentences)
|
16 |
+
# print(embeddings.shape)
|
17 |
+
return embeddings
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
sentence_encoder = SentenceBertEncoder()
|
22 |
+
embedding = sentence_encoder.encode(["this is cool!"])
|
23 |
+
print(embedding.shape)
|
src/StructDiffusion/language/test_parrot_paraphrase.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from parrot import Parrot
|
2 |
+
import torch
|
3 |
+
import warnings
|
4 |
+
warnings.filterwarnings("ignore")
|
5 |
+
|
6 |
+
# [top]
|
7 |
+
|
8 |
+
# Put the [objects] in a [size][shape] on the [x][y] of the table facing [rotation].
|
9 |
+
# Build a [size][shape] of the [objects] on the [x][y] of the table facing [rotation].
|
10 |
+
# Put the [objects] on the [x][y] of the table and make a [shape] facing [rotation].
|
11 |
+
# Rearrange the [objects] into a [shape], and put the structure on the [x][y] of the table facing [rotation].
|
12 |
+
# Could you ...
|
13 |
+
# Please ...
|
14 |
+
# Pick up the objects, put them into a [size][shape], place the [shape] on the [x][y] of table, make sure the [shape] is facing [rotation].
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
'''
|
18 |
+
uncomment to get reproducable paraphrase generations
|
19 |
+
def random_state(seed):
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
torch.cuda.manual_seed_all(seed)
|
23 |
+
|
24 |
+
random_state(1234)
|
25 |
+
'''
|
26 |
+
|
27 |
+
#Init models (make sure you init ONLY once if you integrate this to your code)
|
28 |
+
parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5")
|
29 |
+
|
30 |
+
phrases = ["Rearrange the mugs in a circle on the top left of the table."]
|
31 |
+
|
32 |
+
for phrase in phrases:
|
33 |
+
print("-"*100)
|
34 |
+
print("Input_phrase: ", phrase)
|
35 |
+
print("-"*100)
|
36 |
+
para_phrases = parrot.augment(input_phrase=phrase, use_gpu=False, max_return_phrases=100, do_diverse=True)
|
37 |
+
for para_phrase in para_phrases:
|
38 |
+
print(para_phrase)
|
src/StructDiffusion/language/tokenizer.py
CHANGED
@@ -517,25 +517,4 @@ class ContinuousTokenizer:
|
|
517 |
idx = value
|
518 |
else:
|
519 |
raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value))
|
520 |
-
return idx
|
521 |
-
|
522 |
-
|
523 |
-
if __name__ == "__main__":
|
524 |
-
tokenizer = Tokenizer("/home/weiyu/data_drive/data_new_objects/type_vocabs_coarse.json")
|
525 |
-
# print(tokenizer.get_all_values_of_type("class"))
|
526 |
-
# print(tokenizer.get_all_values_of_type("color"))
|
527 |
-
# print(tokenizer.get_all_values_of_type("material"))
|
528 |
-
#
|
529 |
-
# for type in tokenizer.type_vocabs:
|
530 |
-
# print(type, tokenizer.type_vocabs[type])
|
531 |
-
|
532 |
-
tokenizer.prepare_grounding_reference()
|
533 |
-
|
534 |
-
# for i in range(100):
|
535 |
-
# types = list(tokenizer.continuous_types) + list(tokenizer.discrete_types)
|
536 |
-
# for t in types:
|
537 |
-
# v = tokenizer.get_valid_random_value(t)
|
538 |
-
# print(v)
|
539 |
-
# print(tokenizer.tokenize(v, t))
|
540 |
-
|
541 |
-
# build_vocab("/home/weiyu/data_drive/examples_v4/leonardo/vocab.json", "/home/weiyu/data_drive/examples_v4/leonardo/type_vocabs.json")
|
|
|
517 |
idx = value
|
518 |
else:
|
519 |
raise KeyError("Do not recognize the type {} of the given token: {}".format(typ, value))
|
520 |
+
return idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/StructDiffusion/models/__pycache__/models.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc and b/src/StructDiffusion/models/__pycache__/models.cpython-38.pyc differ
|
|
src/StructDiffusion/models/models.py
CHANGED
@@ -26,6 +26,8 @@ class TransformerDiffusionModel(torch.nn.Module):
|
|
26 |
word_emb_dim=160,
|
27 |
time_emb_dim=80,
|
28 |
use_virtual_structure_frame=True,
|
|
|
|
|
29 |
):
|
30 |
super(TransformerDiffusionModel, self).__init__()
|
31 |
|
@@ -53,7 +55,12 @@ class TransformerDiffusionModel(torch.nn.Module):
|
|
53 |
self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim
|
54 |
|
55 |
# for language
|
56 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# for diffusion
|
59 |
self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim))
|
@@ -88,7 +95,10 @@ class TransformerDiffusionModel(torch.nn.Module):
|
|
88 |
|
89 |
batch_size, num_objects, num_pts, _ = pcs.shape
|
90 |
_, num_poses, _ = poses.shape
|
91 |
-
|
|
|
|
|
|
|
92 |
_, total_len = type_index.shape
|
93 |
|
94 |
pcs = pcs.reshape(batch_size * num_objects, num_pts, -1)
|
@@ -102,7 +112,11 @@ class TransformerDiffusionModel(torch.nn.Module):
|
|
102 |
tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1)
|
103 |
|
104 |
#########################
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
|
107 |
#########################
|
108 |
|
|
|
26 |
word_emb_dim=160,
|
27 |
time_emb_dim=80,
|
28 |
use_virtual_structure_frame=True,
|
29 |
+
use_sentence_embedding=False,
|
30 |
+
sentence_embedding_dim=None,
|
31 |
):
|
32 |
super(TransformerDiffusionModel, self).__init__()
|
33 |
|
|
|
55 |
self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim
|
56 |
|
57 |
# for language
|
58 |
+
self.sentence_embedding_dim = sentence_embedding_dim
|
59 |
+
self.use_sentence_embedding = use_sentence_embedding
|
60 |
+
if use_sentence_embedding:
|
61 |
+
self.sentence_embedding_down_sample = torch.nn.Linear(sentence_embedding_dim, word_emb_dim)
|
62 |
+
else:
|
63 |
+
self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0)
|
64 |
|
65 |
# for diffusion
|
66 |
self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim))
|
|
|
95 |
|
96 |
batch_size, num_objects, num_pts, _ = pcs.shape
|
97 |
_, num_poses, _ = poses.shape
|
98 |
+
if self.use_sentence_embedding:
|
99 |
+
assert sentence.shape == (batch_size, self.sentence_embedding_dim), sentence.shape
|
100 |
+
else:
|
101 |
+
_, sentence_len = sentence.shape
|
102 |
_, total_len = type_index.shape
|
103 |
|
104 |
pcs = pcs.reshape(batch_size * num_objects, num_pts, -1)
|
|
|
112 |
tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1)
|
113 |
|
114 |
#########################
|
115 |
+
if self.use_sentence_embedding:
|
116 |
+
# sentence: B, sentence_embedding_dim
|
117 |
+
sentence_embed = self.sentence_embedding_down_sample(sentence).unsqueeze(1) # B, 1, word_emb_dim
|
118 |
+
else:
|
119 |
+
sentence_embed = self.word_embeddings(sentence)
|
120 |
|
121 |
#########################
|
122 |
|
src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/batch_inference.cpython-38.pyc differ
|
|
src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/files.cpython-38.pyc differ
|
|
src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/rearrangement.cpython-38.pyc differ
|
|
src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc
CHANGED
Binary files a/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc and b/src/StructDiffusion/utils/__pycache__/rotation_continuity.cpython-38.pyc differ
|
|
src/StructDiffusion/utils/__pycache__/tra3d.cpython-38.pyc
ADDED
Binary file (4.72 kB). View file
|
|
src/StructDiffusion/utils/batch_inference.py
CHANGED
@@ -1,175 +1,10 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
import pytorch3d.transforms as tra3d
|
5 |
|
6 |
from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh
|
7 |
from StructDiffusion.utils.pointnet import random_point_sample, index_points
|
8 |
-
|
9 |
-
|
10 |
-
def move_pc_and_create_scene(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds,
|
11 |
-
num_scene_pts, device, normalize_pc=False,
|
12 |
-
return_pair_pc=False, normalize_pair_pc=False, num_pair_pc_pts=None,
|
13 |
-
return_scene_pts=True, return_scene_pts_and_pc_idxs=False):
|
14 |
-
|
15 |
-
# obj_xyzs: N, P, 3
|
16 |
-
# obj_params: B, N, 6
|
17 |
-
# struct_pose: B x N, 4, 4
|
18 |
-
# current_pc_pose: B x N, 4, 4
|
19 |
-
# target_object_inds: 1, N
|
20 |
-
|
21 |
-
B, N, _ = obj_params.shape
|
22 |
-
_, P, _ = obj_xyzs.shape
|
23 |
-
|
24 |
-
# B, N, 6
|
25 |
-
flat_obj_params = obj_params.reshape(B * N, -1)
|
26 |
-
goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
|
27 |
-
goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
|
28 |
-
goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
|
29 |
-
|
30 |
-
goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
|
31 |
-
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
|
32 |
-
|
33 |
-
# important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
34 |
-
transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
35 |
-
|
36 |
-
# obj_xyzs: N, P, 3
|
37 |
-
new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
|
38 |
-
new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
|
39 |
-
|
40 |
-
# put it back to B, N, P, 3
|
41 |
-
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
|
42 |
-
# visualize_batch_pcs(new_obj_xyzs, S, N, P)
|
43 |
-
|
44 |
-
# ===================================
|
45 |
-
# Pass to discriminator
|
46 |
-
subsampled_scene_xyz = None
|
47 |
-
if return_scene_pts:
|
48 |
-
|
49 |
-
num_indicator = N
|
50 |
-
|
51 |
-
# add one hot
|
52 |
-
indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device) # B, N, P, N
|
53 |
-
# print(indicator_variables.shape)
|
54 |
-
# print(new_obj_xyzs.shape)
|
55 |
-
new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1) # B, N, P, 3 + N
|
56 |
-
|
57 |
-
# combine pcs in each scene
|
58 |
-
scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)
|
59 |
-
|
60 |
-
# ToDo: maybe convert this to a batch operation
|
61 |
-
subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
|
62 |
-
for si, scene_xyz in enumerate(scene_xyzs):
|
63 |
-
# scene_xyz: N*P, 3+N
|
64 |
-
# target_object_inds: 1, N
|
65 |
-
subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
|
66 |
-
subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
|
67 |
-
|
68 |
-
# # debug:
|
69 |
-
# print("-"*50)
|
70 |
-
# if si < 10:
|
71 |
-
# trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
|
72 |
-
# trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()
|
73 |
-
|
74 |
-
# subsampled_scene_xyz: B, num_scene_pts, 3+N
|
75 |
-
# new_obj_xyzs: B, N, P, 3
|
76 |
-
# goal_pc_pose: B, N, 4, 4
|
77 |
-
|
78 |
-
# important:
|
79 |
-
if normalize_pc:
|
80 |
-
subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
|
81 |
-
|
82 |
-
# # debug:
|
83 |
-
# for si in range(10):
|
84 |
-
# trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
|
85 |
-
|
86 |
-
if return_scene_pts_and_pc_idxs:
|
87 |
-
num_indicator = N
|
88 |
-
pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device) # B, N, P
|
89 |
-
# new_obj_xyzs: B, N, P, 3 + 1
|
90 |
-
|
91 |
-
# combine pcs in each scene
|
92 |
-
scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
|
93 |
-
pc_idxs = pc_idxs.reshape(B, N*P)
|
94 |
-
|
95 |
-
subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
|
96 |
-
subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
|
97 |
-
for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
|
98 |
-
# scene_xyz: N*P, 3+1
|
99 |
-
# target_object_inds: 1, N
|
100 |
-
subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
|
101 |
-
subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
|
102 |
-
subsampled_pc_idxs[si] = pc_idx[subsample_idx]
|
103 |
-
|
104 |
-
# subsampled_scene_xyz: B, num_scene_pts, 3
|
105 |
-
# subsampled_pc_idxs: B, num_scene_pts
|
106 |
-
# new_obj_xyzs: B, N, P, 3
|
107 |
-
# goal_pc_pose: B, N, 4, 4
|
108 |
-
|
109 |
-
# important:
|
110 |
-
if normalize_pc:
|
111 |
-
subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])
|
112 |
-
|
113 |
-
# TODO: visualize each individual object
|
114 |
-
# debug
|
115 |
-
# print(subsampled_scene_xyz.shape)
|
116 |
-
# print(subsampled_pc_idxs.shape)
|
117 |
-
# print("visualize subsampled scene")
|
118 |
-
# for si in range(5):
|
119 |
-
# trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()
|
120 |
-
|
121 |
-
###############################################
|
122 |
-
# Create input for pairwise collision detector
|
123 |
-
if return_pair_pc:
|
124 |
-
|
125 |
-
assert num_pair_pc_pts is not None
|
126 |
-
|
127 |
-
# new_obj_xyzs: B, N, P, 3 + N
|
128 |
-
# target_object_inds: 1, N
|
129 |
-
# ignore paddings
|
130 |
-
num_objs = torch.sum(target_object_inds[0])
|
131 |
-
obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2) # num_comb, 2
|
132 |
-
|
133 |
-
# use [:, :, :, :3] to get obj_xyzs without object-wise indicator
|
134 |
-
obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs] # B, num_comb, 2 (obj 1 and obj 2), P, 3
|
135 |
-
num_comb = obj_pair_xyzs.shape[1]
|
136 |
-
pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device) # B, num_comb, 2, P, 2
|
137 |
-
obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1) # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
|
138 |
-
obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)
|
139 |
-
|
140 |
-
# random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
|
141 |
-
obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
|
142 |
-
# random_point_sample() input dim: B, N, C
|
143 |
-
rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts) # B * num_comb, num_pair_pc_pts
|
144 |
-
obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs) # B * num_comb, num_pair_pc_pts, 5
|
145 |
-
|
146 |
-
if normalize_pair_pc:
|
147 |
-
# pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
|
148 |
-
# obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
|
149 |
-
obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
|
150 |
-
obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)
|
151 |
-
|
152 |
-
# # debug
|
153 |
-
# for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
|
154 |
-
# print("batch id", bi)
|
155 |
-
# for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
|
156 |
-
# print("pair", pi)
|
157 |
-
# # obj_pair_xyzs: 2 * P, 5
|
158 |
-
# print(obj_pair_xyz[:, :3].shape)
|
159 |
-
# trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
|
160 |
-
|
161 |
-
# obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
|
162 |
-
goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
|
163 |
-
|
164 |
-
# TODO: update the return logic, a mess right now
|
165 |
-
if return_scene_pts_and_pc_idxs:
|
166 |
-
return subsampled_scene_xyz, subsampled_pc_idxs, new_obj_xyzs, goal_pc_pose
|
167 |
-
|
168 |
-
if return_pair_pc:
|
169 |
-
return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose, obj_pair_xyzs
|
170 |
-
else:
|
171 |
-
return subsampled_scene_xyz, new_obj_xyzs, goal_pc_pose
|
172 |
-
|
173 |
|
174 |
def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
|
175 |
return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
|
@@ -193,12 +28,17 @@ def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_p
|
|
193 |
goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
|
194 |
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
|
195 |
|
196 |
-
# important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
197 |
-
transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
# obj_xyzs: N, P, 3
|
200 |
-
new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
|
201 |
-
new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
|
202 |
|
203 |
# put it back to B, N, P, 3
|
204 |
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
|
@@ -332,82 +172,6 @@ def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_p
|
|
332 |
return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
|
333 |
|
334 |
|
335 |
-
def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device):
|
336 |
-
|
337 |
-
# obj_xyzs: N, P, 3
|
338 |
-
# obj_params: B, N, 6
|
339 |
-
# struct_pose: B x N, 4, 4
|
340 |
-
# current_pc_pose: B x N, 4, 4
|
341 |
-
# target_object_inds: 1, N
|
342 |
-
|
343 |
-
B, N, _ = obj_params.shape
|
344 |
-
_, P, _ = obj_xyzs.shape
|
345 |
-
|
346 |
-
# B, N, 6
|
347 |
-
flat_obj_params = obj_params.reshape(B * N, -1)
|
348 |
-
goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
|
349 |
-
goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
|
350 |
-
goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3] # B x N, 4, 4
|
351 |
-
|
352 |
-
goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
|
353 |
-
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
|
354 |
-
|
355 |
-
# important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
356 |
-
transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
357 |
-
|
358 |
-
# obj_xyzs: N, P, 3
|
359 |
-
new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
|
360 |
-
new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
|
361 |
-
|
362 |
-
# put it back to B, N, P, 3
|
363 |
-
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
|
364 |
-
# visualize_batch_pcs(new_obj_xyzs, S, N, P)
|
365 |
-
|
366 |
-
# subsampled_scene_xyz: B, num_scene_pts, 3+N
|
367 |
-
# new_obj_xyzs: B, N, P, 3
|
368 |
-
# goal_pc_pose: B, N, 4, 4
|
369 |
-
|
370 |
-
goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
|
371 |
-
return new_obj_xyzs, goal_pc_pose
|
372 |
-
|
373 |
-
|
374 |
-
def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
|
375 |
-
|
376 |
-
device = obj_xyzs.device
|
377 |
-
|
378 |
-
# obj_xyzs: B, N, P, 3 or 6
|
379 |
-
# struct_pose: B, 1, 4, 4
|
380 |
-
# pc_poses_in_struct: B, N, 4, 4
|
381 |
-
|
382 |
-
B, N, _, _ = pc_poses_in_struct.shape
|
383 |
-
_, _, P, _ = obj_xyzs.shape
|
384 |
-
|
385 |
-
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
|
386 |
-
# print(torch.mean(obj_xyzs, dim=2).shape)
|
387 |
-
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
|
388 |
-
current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
|
389 |
-
|
390 |
-
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
|
391 |
-
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
|
392 |
-
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
|
393 |
-
|
394 |
-
goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
|
395 |
-
# print("goal pc poses")
|
396 |
-
# print(goal_pc_pose)
|
397 |
-
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
|
398 |
-
|
399 |
-
# important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
400 |
-
transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
401 |
-
|
402 |
-
new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
|
403 |
-
new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
|
404 |
-
|
405 |
-
# put it back to B, N, P, 3
|
406 |
-
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
|
407 |
-
|
408 |
-
return new_obj_xyzs
|
409 |
-
|
410 |
-
|
411 |
def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
|
412 |
|
413 |
device = obj_xyzs.device
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
|
5 |
from StructDiffusion.utils.rearrangement import show_pcs_color_order, show_pcs_with_trimesh
|
6 |
from StructDiffusion.utils.pointnet import random_point_sample, index_points
|
7 |
+
import StructDiffusion.utils.tra3d as tra3d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
|
10 |
return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
|
|
|
28 |
goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
|
29 |
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose) # cur_batch_size x N, 4, 4
|
30 |
|
31 |
+
# # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
|
32 |
+
# transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
|
33 |
+
# # obj_xyzs: N, P, 3
|
34 |
+
# new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
|
35 |
+
# new_obj_xyzs = transpose.transform_points(new_obj_xyzs)
|
36 |
+
|
37 |
+
# a verision that does not rely on pytorch3d
|
38 |
+
new_obj_xyzs = obj_xyzs.repeat(B, 1, 1) # B x N, P, 3
|
39 |
+
new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
|
40 |
+
new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
|
41 |
|
|
|
|
|
|
|
42 |
|
43 |
# put it back to B, N, P, 3
|
44 |
new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
|
|
|
172 |
return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs
|
173 |
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
|
176 |
|
177 |
device = obj_xyzs.device
|
src/StructDiffusion/utils/files.py
CHANGED
@@ -1,9 +1,17 @@
|
|
1 |
import os
|
2 |
|
|
|
3 |
def get_checkpoint_path_from_dir(checkpoint_dir):
|
4 |
checkpoint_path = None
|
5 |
for file in os.listdir(checkpoint_dir):
|
6 |
if "ckpt" in file:
|
7 |
checkpoint_path = os.path.join(checkpoint_dir, file)
|
8 |
assert checkpoint_path is not None
|
9 |
-
return checkpoint_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
+
|
4 |
def get_checkpoint_path_from_dir(checkpoint_dir):
|
5 |
checkpoint_path = None
|
6 |
for file in os.listdir(checkpoint_dir):
|
7 |
if "ckpt" in file:
|
8 |
checkpoint_path = os.path.join(checkpoint_dir, file)
|
9 |
assert checkpoint_path is not None
|
10 |
+
return checkpoint_path
|
11 |
+
|
12 |
+
|
13 |
+
def replace_config_for_testing_data(cfg, testing_data_cfg):
|
14 |
+
cfg.DATASET.data_roots = testing_data_cfg.DATASET.data_roots
|
15 |
+
cfg.DATASET.index_roots = testing_data_cfg.DATASET.index_roots
|
16 |
+
cfg.DATASET.vocab_dir = testing_data_cfg.DATASET.vocab_dir
|
17 |
+
|
src/StructDiffusion/utils/np_speed_test.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import time
|
3 |
+
|
4 |
+
def numpy_test(w, h, r):
|
5 |
+
v, u = np.indices((h, w))
|
6 |
+
theta = (u - w / 2.) * 2 * np.pi / w
|
7 |
+
phi = (v - h / 2.) * np.pi / h
|
8 |
+
cos_phi = np.cos(phi)
|
9 |
+
x = cos_phi * np.sin(theta)
|
10 |
+
y = np.sin(phi)
|
11 |
+
z = cos_phi * np.cos(theta)
|
12 |
+
|
13 |
+
ray = np.dstack((x, y, z))
|
14 |
+
ray = ray.reshape(-1, 3).dot(r.T)
|
15 |
+
ray.shape = (h, w, 3)
|
16 |
+
|
17 |
+
x, y, z = np.dsplit(ray, 3)
|
18 |
+
theta = np.arctan2(x, z)
|
19 |
+
phi = np.arcsin(y)
|
20 |
+
u = theta * w / 2 / np.pi + w / 2.
|
21 |
+
v = phi * h / np.pi + h / 2.
|
22 |
+
xymap = np.dstack((u, v)).astype(np.float32)
|
23 |
+
|
24 |
+
return xymap
|
25 |
+
|
26 |
+
def matrix_multiplication():
|
27 |
+
for i in range(100):
|
28 |
+
np.random.random((1000, 1000)) @ np.random.random((1000, 1000))
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
w = 3584
|
33 |
+
h = int(w / 2)
|
34 |
+
r = np.array([
|
35 |
+
[0.61566148, -0.78369395, 0.08236955],
|
36 |
+
[0.78801075, 0.61228882, -0.06435415],
|
37 |
+
[0, 0.10452846, 0.9945219],])
|
38 |
+
begin_time = time.time()
|
39 |
+
# numpy_test(w, h, r)
|
40 |
+
matrix_multiplication()
|
41 |
+
print(time.time() - begin_time)
|
src/StructDiffusion/utils/rearrangement.py
CHANGED
@@ -558,9 +558,12 @@ def fit_gaussians(samples, sigma_eps=0.01):
|
|
558 |
return mus, sigmas
|
559 |
|
560 |
|
561 |
-
def show_pcs_with_trimesh(obj_xyzs, obj_rgbs, return_scene=False):
|
562 |
-
|
563 |
-
|
|
|
|
|
|
|
564 |
scene = trimesh.Scene()
|
565 |
# add the coordinate frame first
|
566 |
geom = trimesh.creation.axis(0.01)
|
@@ -582,13 +585,35 @@ def show_pcs_with_trimesh(obj_xyzs, obj_rgbs, return_scene=False):
|
|
582 |
RT_4x4 = np.linalg.inv(RT_4x4)
|
583 |
RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
584 |
scene.camera_transform = RT_4x4
|
585 |
-
|
586 |
if return_scene:
|
587 |
return scene
|
588 |
else:
|
589 |
scene.show()
|
590 |
|
591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
def show_pcs_with_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True):
|
593 |
""" Display point clouds """
|
594 |
|
|
|
558 |
return mus, sigmas
|
559 |
|
560 |
|
561 |
+
def show_pcs_with_trimesh(obj_xyzs, obj_rgbs=None, return_scene=False):
|
562 |
+
if obj_rgbs is not None:
|
563 |
+
vis_pcs = [trimesh.PointCloud(obj_xyz, colors=np.concatenate([obj_rgb * 255, np.ones([obj_rgb.shape[0], 1]) * 255], axis=-1)) for
|
564 |
+
obj_xyz, obj_rgb in zip(obj_xyzs, obj_rgbs)]
|
565 |
+
else:
|
566 |
+
vis_pcs = [trimesh.PointCloud(obj_xyz) for obj_xyz in obj_xyzs]
|
567 |
scene = trimesh.Scene()
|
568 |
# add the coordinate frame first
|
569 |
geom = trimesh.creation.axis(0.01)
|
|
|
585 |
RT_4x4 = np.linalg.inv(RT_4x4)
|
586 |
RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
587 |
scene.camera_transform = RT_4x4
|
|
|
588 |
if return_scene:
|
589 |
return scene
|
590 |
else:
|
591 |
scene.show()
|
592 |
|
593 |
|
594 |
+
def get_trimesh_scene_with_table():
|
595 |
+
scene = trimesh.Scene()
|
596 |
+
# add the coordinate frame first
|
597 |
+
geom = trimesh.creation.axis(0.01)
|
598 |
+
scene.add_geometry(geom)
|
599 |
+
table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
|
600 |
+
table.apply_translation([0.5, 0, -0.01])
|
601 |
+
table.visual.vertex_colors = [150, 111, 87, 125]
|
602 |
+
scene.add_geometry(table)
|
603 |
+
# bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
|
604 |
+
bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
|
605 |
+
bounds.apply_translation([0, 0, 0])
|
606 |
+
bounds.visual.vertex_colors = [30, 30, 30, 30]
|
607 |
+
# scene.add_geometry(bounds)
|
608 |
+
RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
|
609 |
+
[-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
|
610 |
+
[0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
|
611 |
+
[0.0, 0.0, 0.0, 1.0]])
|
612 |
+
RT_4x4 = np.linalg.inv(RT_4x4)
|
613 |
+
RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
|
614 |
+
scene.camera_transform = RT_4x4
|
615 |
+
return scene
|
616 |
+
|
617 |
def show_pcs_with_predictions(xyz, rgb, gts, predictions, add_coordinate_frame=False, return_buffer=False, add_table=True, side_view=True):
|
618 |
""" Display point clouds """
|
619 |
|
src/StructDiffusion/utils/tra3d.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
# source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_euler_angles
|
5 |
+
# we don't want to build pytorch3d, so only pick functions we need to use
|
6 |
+
|
7 |
+
def _index_from_letter(letter: str) -> int:
|
8 |
+
if letter == "X":
|
9 |
+
return 0
|
10 |
+
if letter == "Y":
|
11 |
+
return 1
|
12 |
+
if letter == "Z":
|
13 |
+
return 2
|
14 |
+
raise ValueError("letter must be either X, Y or Z.")
|
15 |
+
|
16 |
+
|
17 |
+
def _angle_from_tan(
|
18 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
19 |
+
) -> torch.Tensor:
|
20 |
+
"""
|
21 |
+
Extract the first or third Euler angle from the two members of
|
22 |
+
the matrix which are positive constant times its sine and cosine.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
26 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
27 |
+
convention.
|
28 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
29 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
30 |
+
which means the relevant entries are in the same row of the
|
31 |
+
rotation matrix. If not, they are in the same column.
|
32 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Euler Angles in radians for each matrix in data as a tensor
|
36 |
+
of shape (...).
|
37 |
+
"""
|
38 |
+
|
39 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
40 |
+
if horizontal:
|
41 |
+
i2, i1 = i1, i2
|
42 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
43 |
+
if horizontal == even:
|
44 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
45 |
+
if tait_bryan:
|
46 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
47 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
48 |
+
|
49 |
+
|
50 |
+
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
51 |
+
"""
|
52 |
+
Return the rotation matrices for one of the rotations about an axis
|
53 |
+
of which Euler angles describe, for each value of the angle given.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
axis: Axis label "X" or "Y or "Z".
|
57 |
+
angle: any shape tensor of Euler angles in radians
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
61 |
+
"""
|
62 |
+
|
63 |
+
cos = torch.cos(angle)
|
64 |
+
sin = torch.sin(angle)
|
65 |
+
one = torch.ones_like(angle)
|
66 |
+
zero = torch.zeros_like(angle)
|
67 |
+
|
68 |
+
if axis == "X":
|
69 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
70 |
+
elif axis == "Y":
|
71 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
72 |
+
elif axis == "Z":
|
73 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
74 |
+
else:
|
75 |
+
raise ValueError("letter must be either X, Y or Z.")
|
76 |
+
|
77 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
78 |
+
|
79 |
+
|
80 |
+
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
81 |
+
"""
|
82 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
86 |
+
convention: Convention string of three uppercase letters.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Euler angles in radians as tensor of shape (..., 3).
|
90 |
+
"""
|
91 |
+
if len(convention) != 3:
|
92 |
+
raise ValueError("Convention must have 3 letters.")
|
93 |
+
if convention[1] in (convention[0], convention[2]):
|
94 |
+
raise ValueError(f"Invalid convention {convention}.")
|
95 |
+
for letter in convention:
|
96 |
+
if letter not in ("X", "Y", "Z"):
|
97 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
98 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
99 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
100 |
+
i0 = _index_from_letter(convention[0])
|
101 |
+
i2 = _index_from_letter(convention[2])
|
102 |
+
tait_bryan = i0 != i2
|
103 |
+
if tait_bryan:
|
104 |
+
central_angle = torch.asin(
|
105 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
109 |
+
|
110 |
+
o = (
|
111 |
+
_angle_from_tan(
|
112 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
113 |
+
),
|
114 |
+
central_angle,
|
115 |
+
_angle_from_tan(
|
116 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
117 |
+
),
|
118 |
+
)
|
119 |
+
return torch.stack(o, -1)
|
120 |
+
|
121 |
+
|
122 |
+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
|
123 |
+
"""
|
124 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
128 |
+
convention: Convention string of three uppercase letters from
|
129 |
+
{"X", "Y", and "Z"}.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
133 |
+
"""
|
134 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
135 |
+
raise ValueError("Invalid input euler angles.")
|
136 |
+
if len(convention) != 3:
|
137 |
+
raise ValueError("Convention must have 3 letters.")
|
138 |
+
if convention[1] in (convention[0], convention[2]):
|
139 |
+
raise ValueError(f"Invalid convention {convention}.")
|
140 |
+
for letter in convention:
|
141 |
+
if letter not in ("X", "Y", "Z"):
|
142 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
143 |
+
matrices = [
|
144 |
+
_axis_angle_rotation(c, e)
|
145 |
+
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
146 |
+
]
|
147 |
+
# return functools.reduce(torch.matmul, matrices)
|
148 |
+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
tmp_data/input_scene.glb
ADDED
Binary file (121 kB). View file
|
|
tmp_data/input_scene_102.glb
ADDED
Binary file (80.8 kB). View file
|
|
tmp_data/input_scene_None.glb
ADDED
Binary file (66.3 kB). View file
|
|
tmp_data/output_scene.glb
ADDED
Binary file (121 kB). View file
|
|
tmp_data/output_scene_102.glb
ADDED
Binary file (91.3 kB). View file
|
|
wandb_logs/StructDiffusion/CollisionDiscriminator/checkpoints/epoch=199-step=653400.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed90c4976e69f96324d0245ecc635fa987d66d84c512d1e2105ce9f7f3df39ea
|
3 |
+
size 34533413
|
wandb_logs/StructDiffusion/ConditionalPoseDiffusionLanguage/checkpoints/epoch=199-step=100000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edd6c08c5fafd129f365fd70fc1b6ad68e643a9111e522432279afb7ba387a89
|
3 |
+
size 59947673
|