Weiyu Liu commited on
Commit
405900e
·
1 Parent(s): b7dd541

remove dep on pytorch3d

Browse files
app.py CHANGED
@@ -16,10 +16,50 @@ from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
16
  from StructDiffusion.diffusion.sampler import Sampler
17
  from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
  from StructDiffusion.utils.files import get_checkpoint_path_from_dir
19
- from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
20
  from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class Infer_Wrapper:
24
 
25
  def __init__(self, args, cfg):
 
16
  from StructDiffusion.diffusion.sampler import Sampler
17
  from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
18
  from StructDiffusion.utils.files import get_checkpoint_path_from_dir
 
19
  from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
20
 
21
 
22
+ def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
23
+
24
+ device = obj_xyzs.device
25
+
26
+ # obj_xyzs: B, N, P, 3 or 6
27
+ # struct_pose: B, 1, 4, 4
28
+ # pc_poses_in_struct: B, N, 4, 4
29
+
30
+ B, N, _, _ = pc_poses_in_struct.shape
31
+ _, _, P, _ = obj_xyzs.shape
32
+
33
+ current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
34
+ # print(torch.mean(obj_xyzs, dim=2).shape)
35
+ current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
36
+ current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
37
+
38
+ struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
39
+ struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
40
+ pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
41
+
42
+ goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
43
+ # print("goal pc poses")
44
+ # print(goal_pc_pose)
45
+ goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
46
+
47
+ # # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
48
+ # transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
49
+ # new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
50
+ # new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
51
+
52
+ # a verision that does not rely on pytorch3d
53
+ new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3
54
+ new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
55
+ new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
56
+
57
+ # put it back to B, N, P, 3
58
+ obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1)
59
+
60
+ return obj_xyzs
61
+
62
+
63
  class Infer_Wrapper:
64
 
65
  def __init__(self, args, cfg):
requirements.txt CHANGED
@@ -7,8 +7,4 @@ pyglet==1.5.0
7
  openpyxl
8
  pytorch_lightning==1.6.1
9
  wandb===0.13.10
10
- omegaconf==2.2.2
11
- torch==1.12.0+cpu
12
- torchvision==0.13.0+cpu
13
- torchaudio==0.12.0
14
- git+https://github.com/facebookresearch/pytorch3d.git@stable
 
7
  openpyxl
8
  pytorch_lightning==1.6.1
9
  wandb===0.13.10
10
+ omegaconf==2.2.2
 
 
 
 
src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/pose_conversion.cpython-38.pyc differ
 
src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc CHANGED
Binary files a/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc and b/src/StructDiffusion/diffusion/__pycache__/sampler.cpython-38.pyc differ
 
src/StructDiffusion/diffusion/pose_conversion.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import torch
3
- import pytorch3d.transforms as tra3d
4
 
5
  from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d
6
 
 
1
  import os
2
  import torch
 
3
 
4
  from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d
5
 
src/StructDiffusion/diffusion/sampler.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from tqdm import tqdm
3
- import pytorch3d.transforms as tra3d
4
 
5
  from StructDiffusion.diffusion.noise_schedule import extract
6
  from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
@@ -61,236 +61,236 @@ class Sampler:
61
  xs = list(reversed(xs))
62
  return xs
63
 
64
- class SamplerV2:
65
-
66
- def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
67
- collision_model_class, collision_checkpoint_path,
68
- device, debug=False):
69
-
70
- self.debug = debug
71
- self.device = device
72
-
73
- self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
74
- self.diffusion_backbone = self.diffusion_model.model
75
- self.diffusion_backbone.to(device)
76
- self.diffusion_backbone.eval()
77
-
78
- self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
79
- self.collision_backbone = self.collision_model.model
80
- self.collision_backbone.to(device)
81
- self.collision_backbone.eval()
82
-
83
- def sample(self, batch, num_poses):
84
-
85
- noise_schedule = self.diffusion_model.noise_schedule
86
-
87
- B = batch["pcs"].shape[0]
88
-
89
- x_noisy = torch.randn((B, num_poses, 9), device=self.device)
90
-
91
- xs = []
92
- for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
93
- desc='sampling loop time step', total=noise_schedule.timesteps):
94
-
95
- t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
96
-
97
- # noise schedule
98
- betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
99
- sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
100
- sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
101
-
102
- # predict noise
103
- pcs = batch["pcs"]
104
- sentence = batch["sentence"]
105
- type_index = batch["type_index"]
106
- position_index = batch["position_index"]
107
- pad_mask = batch["pad_mask"]
108
- # calling the backbone instead of the pytorch-lightning model
109
- with torch.no_grad():
110
- predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
111
-
112
- # compute noisy x at t
113
- model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
114
- if t_index == 0:
115
- x_noisy = model_mean
116
- else:
117
- posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
118
- noise = torch.randn_like(x_noisy)
119
- x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
120
-
121
- xs.append(x_noisy)
122
-
123
- xs = list(reversed(xs))
124
-
125
- visualize = True
126
-
127
- struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
128
- # struct_pose: B, 1, 4, 4
129
- # pc_poses_in_struct: B, N, 4, 4
130
-
131
- S = B
132
- num_elite = 10
133
- ####################################################
134
- # only keep one copy
135
-
136
- # N, P, 3
137
- obj_xyzs = batch["pcs"][0][:, :, :3]
138
- print("obj_xyzs shape", obj_xyzs.shape)
139
-
140
- # 1, N
141
- # object_pad_mask: padding location has 1
142
- num_target_objs = num_poses
143
- if self.diffusion_backbone.use_virtual_structure_frame:
144
- num_target_objs -= 1
145
- object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
146
- target_object_inds = 1 - object_pad_mask
147
- print("target_object_inds shape", target_object_inds.shape)
148
- print("target_object_inds", target_object_inds)
149
-
150
- N, P, _ = obj_xyzs.shape
151
- print("S, N, P: {}, {}, {}".format(S, N, P))
152
-
153
- ####################################################
154
- # S, N, ...
155
-
156
- struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
157
- struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
158
-
159
- new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
160
- current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
161
- current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
162
- current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
163
-
164
- # optimize xyzrpy
165
- obj_params = torch.zeros((S, N, 6)).to(self.device)
166
- obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
167
- obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
168
- #
169
- # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
170
- #
171
- # if visualize:
172
- # print("visualizing rearrangements predicted by the generator")
173
- # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
174
-
175
- ####################################################
176
- # rank
177
-
178
- # evaluate in batches
179
- scores = torch.zeros(S).to(self.device)
180
- no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
181
- num_batches = int(S / B)
182
- if S % B != 0:
183
- num_batches += 1
184
- for b in range(num_batches):
185
- if b + 1 == num_batches:
186
- cur_batch_idxs_start = b * B
187
- cur_batch_idxs_end = S
188
- else:
189
- cur_batch_idxs_start = b * B
190
- cur_batch_idxs_end = (b + 1) * B
191
- cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
192
-
193
- # print("current batch idxs start", cur_batch_idxs_start)
194
- # print("current batch idxs end", cur_batch_idxs_end)
195
- # print("size of the current batch", cur_batch_size)
196
-
197
- batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
198
- batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
199
- batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
200
-
201
- new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
202
- move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
203
- target_object_inds, self.device,
204
- return_scene_pts=False,
205
- return_scene_pts_and_pc_idxs=False,
206
- num_scene_pts=False,
207
- normalize_pc=False,
208
- return_pair_pc=True,
209
- num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
210
- normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
211
-
212
- #######################################
213
- # predict whether there are pairwise collisions
214
- # if collision_score_weight > 0:
215
- with torch.no_grad():
216
- _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
217
- # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
218
- collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
219
- collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
220
-
221
- # debug
222
- # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
223
- # print("batch id", bi)
224
- # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
225
- # print("pair", pi)
226
- # # obj_pair_xyzs: 2 * P, 5
227
- # print("collision score", collision_scores[bi, pi])
228
- # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
229
-
230
- # 1 - mean() since the collision model predicts 1 if there is a collision
231
- no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
232
- if visualize:
233
- print("no intersection scores", no_intersection_scores)
234
- # #######################################
235
- # if discriminator_score_weight > 0:
236
- # # # debug:
237
- # # print(subsampled_scene_xyz.shape)
238
- # # print(subsampled_scene_xyz[0])
239
- # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
240
- # #
241
- # with torch.no_grad():
242
- #
243
- # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
244
- # # local_sentence = sentence[:, [0, 4]]
245
- # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
246
- # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
247
- #
248
- # sentence_disc = torch.LongTensor(
249
- # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
250
- # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
251
- # position_index_dic = torch.LongTensor(raw_position_index_discriminator)
252
- #
253
- # preds = discriminator_model.forward(subsampled_scene_xyz,
254
- # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
255
- # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
256
- # 1).to(device),
257
- # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
258
- # device))
259
- # # preds = discriminator_model.forward(subsampled_scene_xyz)
260
- # preds = discriminator_model.convert_logits(preds)
261
- # preds = preds["is_circle"] # cur_batch_size,
262
- # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
263
- # if visualize:
264
- # print("discriminator scores", scores)
265
-
266
- # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
267
- scores = no_intersection_scores
268
- sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
269
- elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
270
- elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
271
- elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
272
- elite_scores = scores[sort_idx]
273
- print("elite scores:", elite_scores)
274
-
275
- ####################################################
276
- # # visualize best samples
277
- # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
278
- # batch_current_pc_pose = current_pc_pose[0: num_elite * N]
279
- # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
280
- # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
281
- # target_object_inds, self.device,
282
- # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
283
- # if visualize:
284
- # print("visualizing elite rearrangements ranked by collision model/discriminator")
285
- # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
286
-
287
- # num_elite, N, 6
288
- elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
289
- pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
290
- pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
291
- pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
292
- pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
293
-
294
- struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
295
-
296
- return struct_pose, pc_poses_in_struct
 
1
  import torch
2
  from tqdm import tqdm
3
+ # import pytorch3d.transforms as tra3d
4
 
5
  from StructDiffusion.diffusion.noise_schedule import extract
6
  from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
 
61
  xs = list(reversed(xs))
62
  return xs
63
 
64
+ # class SamplerV2:
65
+ #
66
+ # def __init__(self, diffusion_model_class, diffusion_checkpoint_path,
67
+ # collision_model_class, collision_checkpoint_path,
68
+ # device, debug=False):
69
+ #
70
+ # self.debug = debug
71
+ # self.device = device
72
+ #
73
+ # self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path)
74
+ # self.diffusion_backbone = self.diffusion_model.model
75
+ # self.diffusion_backbone.to(device)
76
+ # self.diffusion_backbone.eval()
77
+ #
78
+ # self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path)
79
+ # self.collision_backbone = self.collision_model.model
80
+ # self.collision_backbone.to(device)
81
+ # self.collision_backbone.eval()
82
+ #
83
+ # def sample(self, batch, num_poses):
84
+ #
85
+ # noise_schedule = self.diffusion_model.noise_schedule
86
+ #
87
+ # B = batch["pcs"].shape[0]
88
+ #
89
+ # x_noisy = torch.randn((B, num_poses, 9), device=self.device)
90
+ #
91
+ # xs = []
92
+ # for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)),
93
+ # desc='sampling loop time step', total=noise_schedule.timesteps):
94
+ #
95
+ # t = torch.full((B,), t_index, device=self.device, dtype=torch.long)
96
+ #
97
+ # # noise schedule
98
+ # betas_t = extract(noise_schedule.betas, t, x_noisy.shape)
99
+ # sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
100
+ # sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape)
101
+ #
102
+ # # predict noise
103
+ # pcs = batch["pcs"]
104
+ # sentence = batch["sentence"]
105
+ # type_index = batch["type_index"]
106
+ # position_index = batch["position_index"]
107
+ # pad_mask = batch["pad_mask"]
108
+ # # calling the backbone instead of the pytorch-lightning model
109
+ # with torch.no_grad():
110
+ # predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)
111
+ #
112
+ # # compute noisy x at t
113
+ # model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t)
114
+ # if t_index == 0:
115
+ # x_noisy = model_mean
116
+ # else:
117
+ # posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape)
118
+ # noise = torch.randn_like(x_noisy)
119
+ # x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise
120
+ #
121
+ # xs.append(x_noisy)
122
+ #
123
+ # xs = list(reversed(xs))
124
+ #
125
+ # visualize = True
126
+ #
127
+ # struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
128
+ # # struct_pose: B, 1, 4, 4
129
+ # # pc_poses_in_struct: B, N, 4, 4
130
+ #
131
+ # S = B
132
+ # num_elite = 10
133
+ # ####################################################
134
+ # # only keep one copy
135
+ #
136
+ # # N, P, 3
137
+ # obj_xyzs = batch["pcs"][0][:, :, :3]
138
+ # print("obj_xyzs shape", obj_xyzs.shape)
139
+ #
140
+ # # 1, N
141
+ # # object_pad_mask: padding location has 1
142
+ # num_target_objs = num_poses
143
+ # if self.diffusion_backbone.use_virtual_structure_frame:
144
+ # num_target_objs -= 1
145
+ # object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0)
146
+ # target_object_inds = 1 - object_pad_mask
147
+ # print("target_object_inds shape", target_object_inds.shape)
148
+ # print("target_object_inds", target_object_inds)
149
+ #
150
+ # N, P, _ = obj_xyzs.shape
151
+ # print("S, N, P: {}, {}, {}".format(S, N, P))
152
+ #
153
+ # ####################################################
154
+ # # S, N, ...
155
+ #
156
+ # struct_pose = struct_pose.repeat(1, N, 1, 1) # S, N, 4, 4
157
+ # struct_pose = struct_pose.reshape(S * N, 4, 4) # S x N, 4, 4
158
+ #
159
+ # new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) # S, N, P, 3
160
+ # current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) # S, N, 4, 4
161
+ # current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) # S, N, 4, 4
162
+ # current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) # S x N, 4, 4
163
+ #
164
+ # # optimize xyzrpy
165
+ # obj_params = torch.zeros((S, N, 6)).to(self.device)
166
+ # obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3]
167
+ # obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") # S, N, 6
168
+ # #
169
+ # # new_obj_xyzs_before_cem, goal_pc_pose_before_cem = move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device)
170
+ # #
171
+ # # if visualize:
172
+ # # print("visualizing rearrangements predicted by the generator")
173
+ # # visualize_batch_pcs(new_obj_xyzs_before_cem, S, N, P, limit_B=5)
174
+ #
175
+ # ####################################################
176
+ # # rank
177
+ #
178
+ # # evaluate in batches
179
+ # scores = torch.zeros(S).to(self.device)
180
+ # no_intersection_scores = torch.zeros(S).to(self.device) # the higher the better
181
+ # num_batches = int(S / B)
182
+ # if S % B != 0:
183
+ # num_batches += 1
184
+ # for b in range(num_batches):
185
+ # if b + 1 == num_batches:
186
+ # cur_batch_idxs_start = b * B
187
+ # cur_batch_idxs_end = S
188
+ # else:
189
+ # cur_batch_idxs_start = b * B
190
+ # cur_batch_idxs_end = (b + 1) * B
191
+ # cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start
192
+ #
193
+ # # print("current batch idxs start", cur_batch_idxs_start)
194
+ # # print("current batch idxs end", cur_batch_idxs_end)
195
+ # # print("size of the current batch", cur_batch_size)
196
+ #
197
+ # batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end]
198
+ # batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N]
199
+ # batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N]
200
+ #
201
+ # new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \
202
+ # move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose,
203
+ # target_object_inds, self.device,
204
+ # return_scene_pts=False,
205
+ # return_scene_pts_and_pc_idxs=False,
206
+ # num_scene_pts=False,
207
+ # normalize_pc=False,
208
+ # return_pair_pc=True,
209
+ # num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts,
210
+ # normalize_pair_pc=self.collision_model.data_cfg.normalize_pc)
211
+ #
212
+ # #######################################
213
+ # # predict whether there are pairwise collisions
214
+ # # if collision_score_weight > 0:
215
+ # with torch.no_grad():
216
+ # _, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape
217
+ # # obj_pair_xyzs = obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)
218
+ # collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1))
219
+ # collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) # cur_batch_size, num_comb
220
+ #
221
+ # # debug
222
+ # # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
223
+ # # print("batch id", bi)
224
+ # # for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
225
+ # # print("pair", pi)
226
+ # # # obj_pair_xyzs: 2 * P, 5
227
+ # # print("collision score", collision_scores[bi, pi])
228
+ # # trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()
229
+ #
230
+ # # 1 - mean() since the collision model predicts 1 if there is a collision
231
+ # no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1)
232
+ # if visualize:
233
+ # print("no intersection scores", no_intersection_scores)
234
+ # # #######################################
235
+ # # if discriminator_score_weight > 0:
236
+ # # # # debug:
237
+ # # # print(subsampled_scene_xyz.shape)
238
+ # # # print(subsampled_scene_xyz[0])
239
+ # # # trimesh.PointCloud(subsampled_scene_xyz[0, :, :3].cpu().numpy()).show()
240
+ # # #
241
+ # # with torch.no_grad():
242
+ # #
243
+ # # # Important: since this discriminator only uses local structure param, takes sentence from the first and last position
244
+ # # # local_sentence = sentence[:, [0, 4]]
245
+ # # # local_sentence_pad_mask = sentence_pad_mask[:, [0, 4]]
246
+ # # # sentence_disc, sentence_pad_mask_disc, position_index_dic = discriminator_inference.dataset.tensorfy_sentence(raw_sentence_discriminator, raw_sentence_pad_mask_discriminator, raw_position_index_discriminator)
247
+ # #
248
+ # # sentence_disc = torch.LongTensor(
249
+ # # [discriminator_tokenizer.tokenize(*i) for i in raw_sentence_discriminator])
250
+ # # sentence_pad_mask_disc = torch.LongTensor(raw_sentence_pad_mask_discriminator)
251
+ # # position_index_dic = torch.LongTensor(raw_position_index_discriminator)
252
+ # #
253
+ # # preds = discriminator_model.forward(subsampled_scene_xyz,
254
+ # # sentence_disc.unsqueeze(0).repeat(cur_batch_size, 1).to(device),
255
+ # # sentence_pad_mask_disc.unsqueeze(0).repeat(cur_batch_size,
256
+ # # 1).to(device),
257
+ # # position_index_dic.unsqueeze(0).repeat(cur_batch_size, 1).to(
258
+ # # device))
259
+ # # # preds = discriminator_model.forward(subsampled_scene_xyz)
260
+ # # preds = discriminator_model.convert_logits(preds)
261
+ # # preds = preds["is_circle"] # cur_batch_size,
262
+ # # scores[cur_batch_idxs_start:cur_batch_idxs_end] = preds
263
+ # # if visualize:
264
+ # # print("discriminator scores", scores)
265
+ #
266
+ # # scores = scores * discriminator_score_weight + no_intersection_scores * collision_score_weight
267
+ # scores = no_intersection_scores
268
+ # sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite]
269
+ # elite_obj_params = obj_params[sort_idx] # num_elite, N, 6
270
+ # elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] # num_elite, N, 4, 4
271
+ # elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) # num_elite x N, 4, 4
272
+ # elite_scores = scores[sort_idx]
273
+ # print("elite scores:", elite_scores)
274
+ #
275
+ # ####################################################
276
+ # # # visualize best samples
277
+ # # num_scene_pts = 4096 # if discriminator_num_scene_pts is None else discriminator_num_scene_pts
278
+ # # batch_current_pc_pose = current_pc_pose[0: num_elite * N]
279
+ # # best_new_obj_xyzs, best_goal_pc_pose, best_subsampled_scene_xyz, _, _ = \
280
+ # # move_pc_and_create_scene_new(obj_xyzs, elite_obj_params, elite_struct_poses, batch_current_pc_pose,
281
+ # # target_object_inds, self.device,
282
+ # # return_scene_pts=True, num_scene_pts=num_scene_pts, normalize_pc=True)
283
+ # # if visualize:
284
+ # # print("visualizing elite rearrangements ranked by collision model/discriminator")
285
+ # # visualize_batch_pcs(best_new_obj_xyzs, num_elite, limit_B=num_elite)
286
+ #
287
+ # # num_elite, N, 6
288
+ # elite_obj_params = elite_obj_params.reshape(num_elite * N, -1)
289
+ # pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device)
290
+ # pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ")
291
+ # pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3]
292
+ # pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) # num_elite, N, 4, 4
293
+ #
294
+ # struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) # num_elite, 1, 4, 4
295
+ #
296
+ # return struct_pose, pc_poses_in_struct