thuzhaowang commited on
Commit
b6a9b6d
·
1 Parent(s): ad31d76
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitmodules +3 -0
  2. README.md +1 -0
  3. app.py +258 -0
  4. configs/demo/basket_demo.yaml +20 -0
  5. configs/demo/chair_demo.yaml +20 -0
  6. configs/demo/dandelion_demo.yaml +20 -0
  7. configs/demo/flower_demo.yaml +20 -0
  8. configs/demo/table_demo.yaml +20 -0
  9. configs/demo/vase_demo.yaml +20 -0
  10. configs/infinigen/base.gin +89 -0
  11. configs/test/basket_test.yaml +24 -0
  12. configs/test/chair_test.yaml +24 -0
  13. configs/test/dandelion_test.yaml +24 -0
  14. configs/test/flower_test.yaml +24 -0
  15. configs/test/table_test.yaml +24 -0
  16. configs/test/vase_test.yaml +24 -0
  17. configs/train/basket_train.yaml +17 -0
  18. configs/train/chair_train.yaml +17 -0
  19. configs/train/dandelion_train.yaml +17 -0
  20. configs/train/flower_train.yaml +17 -0
  21. configs/train/table_train.yaml +17 -0
  22. configs/train/vase_train.yaml +17 -0
  23. core/__pycache__/dataset.cpython-310.pyc +0 -0
  24. core/__pycache__/models.cpython-310.pyc +0 -0
  25. core/assets/__pycache__/basket.cpython-310.pyc +0 -0
  26. core/assets/__pycache__/chair.cpython-310.pyc +0 -0
  27. core/assets/__pycache__/dandelion.cpython-310.pyc +0 -0
  28. core/assets/__pycache__/flower.cpython-310.pyc +0 -0
  29. core/assets/__pycache__/table.cpython-310.pyc +0 -0
  30. core/assets/__pycache__/vase.cpython-310.pyc +0 -0
  31. core/assets/basket.py +576 -0
  32. core/assets/chair.py +657 -0
  33. core/assets/dandelion.py +1097 -0
  34. core/assets/flower.py +1002 -0
  35. core/assets/table.py +493 -0
  36. core/assets/vase.py +486 -0
  37. core/dataset.py +40 -0
  38. core/diffusion/__init__.py +46 -0
  39. core/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  40. core/diffusion/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
  41. core/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc +0 -0
  42. core/diffusion/__pycache__/respace.cpython-310.pyc +0 -0
  43. core/diffusion/diffusion_utils.py +88 -0
  44. core/diffusion/gaussian_diffusion.py +873 -0
  45. core/diffusion/respace.py +129 -0
  46. core/diffusion/timestep_sampler.py +150 -0
  47. core/models.py +331 -0
  48. core/utils/__pycache__/camera.cpython-310.pyc +0 -0
  49. core/utils/__pycache__/dinov2.cpython-310.pyc +0 -0
  50. core/utils/__pycache__/io.cpython-310.pyc +0 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/infinigen"]
2
+ path = third_party/infinigen
3
+ url = https://github.com/princeton-vl/infinigen
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.9.1
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
+ python_version: 3.10.14
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
3
+ import yaml
4
+ import numpy as np
5
+ from PIL import Image
6
+ import rembg
7
+ import importlib
8
+ import torch
9
+ import tempfile
10
+ import json
11
+ #import spaces
12
+ from core.models import DiT_models
13
+ from core.diffusion import create_diffusion
14
+ from core.utils.dinov2 import Dinov2Model
15
+ from core.utils.math_utils import unnormalize_params
16
+
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ # Setup PyTorch:
20
+ device = torch.device('cuda')
21
+
22
+ # Define the cache directory for model files
23
+ #model_cache_dir = './ckpts/'
24
+ #os.makedirs(model_cache_dir, exist_ok=True)
25
+
26
+ # load generators & models
27
+ generators_choices = ["chair", "table", "vase", "basket", "flower", "dandelion"]
28
+ factory_names = ["ChairFactory", "TableDiningFactory", "VaseFactory", "BasketBaseFactory", "FlowerFactory", "DandelionFactory"]
29
+ generator_path = "./core/assets/"
30
+ generators, configs, models = [], [], []
31
+ for category, factory in zip(generators_choices, factory_names):
32
+ # load generator
33
+ module = importlib.import_module(f"core.assets.{category}")
34
+ gen = getattr(module, factory)
35
+ generator = gen(0)
36
+ generators.append(generator)
37
+ # load configs
38
+ config_path = f"./configs/demo/{category}_demo.yaml"
39
+ with open(config_path) as f:
40
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
41
+ configs.append(cfg)
42
+ # load models
43
+ latent_size = cfg["num_params"]
44
+ model = DiT_models[cfg["model"]](input_size=latent_size).to(device)
45
+ # load a custom DiT checkpoint from train.py:
46
+ # download the checkpoint if not found:
47
+ if not os.path.exists(cfg["ckpt_path"]):
48
+ model_dir, model_name = os.path.dirname(cfg["ckpt_path"]), os.path.basename(cfg["ckpt_path"])
49
+ os.makedirs(model_dir, exist_ok=True)
50
+ checkpoint_path = hf_hub_download(repo_id="TencentARC/DI-PCG",
51
+ local_dir=model_dir, filename=model_name)
52
+ print("Downloading checkpoint {} from Hugging Face Hub...".format(model_name))
53
+ print("Loading model from {}".format(cfg["ckpt_path"]))
54
+
55
+ state_dict = torch.load(cfg["ckpt_path"], map_location=lambda storage, loc: storage)
56
+ if "ema" in state_dict: # supports checkpoints from train.py
57
+ state_dict = state_dict["ema"]
58
+ model.load_state_dict(state_dict)
59
+ model.eval()
60
+ models.append(model)
61
+
62
+ diffusion = create_diffusion(str(cfg["num_sampling_steps"]))
63
+ # feature model
64
+ feature_model = Dinov2Model()
65
+
66
+
67
+ def check_input_image(input_image):
68
+ if input_image is None:
69
+ raise gr.Error("No image uploaded!")
70
+
71
+
72
+ def preprocess(input_image, do_remove_background):
73
+ # resize
74
+ if input_image.size[0] != 256 or input_image.size[1] != 256:
75
+ input_image = input_image.resize((256, 256))
76
+ # remove background
77
+ if do_remove_background:
78
+ processed_image = rembg.remove(np.array(input_image))
79
+ # white background
80
+ else:
81
+ processed_image = input_image
82
+ return processed_image
83
+
84
+ #@spaces.GPU
85
+ def sample(image, seed, category):
86
+ # seed
87
+ np.random.seed(seed)
88
+ torch.manual_seed(seed)
89
+ # generator & model
90
+ idx = generators_choices.index(category)
91
+ generator, cfg, model = generators[idx], configs[idx], models[idx]
92
+
93
+ # encode condition image feature
94
+ # convert RGBA images to RGB, white background
95
+ input_image_np = np.array(image)
96
+ mask = input_image_np[:, :, -1:] > 0
97
+ input_image_np = input_image_np[:, :, :3] * mask + 255 * (1 - mask)
98
+ image = input_image_np.astype(np.uint8)
99
+
100
+ img_feat = feature_model.encode_batch_imgs([np.array(image)], global_feat=False)
101
+
102
+ # Create sampling noise:
103
+ latent_size = int(cfg['num_params'])
104
+ z = torch.randn(1, 1, latent_size, device=device)
105
+ y = img_feat
106
+
107
+ # No classifier-free guidance:
108
+ model_kwargs = dict(y=y)
109
+
110
+ # Sample target params:
111
+ samples = diffusion.p_sample_loop(
112
+ model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
113
+ )
114
+ samples = samples[0].squeeze(0).cpu().numpy()
115
+
116
+ # unnormalize params
117
+ params_dict = generator.params_dict
118
+ params_original = unnormalize_params(samples, params_dict)
119
+
120
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False).name
121
+ params_fpath = tempfile.NamedTemporaryFile(suffix=f".npy", delete=False).name
122
+ np.save(params_fpath, params_original)
123
+ print(mesh_fpath)
124
+ print(params_fpath)
125
+ # generate 3D using sampled params - TODO: this is a hacky way to go through PCG pipeline, avoiding conflict with gradio
126
+ command = f"python ./scripts/generate.py --config ./configs/demo/{category}_demo.yaml --output_path {mesh_fpath} --seed {seed} --params_path {params_fpath}"
127
+ os.system(command)
128
+
129
+ return mesh_fpath
130
+
131
+
132
+ import gradio as gr
133
+
134
+ _HEADER_ = '''
135
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/DI-PCG' target='_blank'><b>DI-PCG: Diffusion-based Efficient Inverse Procedural Content Generation for High-quality 3D Asset Creation</b></a></h2>
136
+
137
+ **DI-PCG** is a diffusion model which directly generates a procedural generator's parameters from a single image, resulting in high-quality 3D meshes.
138
+
139
+ Code: <a href='https://github.com/TencentARC/DI-PCG' target='_blank'>GitHub</a>. Techenical report: <a href='' target='_blank'>ArXiv</a>.
140
+
141
+ ❗️❗️❗️**Important Notes:**
142
+ - DI-PCG trains a diffusion model for each procedural generator. Current supported generators are: Chair, Table, Vase, Basket, Flower, Dandelion from <a href="https://github.com/princeton-vl/infinigen">Infinigen</a>.
143
+ - The diversity of the generated meshes are strictly bounded by the procedural generators. For out-of-domain shapes, DI-PCG may only provide closest approximations.
144
+ '''
145
+
146
+ _CITE_ = r"""
147
+ If DI-PCG is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/DI-PCG' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/DI-PCG?style=social)](https://github.com/TencentARC/DI-PCG)
148
+ ---
149
+ 📝 **Citation**
150
+
151
+ If you find our work useful for your research or applications, please cite using this bibtex:
152
+ ```bibtex
153
+
154
+ ```
155
+
156
+ 📋 **License**
157
+
158
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file]() for details.
159
+
160
+ 📧 **Contact**
161
+
162
+ If you have any questions, feel free to open a discussion or contact us at <b></b>.
163
+ """
164
+ def update_examples(category):
165
+ samples = [[os.path.join(f"examples/{category}", img_name)]
166
+ for img_name in sorted(os.listdir(f"examples/{category}"))]
167
+ print(samples)
168
+ return gr.Dataset(samples=samples)
169
+
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown(_HEADER_)
172
+ with gr.Row(variant="panel"):
173
+ with gr.Column():
174
+ # select the generator category
175
+ with gr.Row():
176
+ with gr.Group():
177
+ generator_category = gr.Radio(
178
+ choices=[
179
+ "chair",
180
+ "table",
181
+ "vase",
182
+ "basket",
183
+ "flower",
184
+ "dandelion",
185
+ ],
186
+ value="chair",
187
+ label="category",
188
+ )
189
+ with gr.Row():
190
+ input_image = gr.Image(
191
+ label="Input Image",
192
+ image_mode="RGB",
193
+ sources='upload',
194
+ width=256,
195
+ height=256,
196
+ type="pil",
197
+ elem_id="content_image",
198
+ )
199
+ processed_image = gr.Image(
200
+ label="Processed Image",
201
+ image_mode="RGBA",
202
+ width=256,
203
+ height=256,
204
+ type="pil",
205
+ interactive=False
206
+ )
207
+ with gr.Row():
208
+ with gr.Group():
209
+ do_remove_background = gr.Checkbox(
210
+ label="Remove Background", value=False
211
+ )
212
+ sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
213
+
214
+ with gr.Row():
215
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
216
+
217
+ with gr.Row(variant="panel"):
218
+ examples = gr.Examples(
219
+ [os.path.join(f"examples/chair", img_name) for img_name in sorted(os.listdir(f"examples/chair"))],
220
+ inputs=[input_image],
221
+ label="Examples",
222
+ examples_per_page=5
223
+ )
224
+ generator_category.change(update_examples, generator_category, outputs=examples.dataset)
225
+
226
+ with gr.Column():
227
+ with gr.Row():
228
+ with gr.Tab("Geometry"):
229
+ output_model_obj = gr.Model3D(
230
+ label="Output Model",
231
+ #width=768,
232
+ display_mode="wireframe",
233
+ interactive=False
234
+ )
235
+ #with gr.Tab("Textured"):
236
+ # output_model_obj = gr.Model3D(
237
+ # label="Output Model (STL Format)",
238
+ # #width=768,
239
+ # interactive=False,
240
+ # )
241
+ # gr.Markdown("Note: Texture and Material are randomly assigned by the procedural generator.")
242
+
243
+
244
+ gr.Markdown(_CITE_)
245
+ mv_images = gr.State()
246
+
247
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
248
+ fn=preprocess,
249
+ inputs=[input_image, do_remove_background],
250
+ outputs=[processed_image],
251
+ ).success(
252
+ fn=sample,
253
+ inputs=[processed_image, sample_seed, generator_category],
254
+ outputs=[output_model_obj],
255
+ )
256
+
257
+ demo.queue(max_size=10)
258
+ demo.launch(server_name="0.0.0.0", server_port=43839)
configs/demo/basket_demo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ condition_img_dir: examples/basket
3
+ save_dir: logs/basket_demo
4
+ num_sampling_steps: 250
5
+ ckpt_path: pretrained_models/basket.pt
6
+
7
+ # Generator
8
+ generator_root: core/assets
9
+ generator: BasketBaseFactory
10
+ seed: 0
11
+
12
+ # Model
13
+ model: DiT_mini
14
+ num_params: 14
15
+
16
+ # Render
17
+ r_cam_dists: [1.6]
18
+ r_cam_elevations: [60]
19
+ r_cam_azimuths: [30]
20
+ r_zoff: 0.0
configs/demo/chair_demo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ condition_img_dir: examples/chair
3
+ save_dir: logs/chair_demo
4
+ num_sampling_steps: 250
5
+ ckpt_path: pretrained_models/chair.pt
6
+
7
+ # Generator
8
+ generator_root: core/assets
9
+ generator: ChairFactory
10
+ seed: 0
11
+
12
+ # Model
13
+ model: DiT_mini
14
+ num_params: 48
15
+
16
+ # Render
17
+ r_cam_dists: [2.0]
18
+ r_cam_elevations: [60]
19
+ r_cam_azimuths: [30]
20
+ r_zoff: 0.0
configs/demo/dandelion_demo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ condition_img_dir: examples/dandelion
3
+ save_dir: logs/dandelion_demo
4
+ num_sampling_steps: 250
5
+ ckpt_path: pretrained_models/dandelion.pt
6
+
7
+ # Generator
8
+ generator_root: core/assets
9
+ generator: DandelionFactory
10
+ seed: 0
11
+
12
+ # Model
13
+ model: DiT_mini
14
+ num_params: 15
15
+
16
+ # Render
17
+ r_cam_dists: [3.0]
18
+ r_cam_elevations: [90]
19
+ r_cam_azimuths: [0]
20
+ r_zoff: 0.5
configs/demo/flower_demo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ condition_img_dir: examples/flower
3
+ save_dir: logs/flower_demo
4
+ num_sampling_steps: 250
5
+ ckpt_path: pretrained_models/flower.pt
6
+
7
+ # Generator
8
+ generator_root: core/assets
9
+ generator: FlowerFactory
10
+ seed: 0
11
+
12
+ # Model
13
+ model: DiT_mini
14
+ num_params: 9
15
+
16
+ # Render
17
+ r_cam_dists: [4.0]
18
+ r_cam_elevations: [60]
19
+ r_cam_azimuths: [0]
20
+ r_zoff: 0.0
configs/demo/table_demo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ condition_img_dir: examples/table
3
+ save_dir: logs/table_demo
4
+ num_sampling_steps: 250
5
+ ckpt_path: pretrained_models/table.pt
6
+
7
+ # Generator
8
+ generator_root: core/assets
9
+ generator: TableDiningFactory
10
+ seed: 0
11
+
12
+ # Model
13
+ model: DiT_mini
14
+ num_params: 19
15
+
16
+ # Render
17
+ r_cam_dists: [5.0]
18
+ r_cam_elevations: [60]
19
+ r_cam_azimuths: [30]
20
+ r_zoff: 0.1
configs/demo/vase_demo.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ condition_img_dir: examples/vase
3
+ save_dir: logs/vase_demo
4
+ num_sampling_steps: 250
5
+ ckpt_path: pretrained_models/vase.pt
6
+
7
+ # Generator
8
+ generator_root: core/assets
9
+ generator: VaseFactory
10
+ seed: 0
11
+
12
+ # Model
13
+ model: DiT_mini
14
+ num_params: 12
15
+
16
+ # Render
17
+ r_cam_dists: [2.0]
18
+ r_cam_elevations: [60]
19
+ r_cam_azimuths: [0]
20
+ r_zoff: 0.3
configs/infinigen/base.gin ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include 'surface_registry.gin'
2
+
3
+ OVERALL_SEED = 0
4
+ LOG_DIR = '.'
5
+
6
+ Terrain.asset_folder = "" # Will read from $INFINIGEN_ASSET_FOLDER environment var when set to None, and on the fly when set to ""
7
+ Terrain.asset_version = 'May27'
8
+
9
+ util.math.FixedSeed.seed = %OVERALL_SEED
10
+
11
+ execute_tasks.frame_range = [1, 1] # Between start/end frames should this job consider? Increase end frame to tackle video
12
+ execute_tasks.camera_id = [0, 0] # Which camera rig
13
+
14
+ save_obj_and_instances.output_folder="saved_mesh.obj"
15
+
16
+ util.logging.create_text_file.log_dir = %LOG_DIR
17
+
18
+ target_face_size.global_multiplier = 2
19
+ scatter_res_distance.dist = 4
20
+
21
+ random_color_mapping.hue_stddev = 0.05 # Note: 1.0 is the whole color spectrum
22
+
23
+ render.render_image_func = @full/render_image
24
+ configure_render_cycles.time_limit = 0
25
+
26
+ configure_render_cycles.min_samples = 0
27
+ configure_render_cycles.num_samples = 8192
28
+ configure_render_cycles.adaptive_threshold = 0.01
29
+ configure_render_cycles.denoise = False
30
+ configure_render_cycles.exposure = 1
31
+ configure_blender.motion_blur_shutter = 0.15
32
+ render_image.use_dof = False
33
+ render_image.dof_aperture_fstop = 3
34
+ compositor_postprocessing.distort = False
35
+ compositor_postprocessing.color_correct = False
36
+
37
+ flat/configure_render_cycles.min_samples = 1
38
+ flat/configure_render_cycles.num_samples = 16
39
+ flat/render_image.flat_shading = True
40
+ full/render_image.passes_to_save = [
41
+ ['diffuse_direct', 'DiffDir'],
42
+ ['diffuse_color', 'DiffCol'],
43
+ ['diffuse_indirect', 'DiffInd'],
44
+ ['glossy_direct', 'GlossDir'],
45
+ ['glossy_color', 'GlossCol'],
46
+ ['glossy_indirect', 'GlossInd'],
47
+ ['transmission_direct', 'TransDir'],
48
+ ['transmission_color', 'TransCol'],
49
+ ['transmission_indirect', 'TransInd'],
50
+ ['volume_direct', 'VolumeDir'],
51
+ ['emit', 'Emit'],
52
+ ['environment', 'Env'],
53
+ ['ambient_occlusion', 'AO']
54
+ ]
55
+ flat/render_image.passes_to_save = [
56
+ ['z', 'Depth'],
57
+ ['normal', 'Normal'],
58
+ ['vector', 'Vector'],
59
+ ['object_index', 'IndexOB']
60
+ ]
61
+
62
+ execute_tasks.generate_resolution = (1280, 720)
63
+ execute_tasks.fps = 24
64
+ get_sensor_coords.H = 720
65
+ get_sensor_coords.W = 1280
66
+
67
+ min_terrain_distance = 2
68
+ keep_cam_pose_proposal.min_terrain_distance = %min_terrain_distance
69
+ SphericalMesher.r_min = %min_terrain_distance
70
+
71
+ build_terrain_bvh_and_attrs.avoid_border = False # disabled due to crashes 5/15
72
+
73
+ animate_cameras.follow_poi_chance=0.0
74
+ camera.camera_pose_proposal.altitude = ("weighted_choice",
75
+ (0.975, ("clip_gaussian", 2, 0.3, 0.5, 3)), # person height usually
76
+ (0.025, ("clip_gaussian", 15, 7, 5, 30)) # drone height sometimes
77
+ )
78
+
79
+ camera.camera_pose_proposal.pitch = ("clip_gaussian", 90, 30, 20, 160)
80
+
81
+ # WARNING: Large camera rig translations or rotations require special handling.
82
+ # if your cameras are not all approximately forward facing within a few centimeters, you must either:
83
+ # - configure the pipeline to generate assets / terrain for each camera separately, rather than sharing it between the whole rig
84
+ # - or, treat your camera rig as multiple camera rigs each with one camera, and implement code to positon them correctly
85
+ camera.spawn_camera_rigs.n_camera_rigs = 1
86
+ camera.spawn_camera_rigs.camera_rig_config = [
87
+ {'loc': (0, 0, 0), 'rot_euler': (0, 0, 0)},
88
+ {'loc': (0.075, 0, 0), 'rot_euler': (0, 0, 0)}
89
+ ]
configs/test/basket_test.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ save_dir: logs/basket_test
3
+ data_root: /group/40034/wangzhao/data/ipcg/basket_new
4
+ test_file: test_list_mv.txt
5
+ batch_size: 100
6
+ num_workers: 24
7
+ num_sampling_steps: 250
8
+ ckpt_path: /your/path/to/trained/model/ckpt.pt
9
+
10
+ # Generator
11
+ run_generate: False
12
+ generator: BasketBaseFactory
13
+ params_dict_file: params_dict.txt
14
+ seed: 0
15
+
16
+ # Model
17
+ model: DiT_mini
18
+ num_params: 14
19
+
20
+ # Render
21
+ r_cam_dists: [1.6]
22
+ r_cam_elevations: [60]
23
+ r_cam_azimuths: [30]
24
+ r_zoff: 0.0
configs/test/chair_test.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ save_dir: logs/chair_test
3
+ data_root: /group/40034/wangzhao/data/ipcg/chair_new
4
+ test_file: test_list_mv.txt
5
+ batch_size: 100
6
+ num_workers: 24
7
+ num_sampling_steps: 250
8
+ ckpt_path: /your/path/to/trained/model/ckpt.pt
9
+
10
+ # Generator
11
+ run_generate: True
12
+ generator: ChairFactory
13
+ params_dict_file: params_dict.txt
14
+ seed: 0
15
+
16
+ # Model
17
+ model: DiT_mini
18
+ num_params: 48
19
+
20
+ # Render
21
+ r_cam_dists: [2.0]
22
+ r_cam_elevations: [60]
23
+ r_cam_azimuths: [30]
24
+ r_zoff: 0.0
configs/test/dandelion_test.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ save_dir: logs/dandelion_test
3
+ data_root: /group/40075/wangzhao/ipcg/dandelion_new
4
+ test_file: test_list_mv.txt
5
+ batch_size: 100
6
+ num_workers: 24
7
+ num_sampling_steps: 250
8
+ ckpt_path: /your/path/to/trained/model/ckpt.pt
9
+
10
+ # Generator
11
+ run_generate: True
12
+ generator: DandelionFactory
13
+ params_dict_file: params_dict.txt
14
+ seed: 0
15
+
16
+ # Model
17
+ model: DiT_mini
18
+ num_params: 15
19
+
20
+ # Render
21
+ r_cam_dists: [3.0]
22
+ r_cam_elevations: [90]
23
+ r_cam_azimuths: [0]
24
+ r_zoff: 0.5
configs/test/flower_test.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ save_dir: logs/flower_test
3
+ data_root: /group/40075/wangzhao/ipcg/flower_new
4
+ test_file: test_list_mv.txt
5
+ batch_size: 100
6
+ num_workers: 24
7
+ num_sampling_steps: 250
8
+ ckpt_path: /your/path/to/trained/model/ckpt.pt
9
+
10
+ # Generator
11
+ run_generate: True
12
+ generator: FlowerFactory
13
+ params_dict_file: params_dict.txt
14
+ seed: 0
15
+
16
+ # Model
17
+ model: DiT_mini
18
+ num_params: 9
19
+
20
+ # Render
21
+ r_cam_dists: [4.0]
22
+ r_cam_elevations: [60]
23
+ r_cam_azimuths: [0]
24
+ r_zoff: 0.0
configs/test/table_test.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ save_dir: logs/table_test
3
+ data_root: /group/40034/wangzhao/data/ipcg/table_new
4
+ test_file: test_list_mv.txt
5
+ batch_size: 100
6
+ num_workers: 24
7
+ num_sampling_steps: 250
8
+ ckpt_path: /your/path/to/trained/model/ckpt.pt
9
+
10
+ # Generator
11
+ run_generate: True
12
+ generator: TableDiningFactory
13
+ params_dict_file: params_dict.txt
14
+ seed: 0
15
+
16
+ # Model
17
+ model: DiT_mini
18
+ num_params: 19
19
+
20
+ # Render
21
+ r_cam_dists: [5.0]
22
+ r_cam_elevations: [60]
23
+ r_cam_azimuths: [30]
24
+ r_zoff: 0.1
configs/test/vase_test.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test
2
+ save_dir: logs/vase_test
3
+ data_root: /group/40034/wangzhao/data/ipcg/vase_new
4
+ test_file: test_list_mv.txt
5
+ batch_size: 100
6
+ num_workers: 24
7
+ num_sampling_steps: 250
8
+ ckpt_path: /your/path/to/trained/model/ckpt.pt
9
+
10
+ # Generator
11
+ run_generate: True
12
+ generator: VaseFactory
13
+ params_dict_file: params_dict.txt
14
+ seed: 0
15
+
16
+ # Model
17
+ model: DiT_mini
18
+ num_params: 12
19
+
20
+ # Render
21
+ r_cam_dists: [2.0]
22
+ r_cam_elevations: [60]
23
+ r_cam_azimuths: [0]
24
+ r_zoff: 0.3
configs/train/basket_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ save_dir: logs/basket_train
3
+ data_root: /group/40075/wangzhao/ipcg/basket
4
+ train_file: train_list_mv_withaug.txt
5
+ test_file: test_list_mv.txt
6
+ params_dict_file: params_dict.txt
7
+ epochs: 200
8
+ batch_size: 128
9
+ num_workers: 64
10
+ lr: 0.0001
11
+ seed: 0
12
+ logging_iter: 100
13
+ ckpt_iter: 10000
14
+
15
+ # Model
16
+ model: DiT_mini
17
+ num_params: 14
configs/train/chair_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ save_dir: logs/chair_train
3
+ data_root: /group/40046/public_datasets/IPCG/chair_new
4
+ train_file: train_list_mv_withaug.txt
5
+ test_file: test_list_mv.txt
6
+ params_dict_file: params_dict.txt
7
+ epochs: 200
8
+ batch_size: 128
9
+ num_workers: 64
10
+ lr: 0.0001
11
+ seed: 0
12
+ logging_iter: 100
13
+ ckpt_iter: 10000
14
+
15
+ # Model
16
+ model: DiT_mini
17
+ num_params: 48
configs/train/dandelion_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ save_dir: logs/dandelion_train
3
+ data_root: /group/40034/wangzhao/data/ipcg/dandelion
4
+ train_file: train_list_mv_withaug.txt
5
+ test_file: test_list_mv.txt
6
+ params_dict_file: params_dict.txt
7
+ epochs: 200
8
+ batch_size: 128
9
+ num_workers: 64
10
+ lr: 0.0001
11
+ seed: 0
12
+ logging_iter: 100
13
+ ckpt_iter: 10000
14
+
15
+ # Model
16
+ model: DiT_mini
17
+ num_params: 15
configs/train/flower_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ save_dir: logs/flower_train
3
+ data_root: /workspace/40075_wangzhao/ipcg/flower_new
4
+ train_file: train_list_mv_withaug.txt
5
+ test_file: test_list_mv.txt
6
+ params_dict_file: params_dict.txt
7
+ epochs: 200
8
+ batch_size: 128
9
+ num_workers: 64
10
+ lr: 0.0001
11
+ seed: 0
12
+ logging_iter: 100
13
+ ckpt_iter: 10000
14
+
15
+ # Model
16
+ model: DiT_mini
17
+ num_params: 9
configs/train/table_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ save_dir: logs/table_train
3
+ data_root: /group/40034/wangzhao/data/ipcg/table_new
4
+ train_file: train_list_mv_withaug.txt
5
+ test_file: test_list_mv.txt
6
+ params_dict_file: params_dict.txt
7
+ epochs: 200
8
+ batch_size: 128
9
+ num_workers: 64
10
+ lr: 0.0001
11
+ seed: 0
12
+ logging_iter: 100
13
+ ckpt_iter: 10000
14
+
15
+ # Model
16
+ model: DiT_mini
17
+ num_params: 19
configs/train/vase_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train
2
+ save_dir: logs/vase_train
3
+ data_root: /group/40034/wangzhao/data/ipcg/vase_new
4
+ train_file: train_list_mv_withaug.txt
5
+ test_file: test_list_mv.txt
6
+ params_dict_file: params_dict.txt
7
+ epochs: 200
8
+ batch_size: 128
9
+ num_workers: 64
10
+ lr: 0.0001
11
+ seed: 0
12
+ logging_iter: 100
13
+ ckpt_iter: 10000
14
+
15
+ # Model
16
+ model: DiT_mini
17
+ num_params: 12
core/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
core/__pycache__/models.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
core/assets/__pycache__/basket.cpython-310.pyc ADDED
Binary file (9.96 kB). View file
 
core/assets/__pycache__/chair.cpython-310.pyc ADDED
Binary file (17.9 kB). View file
 
core/assets/__pycache__/dandelion.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
core/assets/__pycache__/flower.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
core/assets/__pycache__/table.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
core/assets/__pycache__/vase.cpython-310.pyc ADDED
Binary file (8.91 kB). View file
 
core/assets/basket.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Princeton University.
2
+ # This source code is licensed under the BSD 3-Clause license found in the LICENSE file in the root directory of this source tree.
3
+
4
+ # Authors: Beining Han
5
+
6
+ import bpy
7
+ import numpy as np
8
+ from numpy.random import uniform
9
+ import random
10
+ import time
11
+
12
+ from infinigen.assets.materials.plastics.plastic_rough import shader_rough_plastic
13
+ from infinigen.core import surface, tagging
14
+ from infinigen.core.nodes import node_utils
15
+ from infinigen.core.nodes.node_wrangler import Nodes, NodeWrangler
16
+ from infinigen.core.placement.factory import AssetFactory
17
+
18
+
19
+ @node_utils.to_nodegroup("nodegroup_holes", singleton=False, type="GeometryNodeTree")
20
+ def nodegroup_holes(nw: NodeWrangler):
21
+ # Code generated using version 2.6.4 of the node_transpiler
22
+
23
+ group_input = nw.new_node(
24
+ Nodes.GroupInput,
25
+ expose_input=[
26
+ ("NodeSocketFloat", "height", 0.5000),
27
+ ("NodeSocketFloat", "gap_size", 0.5000),
28
+ ("NodeSocketFloat", "hole_edge_gap", 0.5000),
29
+ ("NodeSocketFloat", "hole_size", 0.5000),
30
+ ("NodeSocketFloat", "depth", 0.5000),
31
+ ("NodeSocketFloat", "width", 0.5000),
32
+ ],
33
+ )
34
+
35
+ add = nw.new_node(
36
+ Nodes.Math,
37
+ input_kwargs={0: group_input.outputs["hole_edge_gap"], 1: 0.0000},
38
+ attrs={"operation": "ADD"}
39
+ )
40
+
41
+ subtract = nw.new_node(
42
+ Nodes.Math,
43
+ input_kwargs={0: group_input.outputs["height"], 1: add},
44
+ attrs={"operation": "SUBTRACT"},
45
+ )
46
+
47
+ add_1 = nw.new_node(
48
+ Nodes.Math,
49
+ input_kwargs={0: group_input.outputs["width"], 1: 0.0000},
50
+ attrs={"operation": "ADD"}
51
+ )
52
+
53
+ subtract_1 = nw.new_node(
54
+ Nodes.Math, input_kwargs={0: add_1, 1: add}, attrs={"operation": "SUBTRACT"}
55
+ )
56
+
57
+ add_2 = nw.new_node(
58
+ Nodes.Math, input_kwargs={0: group_input.outputs["hole_size"], 1: 0.0000}, attrs={"operation": "ADD"}
59
+ )
60
+
61
+ add_3 = nw.new_node(
62
+ Nodes.Math, input_kwargs={0: add_2, 1: group_input.outputs["gap_size"]}, attrs={"operation": "ADD"}
63
+ )
64
+
65
+ divide = nw.new_node(
66
+ Nodes.Math, input_kwargs={0: subtract, 1: add_3}, attrs={"operation": "DIVIDE"}
67
+ )
68
+
69
+ divide_1 = nw.new_node(
70
+ Nodes.Math,
71
+ input_kwargs={0: subtract_1, 1: add_3},
72
+ attrs={"operation": "DIVIDE"},
73
+ )
74
+
75
+ grid = nw.new_node(
76
+ Nodes.MeshGrid,
77
+ input_kwargs={
78
+ "Size X": subtract,
79
+ "Size Y": subtract_1,
80
+ "Vertices X": divide,
81
+ "Vertices Y": divide_1,
82
+ },
83
+ )
84
+
85
+ store_named_attribute = nw.new_node(
86
+ Nodes.StoreNamedAttribute,
87
+ input_kwargs={
88
+ "Geometry": grid.outputs["Mesh"],
89
+ "Name": "uv_map",
90
+ 3: grid.outputs["UV Map"],
91
+ },
92
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
93
+ )
94
+
95
+ transform_1 = nw.new_node(
96
+ Nodes.Transform,
97
+ input_kwargs={
98
+ "Geometry": store_named_attribute,
99
+ "Rotation": (0.0000, 1.5708, 0.0000),
100
+ },
101
+ )
102
+
103
+ add_4 = nw.new_node(
104
+ Nodes.Math, input_kwargs={0: group_input.outputs["depth"], 1: 0.0000}, attrs={"operation": "ADD"}
105
+ )
106
+
107
+ add_5 = nw.new_node(Nodes.Math, input_kwargs={0: add_4, 1: 0.1}, attrs={"operation": "ADD"})
108
+
109
+ combine_xyz_3 = nw.new_node(
110
+ Nodes.CombineXYZ, input_kwargs={"X": add_5, "Y": add_2, "Z": add_2}
111
+ )
112
+
113
+ cube_2 = nw.new_node(Nodes.MeshCube, input_kwargs={"Size": combine_xyz_3})
114
+
115
+ store_named_attribute_1 = nw.new_node(
116
+ Nodes.StoreNamedAttribute,
117
+ input_kwargs={
118
+ "Geometry": cube_2.outputs["Mesh"],
119
+ "Name": "uv_map",
120
+ 3: cube_2.outputs["UV Map"],
121
+ },
122
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
123
+ )
124
+
125
+ instance_on_points = nw.new_node(
126
+ Nodes.InstanceOnPoints,
127
+ input_kwargs={"Points": transform_1, "Instance": store_named_attribute_1},
128
+ )
129
+
130
+ subtract_2 = nw.new_node(
131
+ Nodes.Math, input_kwargs={0: add_4, 1: add}, attrs={"operation": "SUBTRACT"}
132
+ )
133
+
134
+ divide_2 = nw.new_node(
135
+ Nodes.Math,
136
+ input_kwargs={0: subtract_2, 1: add_3},
137
+ attrs={"operation": "DIVIDE"},
138
+ )
139
+
140
+ grid_1 = nw.new_node(
141
+ Nodes.MeshGrid,
142
+ input_kwargs={
143
+ "Size X": subtract_2,
144
+ "Size Y": subtract,
145
+ "Vertices X": divide_2,
146
+ "Vertices Y": divide,
147
+ },
148
+ )
149
+
150
+ store_named_attribute_2 = nw.new_node(
151
+ Nodes.StoreNamedAttribute,
152
+ input_kwargs={
153
+ "Geometry": grid_1.outputs["Mesh"],
154
+ "Name": "uv_map",
155
+ 3: grid_1.outputs["UV Map"],
156
+ },
157
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
158
+ )
159
+
160
+ transform_2 = nw.new_node(
161
+ Nodes.Transform,
162
+ input_kwargs={
163
+ "Geometry": store_named_attribute_2,
164
+ "Rotation": (1.5708, 0.0000, 0.0000),
165
+ },
166
+ )
167
+
168
+ add_6 = nw.new_node(Nodes.Math, input_kwargs={0: add_1, 1: 0.1}, attrs={"operation": "ADD"})
169
+
170
+ combine_xyz_4 = nw.new_node(
171
+ Nodes.CombineXYZ, input_kwargs={"X": add_2, "Y": add_6, "Z": add_2}
172
+ )
173
+
174
+ cube_3 = nw.new_node(Nodes.MeshCube, input_kwargs={"Size": combine_xyz_4})
175
+
176
+ store_named_attribute_3 = nw.new_node(
177
+ Nodes.StoreNamedAttribute,
178
+ input_kwargs={
179
+ "Geometry": cube_3.outputs["Mesh"],
180
+ "Name": "uv_map",
181
+ 3: cube_3.outputs["UV Map"],
182
+ },
183
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
184
+ )
185
+
186
+ instance_on_points_1 = nw.new_node(
187
+ Nodes.InstanceOnPoints,
188
+ input_kwargs={"Points": transform_2, "Instance": store_named_attribute_3},
189
+ )
190
+
191
+ group_output = nw.new_node(
192
+ Nodes.GroupOutput,
193
+ input_kwargs={
194
+ "Instances1": instance_on_points,
195
+ "Instances2": instance_on_points_1,
196
+ },
197
+ attrs={"is_active_output": True},
198
+ )
199
+
200
+
201
+ @node_utils.to_nodegroup(
202
+ "nodegroup_handle_hole", singleton=False, type="GeometryNodeTree"
203
+ )
204
+ def nodegroup_handle_hole(nw: NodeWrangler):
205
+ # Code generated using version 2.6.4 of the node_transpiler
206
+
207
+ group_input = nw.new_node(
208
+ Nodes.GroupInput,
209
+ expose_input=[
210
+ ("NodeSocketFloat", "X", 0.0000),
211
+ ("NodeSocketFloat", "Z", 0.0000),
212
+ ("NodeSocketFloat", "height", 0.5000),
213
+ ("NodeSocketFloat", "hole_dist", 0.5000),
214
+ ("NodeSocketInt", "Level", 0),
215
+ ],
216
+ )
217
+
218
+ combine_xyz_3 = nw.new_node(
219
+ Nodes.CombineXYZ,
220
+ input_kwargs={
221
+ "X": group_input.outputs["X"],
222
+ "Y": 1.0000,
223
+ "Z": group_input.outputs["Z"],
224
+ },
225
+ )
226
+
227
+ cube_2 = nw.new_node(Nodes.MeshCube, input_kwargs={"Size": combine_xyz_3})
228
+
229
+ store_named_attribute = nw.new_node(
230
+ Nodes.StoreNamedAttribute,
231
+ input_kwargs={
232
+ "Geometry": cube_2.outputs["Mesh"],
233
+ "Name": "uv_map",
234
+ 3: cube_2.outputs["UV Map"],
235
+ },
236
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
237
+ )
238
+
239
+ subdivide_mesh_2 = nw.new_node(
240
+ Nodes.SubdivideMesh, input_kwargs={"Mesh": store_named_attribute}
241
+ )
242
+
243
+ subdivision_surface_2 = nw.new_node(
244
+ Nodes.SubdivisionSurface,
245
+ input_kwargs={"Mesh": subdivide_mesh_2, "Level": group_input.outputs["Level"]},
246
+ )
247
+
248
+ multiply = nw.new_node(
249
+ Nodes.Math,
250
+ input_kwargs={0: group_input.outputs["height"]},
251
+ attrs={"operation": "MULTIPLY"},
252
+ )
253
+
254
+ subtract = nw.new_node(
255
+ Nodes.Math,
256
+ input_kwargs={0: multiply, 1: group_input.outputs["hole_dist"]},
257
+ attrs={"operation": "SUBTRACT"},
258
+ )
259
+
260
+ combine_xyz_4 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": subtract})
261
+
262
+ transform_1 = nw.new_node(
263
+ Nodes.Transform,
264
+ input_kwargs={"Geometry": subdivision_surface_2, "Translation": combine_xyz_4},
265
+ )
266
+
267
+ group_output = nw.new_node(
268
+ Nodes.GroupOutput,
269
+ input_kwargs={"Geometry": transform_1},
270
+ attrs={"is_active_output": True},
271
+ )
272
+
273
+
274
+ def geometry_nodes(nw: NodeWrangler, **kwargs):
275
+ # Code generated using version 2.6.4 of the node_transpiler
276
+
277
+ depth = nw.new_node(Nodes.Value, label="depth")
278
+ depth.outputs[0].default_value = kwargs["depth"]
279
+
280
+ width = nw.new_node(Nodes.Value, label="width")
281
+ width.outputs[0].default_value = kwargs["width"]
282
+
283
+ height = nw.new_node(Nodes.Value, label="height")
284
+ height.outputs[0].default_value = kwargs["height"]
285
+
286
+ combine_xyz = nw.new_node(
287
+ Nodes.CombineXYZ, input_kwargs={"X": depth, "Y": width, "Z": height}
288
+ )
289
+
290
+ cube = nw.new_node(Nodes.MeshCube, input_kwargs={"Size": combine_xyz})
291
+
292
+ store_named_attribute = nw.new_node(
293
+ Nodes.StoreNamedAttribute,
294
+ input_kwargs={
295
+ "Geometry": cube.outputs["Mesh"],
296
+ "Name": "uv_map",
297
+ 3: cube.outputs["UV Map"],
298
+ },
299
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
300
+ )
301
+
302
+ subdivide_mesh = nw.new_node(
303
+ Nodes.SubdivideMesh, input_kwargs={"Mesh": store_named_attribute, "Level": 2}
304
+ )
305
+
306
+ sub_level = nw.new_node(Nodes.Integer, label="sub_level")
307
+ sub_level.integer = kwargs["frame_sub_level"]
308
+
309
+ subdivision_surface = nw.new_node(
310
+ Nodes.SubdivisionSurface,
311
+ input_kwargs={"Mesh": subdivide_mesh, "Level": sub_level},
312
+ )
313
+
314
+ differences = []
315
+
316
+ if kwargs["has_handle"]:
317
+ hole_depth = nw.new_node(Nodes.Value, label="hole_depth")
318
+ hole_depth.outputs[0].default_value = kwargs["handle_depth"]
319
+
320
+ hole_height = nw.new_node(Nodes.Value, label="hole_height")
321
+ hole_height.outputs[0].default_value = kwargs["handle_height"]
322
+
323
+ hole_dist = nw.new_node(Nodes.Value, label="hole_dist")
324
+ hole_dist.outputs[0].default_value = kwargs["handle_dist_to_top"]
325
+
326
+ handle_level = nw.new_node(Nodes.Integer, label="handle_level")
327
+ handle_level.integer = kwargs["handle_sub_level"]
328
+ handle_hole = nw.new_node(
329
+ nodegroup_handle_hole().name,
330
+ input_kwargs={
331
+ "X": hole_depth,
332
+ "Z": hole_height,
333
+ "height": height,
334
+ "hole_dist": hole_dist,
335
+ "Level": handle_level,
336
+ },
337
+ )
338
+ differences.append(handle_hole)
339
+
340
+ thickness = nw.new_node(Nodes.Value, label="thickness")
341
+ thickness.outputs[0].default_value = kwargs["thickness"]
342
+
343
+ subtract = nw.new_node(
344
+ Nodes.Math,
345
+ input_kwargs={0: depth, 1: thickness},
346
+ attrs={"operation": "SUBTRACT"},
347
+ )
348
+
349
+ subtract_1 = nw.new_node(
350
+ Nodes.Math,
351
+ input_kwargs={0: width, 1: thickness},
352
+ attrs={"operation": "SUBTRACT"},
353
+ )
354
+
355
+ combine_xyz_1 = nw.new_node(
356
+ Nodes.CombineXYZ, input_kwargs={"X": subtract, "Y": subtract_1, "Z": height}
357
+ )
358
+
359
+ cube_1 = nw.new_node(Nodes.MeshCube, input_kwargs={"Size": combine_xyz_1})
360
+
361
+ store_named_attribute_1 = nw.new_node(
362
+ Nodes.StoreNamedAttribute,
363
+ input_kwargs={
364
+ "Geometry": cube_1.outputs["Mesh"],
365
+ "Name": "uv_map",
366
+ 3: cube_1.outputs["UV Map"],
367
+ },
368
+ attrs={"domain": "CORNER", "data_type": "FLOAT_VECTOR"},
369
+ )
370
+
371
+ subdivide_mesh_1 = nw.new_node(
372
+ Nodes.SubdivideMesh, input_kwargs={"Mesh": store_named_attribute_1, "Level": 2}
373
+ )
374
+
375
+ subdivision_surface_1 = nw.new_node(
376
+ Nodes.SubdivisionSurface,
377
+ input_kwargs={"Mesh": subdivide_mesh_1, "Level": sub_level},
378
+ )
379
+
380
+ multiply = nw.new_node(
381
+ Nodes.Math,
382
+ input_kwargs={0: thickness, 2: 0.2500},
383
+ attrs={"operation": "MULTIPLY"},
384
+ )
385
+
386
+ combine_xyz_2 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply})
387
+
388
+ transform = nw.new_node(
389
+ Nodes.Transform,
390
+ input_kwargs={"Geometry": subdivision_surface_1, "Translation": combine_xyz_2},
391
+ )
392
+
393
+ if kwargs["has_holes"]:
394
+ gap_size = nw.new_node(Nodes.Value, label="gap_size")
395
+ gap_size.outputs[0].default_value = kwargs["hole_gap_size"]
396
+
397
+ hole_edge_gap = nw.new_node(Nodes.Value, label="hole_edge_gap")
398
+ hole_edge_gap.outputs[0].default_value = kwargs["hole_edge_gap"]
399
+
400
+ hole_size = nw.new_node(Nodes.Value, label="hole_size")
401
+ hole_size.outputs[0].default_value = kwargs["hole_size"]
402
+ holes = nw.new_node(
403
+ nodegroup_holes().name,
404
+ input_kwargs={
405
+ "height": height,
406
+ "gap_size": gap_size,
407
+ "hole_edge_gap": hole_edge_gap,
408
+ "hole_size": hole_size,
409
+ "depth": depth,
410
+ "width": width,
411
+ },
412
+ )
413
+ differences.extend([holes.outputs["Instances1"], holes.outputs["Instances2"]])
414
+
415
+ difference = nw.new_node(
416
+ Nodes.MeshBoolean,
417
+ input_kwargs={
418
+ "Mesh 1": subdivision_surface,
419
+ "Mesh 2": [transform] + differences,
420
+ },
421
+ )
422
+
423
+ realize_instances = nw.new_node(
424
+ Nodes.RealizeInstances, input_kwargs={"Geometry": difference.outputs["Mesh"]}
425
+ )
426
+
427
+ multiply_1 = nw.new_node(
428
+ Nodes.Math, input_kwargs={0: height}, attrs={"operation": "MULTIPLY"}
429
+ )
430
+
431
+ combine_xyz_3 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply_1})
432
+
433
+ transform_geometry = nw.new_node(
434
+ Nodes.Transform,
435
+ input_kwargs={"Geometry": realize_instances, "Translation": combine_xyz_3},
436
+ )
437
+
438
+ set_material = nw.new_node(
439
+ Nodes.SetMaterial,
440
+ input_kwargs={
441
+ "Geometry": transform_geometry,
442
+ "Material": surface.shaderfunc_to_material(shader_rough_plastic),
443
+ },
444
+ )
445
+
446
+ group_output = nw.new_node(
447
+ Nodes.GroupOutput,
448
+ input_kwargs={"Geometry": set_material},
449
+ attrs={"is_active_output": True},
450
+ )
451
+
452
+
453
+ class BasketBaseFactory(AssetFactory):
454
+ def __init__(self, factory_seed, coarse=False):
455
+ super(BasketBaseFactory, self).__init__(factory_seed, coarse=coarse)
456
+ self.params = self.get_asset_params()
457
+ self.seed = factory_seed
458
+ self.get_params_dict()
459
+
460
+ def get_params_dict(self):
461
+ self.params_dict = {
462
+ "depth": ['continuous', (0.1, 0.6)],
463
+ "width": ['continuous', (0.1, 0.7)],
464
+ "height": ['continuous', (0.05, 0.4)],
465
+ "frame_sub_level": ['discrete', [0, 3]],
466
+ "thickness": ['continuous', (0.001, 0.03)],
467
+ "has_handle": ['discrete', [0, 1]],
468
+ "handle_sub_level": ['discrete', [0, 1, 2]],
469
+ "handle_depth": ['continuous', (0.2, 0.6)],
470
+ "handle_height": ['continuous', (0.1, 0.3)],
471
+ "handle_dist_to_top": ['continuous', (0.08, 0.4)],
472
+ "has_holes": ['discrete', [0, 1]],
473
+ "hole_gap_size": ['continuous', (0.5, 2.0)],
474
+ "hole_edge_gap": ['continuous', (0.04, 0.1)],
475
+ "hole_size": ['continuous', (0.007, 0.02)]
476
+ }
477
+
478
+ def fix_unused_params(self, params):
479
+ if params['height'] < 0.12:
480
+ params['has_holes'] = 0
481
+ if params['has_handle'] == 0:
482
+ params["handle_sub_level"] = 1
483
+ params["handle_depth"] = 0.3
484
+ params["handle_height"] = 0.2
485
+ params["handle_dist_to_top"] = 0.115
486
+ if params['has_holes'] == 0:
487
+ params["hole_gap_size"] = 0.95
488
+ params["hole_edge_gap"] = 0.05
489
+ params["hole_size"] = 0.0075
490
+ return params
491
+
492
+ def update_params(self, params):
493
+ # TODO: to allow random material
494
+ self.seed = int(1000 * time.time()) % 2**32
495
+
496
+ handle_depth = params['depth'] * params['handle_depth']
497
+ handle_height = params['height'] * params['handle_height']
498
+ handle_dist_to_top = handle_height * 0.5 + params['height'] * params["handle_dist_to_top"]
499
+ if params['height'] < 0.12:
500
+ params["has_holes"] = 0
501
+ hole_gap_size = params['hole_size'] * params["hole_gap_size"]
502
+ parameters = {
503
+ "depth": params["depth"],
504
+ "width": params["width"],
505
+ "height": params["height"],
506
+ "frame_sub_level": params["frame_sub_level"],
507
+ "thickness": params["thickness"],
508
+ "has_handle": params["has_handle"] > 0,
509
+ "handle_sub_level": params["handle_sub_level"],
510
+ "handle_depth": handle_depth,
511
+ "handle_height": handle_height,
512
+ "handle_dist_to_top": handle_dist_to_top,
513
+ "has_holes": params["has_holes"] > 0,
514
+ "hole_gap_size": hole_gap_size,
515
+ "hole_edge_gap": params["hole_edge_gap"],
516
+ "hole_size": params["hole_size"],
517
+ }
518
+ self.params.update(parameters)
519
+
520
+ def get_asset_params(self, i=0):
521
+ params = {}
522
+ if params.get("depth", None) is None:
523
+ params["depth"] = uniform(0.15, 0.4)
524
+ if params.get("width", None) is None:
525
+ params["width"] = uniform(0.2, 0.6)
526
+ if params.get("height", None) is None:
527
+ params["height"] = uniform(0.06, 0.24)
528
+ if params.get("frame_sub_level", None) is None:
529
+ params["frame_sub_level"] = np.random.choice([0, 3], p=[0.5, 0.5])
530
+ if params.get("thickness", None) is None:
531
+ params["thickness"] = uniform(0.001, 0.005)
532
+
533
+ if params.get("has_handle", None) is None:
534
+ params["has_handle"] = np.random.choice([True, False], p=[0.8, 0.2])
535
+ if params.get("handle_sub_level", None) is None:
536
+ params["handle_sub_level"] = np.random.choice([0, 1, 2], p=[0.2, 0.4, 0.4])
537
+ if params.get("handle_depth", None) is None:
538
+ params["handle_depth"] = params["depth"] * uniform(0.2, 0.4)
539
+ if params.get("handle_height", None) is None:
540
+ params["handle_height"] = params["height"] * uniform(0.1, 0.25)
541
+ if params.get("handle_dist_to_top", None) is None:
542
+ params["handle_dist_to_top"] = params["handle_height"] * 0.5 + params[
543
+ "height"
544
+ ] * uniform(0.08, 0.15)
545
+
546
+ if params.get("has_holes", None) is None:
547
+ if params["height"] < 0.12:
548
+ params["has_holes"] = False
549
+ else:
550
+ params["has_holes"] = np.random.choice([True, False], p=[0.5, 0.5])
551
+ if params.get("hole_size", None) is None:
552
+ params["hole_size"] = uniform(0.005, 0.01)
553
+ if params.get("hole_gap_size", None) is None:
554
+ params["hole_gap_size"] = params["hole_size"] * uniform(0.8, 1.1)
555
+ if params.get("hole_edge_gap", None) is None:
556
+ params["hole_edge_gap"] = uniform(0.04, 0.06)
557
+
558
+ return params
559
+
560
+ def create_asset(self, i=0, **params):
561
+ bpy.ops.mesh.primitive_plane_add(
562
+ size=1,
563
+ enter_editmode=False,
564
+ align="WORLD",
565
+ location=(0, 0, 0),
566
+ scale=(1, 1, 1),
567
+ )
568
+ obj = bpy.context.active_object
569
+ np.random.seed(self.seed)
570
+ random.seed(self.seed)
571
+
572
+ surface.add_geomod(
573
+ obj, geometry_nodes, attributes=[], apply=True, input_kwargs=self.params
574
+ )
575
+ tagging.tag_system.relabel_obj(obj)
576
+ return obj
core/assets/chair.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024, Princeton University.
2
+ # This source code is licensed under the BSD 3-Clause license found in the LICENSE file in the root directory of this source tree.
3
+
4
+ # Authors: Lingjie Mei
5
+ import bpy
6
+ import numpy as np
7
+ from numpy.random import uniform
8
+
9
+ import infinigen
10
+ from infinigen.assets.material_assignments import AssetList
11
+ from infinigen.assets.utils.decorate import (
12
+ read_co,
13
+ read_edge_center,
14
+ read_edge_direction,
15
+ remove_edges,
16
+ remove_vertices,
17
+ select_edges,
18
+ solidify,
19
+ subsurf,
20
+ write_attribute,
21
+ write_co,
22
+ )
23
+ from infinigen.assets.utils.draw import align_bezier, bezier_curve
24
+ from infinigen.assets.utils.nodegroup import geo_radius
25
+ from infinigen.assets.utils.object import join_objects, new_bbox
26
+ from infinigen.core import surface
27
+ from infinigen.core.placement.factory import AssetFactory
28
+ from infinigen.core.surface import NoApply
29
+ from infinigen.core.util import blender as butil
30
+ from infinigen.core.util.blender import deep_clone_obj
31
+ from infinigen.core.util.math import FixedSeed
32
+ from infinigen.core.util.random import log_uniform
33
+ from infinigen.core.util.random import random_general as rg
34
+
35
+
36
+ class ChairFactory(AssetFactory):
37
+ back_types = {
38
+ 0: "whole",
39
+ 1: "partial",
40
+ 2: "horizontal-bar",
41
+ 3: "vertical-bar",
42
+ }
43
+ leg_types = {
44
+ 0: "vertical",
45
+ 1: "straight",
46
+ 2: "up-curved",
47
+ 3: "down-curved",
48
+ }
49
+
50
+ def __init__(self, factory_seed, coarse=False):
51
+ super().__init__(factory_seed, coarse)
52
+
53
+ self.get_params_dict()
54
+ # random init with seed
55
+ with FixedSeed(self.factory_seed):
56
+ self.width = uniform(0.4, 0.5)
57
+ self.size = uniform(0.38, 0.45)
58
+ self.thickness = uniform(0.04, 0.08)
59
+ self.bevel_width = self.thickness * (0.1 if uniform() < 0.4 else 0.5)
60
+ self.seat_back = uniform(0.7, 1.0) if uniform() < 0.75 else 1.0
61
+ self.seat_mid = uniform(0.7, 0.8)
62
+ self.seat_mid_x = uniform(
63
+ self.seat_back + self.seat_mid * (1 - self.seat_back), 1
64
+ )
65
+ self.seat_mid_z = uniform(0, 0.5)
66
+ self.seat_front = uniform(1.0, 1.2)
67
+ self.is_seat_round = uniform() < 0.6
68
+ self.is_seat_subsurf = uniform() < 0.5
69
+
70
+ self.leg_thickness = uniform(0.04, 0.06)
71
+ self.limb_profile = uniform(1.5, 2.5)
72
+ self.leg_height = uniform(0.45, 0.5)
73
+ self.back_height = uniform(0.4, 0.5)
74
+ self.is_leg_round = uniform() < 0.5
75
+ self.leg_type = np.random.choice(
76
+ ["vertical", "straight", "up-curved", "down-curved"]
77
+ )
78
+
79
+ self.leg_x_offset = 0
80
+ self.leg_y_offset = 0, 0
81
+ self.back_x_offset = 0
82
+ self.back_y_offset = 0
83
+
84
+ self.has_leg_x_bar = uniform() < 0.6
85
+ self.has_leg_y_bar = uniform() < 0.6
86
+ self.leg_offset_bar = uniform(0.2, 0.4), uniform(0.6, 0.8)
87
+
88
+ self.has_arm = uniform() < 0.7
89
+ self.arm_thickness = uniform(0.04, 0.06)
90
+ self.arm_height = self.arm_thickness * uniform(0.6, 1)
91
+ self.arm_y = uniform(0.8, 1) * self.size
92
+ self.arm_z = uniform(0.3, 0.6) * self.back_height
93
+ self.arm_mid = np.array(
94
+ [uniform(-0.03, 0.03), uniform(-0.03, 0.09), uniform(-0.09, 0.03)]
95
+ )
96
+ self.arm_profile = log_uniform(0.1, 3, 2)
97
+
98
+ self.back_thickness = uniform(0.04, 0.05)
99
+ self.back_type = rg(self.back_types)
100
+ self.back_profile = [(0, 1)]
101
+ self.back_vertical_cuts = np.random.randint(1, 4)
102
+ self.back_partial_scale = uniform(1, 1.4)
103
+
104
+ materials = AssetList["ChairFactory"]()
105
+ self.limb_surface = materials["limb"].assign_material()
106
+ self.surface = materials["surface"].assign_material()
107
+ if uniform() < 0.3:
108
+ self.panel_surface = self.surface
109
+ else:
110
+ self.panel_surface = materials["panel"].assign_material()
111
+
112
+ scratch_prob, edge_wear_prob = materials["wear_tear_prob"]
113
+ self.scratch, self.edge_wear = materials["wear_tear"]
114
+ is_scratch = uniform() < scratch_prob
115
+ is_edge_wear = uniform() < edge_wear_prob
116
+ if not is_scratch:
117
+ self.scratch = None
118
+ if not is_edge_wear:
119
+ self.edge_wear = None
120
+
121
+ # from infinigen.assets.clothes import blanket
122
+ # from infinigen.assets.scatters.clothes import ClothesCover
123
+ # self.clothes_scatter = ClothesCover(factory_fn=blanket.BlanketFactory, width=log_uniform(.8, 1.2),
124
+ # size=uniform(.8, 1.2)) if uniform() < .3 else NoApply()
125
+ self.clothes_scatter = NoApply()
126
+ self.post_init()
127
+
128
+ def get_params_dict(self):
129
+ # all the parameters (key:name, value: [type, range]) used in this generator
130
+ self.params_dict = {
131
+ "width": ['continuous', [0.3, 0.8]], # seat width
132
+ "size": ['continuous', [0.35, 0.5]], # seat length
133
+ "thickness": ['continuous', [0.02, 0.1]], # seat thickness
134
+ "bevel_width": ['discrete', [0.1, 0.5]],
135
+ "seat_back": ['continuous', [0.6, 1.0]], # seat back width
136
+ "seat_mid": ['continuous', [0.7, 0.8]],
137
+ "seat_mid_z": ['continuous', [0.0, 0.7]], # seat mid point height
138
+ "seat_front": ['continuous', [1.0, 1.2]], # seat front point
139
+ "is_seat_round": ['discrete', [0, 1]],
140
+ "is_seat_subsurf": ['discrete', [0, 1]],
141
+ "leg_thickness": ['continuous', [0.02, 0.07]], # leg thickness
142
+ "limb_profile": ['continuous', [1.5, 2.5]],
143
+ "leg_height": ['continuous', [0.2, 1.0]], # leg height
144
+ "is_leg_round": ['discrete', [0, 1]],
145
+ "leg_type": ['discrete', [0,1,2,3]],
146
+ "has_leg_x_bar": ['discrete', [0, 1]],
147
+ "has_leg_y_bar": ['discrete', [0, 1]],
148
+ "leg_offset_bar0": ['continuous', [0.1, 0.9]], # leg y bar offset, only for has_leg_y_bar is 1
149
+ "leg_offset_bar1": ['continuous', [0.1, 0.9]], # leg x bar offset, only for has_leg_x_bar is 1
150
+ "leg_x_offset": ['continuous', [0.0, 0.2]], # leg end point x offset
151
+ "leg_y_offset0": ['continuous', [0.0, 0.2]], # leg end point y offset
152
+ "leg_y_offset1": ['continuous', [0.0, 0.2]], # leg end point y offset
153
+ "has_arm": ['discrete', [0, 1]],
154
+ "arm_thickness": ['continuous', [0.02, 0.07]], # arm thickness, only for has_arm is 1
155
+ "arm_height": ['continuous', [0.6, 1]], # only for has_arm is 1
156
+ "arm_y": ['continuous', [0.5, 1]], # arm y end point, only for has_arm is 1
157
+ "arm_z": ['continuous', [0.25, 0.6]], # arm z end point, only for has_arm is 1
158
+ "arm_mid0": ['continuous', [-0.03, 0.03]], # arm mid point x coord, only for has_arm is 1
159
+ "arm_mid1": ['continuous', [-0.03, 0.2]], # arm mid point y coord, only for has_arm is 1
160
+ "arm_mid2": ['continuous', [-0.09, 0.03]], # arm mid point z coord, only for has_arm is 1
161
+ "arm_profile0": ['continuous', [0.0, 2.0]], # arm curve control, only for has_arm is 1
162
+ "arm_profile1": ['continuous', [0.0, 2]], # arm curve control, only for has_arm is 1
163
+ "back_height": ['continuous', [0.3, 0.6]], # back height
164
+ "back_thickness": ['continuous', [0.02, 0.07]], # back thickness
165
+ "back_type": ['discrete', [0, 1, 2, 3]],
166
+ "back_vertical_cuts": ['discrete', [1,2,3,4]], # only for back type 3
167
+ "back_partial_scale": ['continuous', [1.0, 1.4]], # only for back type 1
168
+ "back_x_offset": ['continuous', [-0.1, 0.15]], # back top x length
169
+ "back_y_offset": ['continuous', [0.0, 0.4]], # back top y coord
170
+ "back_profile_partial": ['continuous', [0.4, 0.8]], # only for back type 1
171
+ "back_profile_horizontal_ncuts": ['discrete', [2, 3, 4]], # only for back type 2
172
+ "back_profile_horizontal_locs0": ['continuous', [1, 2]], # only for back type 2
173
+ "back_profile_horizontal_locs1": ['continuous', [1, 2]], # only for back type 2
174
+ "back_profile_horizontal_locs2": ['continuous', [1, 2]], # only for back type 2
175
+ "back_profile_horizontal_locs3": ['continuous', [1, 2]], # only for back type 2
176
+ "back_profile_horizontal_ratio": ['continuous', [0.2, 0.8]], # only for back type 2
177
+ "back_profile_horizontal_lowest": ['continuous', [0, 0.4]], # only for back type 2
178
+ "back_profile_vertical": ['continuous', [0.8, 0.9]], # only for back type 3
179
+ }
180
+
181
+ def fix_unused_params(self, params):
182
+ # check unused parameters inside a given parameter set, and fix them into mid value - for training
183
+ if params['leg_type'] != 2 and params['leg_type'] != 3:
184
+ params['limb_profile'] = (self.params_dict['limb_profile'][1][0] + self.params_dict['limb_profile'][1][-1]) / 2
185
+ if params['has_leg_x_bar'] == 0:
186
+ params['leg_offset_bar1'] = (self.params_dict['leg_offset_bar1'][1][0] + self.params_dict['leg_offset_bar1'][1][-1]) / 2
187
+ if params['has_leg_y_bar'] == 0:
188
+ params['leg_offset_bar0'] = (self.params_dict['leg_offset_bar0'][1][0] + self.params_dict['leg_offset_bar0'][1][-1]) / 2
189
+ if params['has_arm'] == 0:
190
+ params['arm_thickness'] = (self.params_dict['arm_thickness'][1][0] + self.params_dict['arm_thickness'][1][-1]) / 2
191
+ params['arm_height'] = (self.params_dict['arm_height'][1][0] + self.params_dict['arm_height'][1][-1]) / 2
192
+ params['arm_y'] = (self.params_dict['arm_y'][1][0] + self.params_dict['arm_y'][1][-1]) / 2
193
+ params['arm_z'] = (self.params_dict['arm_z'][1][0] + self.params_dict['arm_z'][1][-1]) / 2
194
+ params['arm_mid0'] = (self.params_dict['arm_mid0'][1][0] + self.params_dict['arm_mid0'][1][-1]) / 2
195
+ params['arm_mid1'] = (self.params_dict['arm_mid1'][1][0] + self.params_dict['arm_mid1'][1][-1]) / 2
196
+ params['arm_mid2'] = (self.params_dict['arm_mid2'][1][0] + self.params_dict['arm_mid2'][1][-1]) / 2
197
+ params['arm_profile0'] = (self.params_dict['arm_profile0'][1][0] + self.params_dict['arm_profile0'][1][-1]) / 2
198
+ params['arm_profile1'] = (self.params_dict['arm_profile1'][1][0] + self.params_dict['arm_profile1'][1][-1]) / 2
199
+ if params['back_type'] != 3:
200
+ params['back_vertical_cuts'] = (self.params_dict['back_vertical_cuts'][1][0] + self.params_dict['back_vertical_cuts'][1][-1]) / 2
201
+ params['back_profile_vertical'] = (self.params_dict['back_profile_vertical'][1][0] + self.params_dict['back_profile_vertical'][1][-1]) / 2
202
+ if params['back_type'] != 2:
203
+ params['back_profile_horizontal_ncuts'] = (self.params_dict['back_profile_horizontal_ncuts'][1][0] + self.params_dict['back_profile_horizontal_ncuts'][1][-1]) / 2
204
+ params['back_profile_horizontal_locs0'] = (self.params_dict['back_profile_horizontal_locs0'][1][0] + self.params_dict['back_profile_horizontal_locs0'][1][-1]) / 2
205
+ params['back_profile_horizontal_locs1'] = (self.params_dict['back_profile_horizontal_locs1'][1][0] + self.params_dict['back_profile_horizontal_locs1'][1][-1]) / 2
206
+ params['back_profile_horizontal_locs2'] = (self.params_dict['back_profile_horizontal_locs2'][1][0] + self.params_dict['back_profile_horizontal_locs2'][1][-1]) / 2
207
+ params['back_profile_horizontal_ratio'] = (self.params_dict['back_profile_horizontal_ratio'][1][0] + self.params_dict['back_profile_horizontal_ratio'][1][-1]) / 2
208
+ params['back_profile_horizontal_lowest'] = (self.params_dict['back_profile_horizontal_lowest'][1][0] + self.params_dict['back_profile_horizontal_lowest'][1][-1]) / 2
209
+ if params['back_type'] != 1:
210
+ params['back_partial_scale'] = (self.params_dict['back_partial_scale'][1][0] + self.params_dict['back_partial_scale'][1][-1]) / 2
211
+ params['back_profile_partial'] = (self.params_dict['back_profile_partial'][1][0] + self.params_dict['back_profile_partial'][1][-1]) / 2
212
+ return params
213
+
214
+ def update_params(self, new_params):
215
+ # replace the parameters and calculate all the new values
216
+ self.width = new_params["width"]
217
+ self.size = new_params["size"]
218
+ self.thickness = new_params["thickness"]
219
+ self.bevel_width = self.thickness * new_params["bevel_width"]
220
+ self.seat_back = new_params["seat_back"]
221
+ self.seat_mid = new_params["seat_mid"]
222
+ self.seat_mid_x = uniform(
223
+ self.seat_back + self.seat_mid * (1 - self.seat_back), 1
224
+ )
225
+ self.seat_mid_z = new_params["seat_mid_z"]
226
+ self.seat_front = new_params["seat_front"]
227
+ self.is_seat_round = new_params["is_seat_round"]
228
+ self.is_seat_subsurf = new_params["is_seat_subsurf"]
229
+
230
+ self.leg_thickness = new_params["leg_thickness"]
231
+ self.limb_profile = new_params["limb_profile"]
232
+ self.leg_height = new_params["leg_height"]
233
+ self.back_height = new_params["back_height"]
234
+ self.is_leg_round = new_params["is_leg_round"]
235
+ self.leg_type = self.leg_types[new_params["leg_type"]]
236
+
237
+ self.leg_x_offset = 0
238
+ self.leg_y_offset = 0, 0
239
+ self.back_x_offset = 0
240
+ self.back_y_offset = 0
241
+
242
+ self.has_leg_x_bar = new_params["has_leg_x_bar"]
243
+ self.has_leg_y_bar = new_params["has_leg_y_bar"]
244
+ self.leg_offset_bar = new_params["leg_offset_bar0"], new_params["leg_offset_bar1"]
245
+
246
+ self.has_arm = new_params["has_arm"]
247
+ self.arm_thickness = new_params["arm_thickness"]
248
+ self.arm_height = self.arm_thickness * new_params["arm_height"]
249
+ self.arm_y = new_params["arm_y"] * self.size
250
+ self.arm_z = new_params["arm_z"] * self.back_height
251
+ self.arm_mid = np.array(
252
+ [new_params["arm_mid0"], new_params["arm_mid1"], new_params["arm_mid2"]]
253
+ )
254
+ self.arm_profile = (new_params["arm_profile0"], new_params["arm_profile1"])
255
+
256
+ self.back_thickness = new_params["back_thickness"]
257
+ self.back_type = self.back_types[new_params["back_type"]]
258
+ self.back_profile = [(0, 1)]
259
+ self.back_vertical_cuts = new_params["back_vertical_cuts"]
260
+ self.back_partial_scale = new_params["back_partial_scale"]
261
+
262
+ if self.leg_type == "vertical":
263
+ self.leg_x_offset = 0
264
+ self.leg_y_offset = 0, 0
265
+ self.back_x_offset = 0
266
+ self.back_y_offset = 0
267
+ else:
268
+ self.leg_x_offset = self.width * new_params["leg_x_offset"]
269
+ self.leg_y_offset = self.size * np.array([new_params["leg_y_offset0"], new_params["leg_y_offset1"]])
270
+ self.back_x_offset = self.width * new_params["back_x_offset"]
271
+ self.back_y_offset = self.size * new_params["back_y_offset"]
272
+
273
+ match self.back_type:
274
+ case "partial":
275
+ self.back_profile = ((new_params["back_profile_partial"], 1),)
276
+ case "horizontal-bar":
277
+ n_cuts = int(new_params["back_profile_horizontal_ncuts"])
278
+ locs = np.array([new_params["back_profile_horizontal_locs0"], new_params["back_profile_horizontal_locs1"],
279
+ new_params["back_profile_horizontal_locs2"], new_params["back_profile_horizontal_locs3"]])[:n_cuts].cumsum()
280
+ locs = locs / locs[-1]
281
+ ratio = new_params["back_profile_horizontal_ratio"]
282
+ locs = np.array(
283
+ [
284
+ (p + ratio * (l - p), l)
285
+ for p, l in zip([0, *locs[:-1]], locs)
286
+ ]
287
+ )
288
+ lowest = new_params["back_profile_horizontal_lowest"]
289
+ self.back_profile = locs * (1 - lowest) + lowest
290
+ case "vertical-bar":
291
+ self.back_profile = ((new_params["back_profile_vertical"], 1),)
292
+ case _:
293
+ self.back_profile = [(0, 1)]
294
+
295
+ # TODO: handle the material into the optimization loop
296
+ materials = AssetList["ChairFactory"]()
297
+ self.limb_surface = materials["limb"].assign_material()
298
+ self.surface = materials["surface"].assign_material()
299
+ if uniform() < 0.3:
300
+ self.panel_surface = self.surface
301
+ else:
302
+ self.panel_surface = materials["panel"].assign_material()
303
+
304
+ scratch_prob, edge_wear_prob = materials["wear_tear_prob"]
305
+ self.scratch, self.edge_wear = materials["wear_tear"]
306
+ is_scratch = uniform() < scratch_prob
307
+ is_edge_wear = uniform() < edge_wear_prob
308
+ if not is_scratch:
309
+ self.scratch = None
310
+ if not is_edge_wear:
311
+ self.edge_wear = None
312
+
313
+ # from infinigen.assets.clothes import blanket
314
+ # from infinigen.assets.scatters.clothes import ClothesCover
315
+ # self.clothes_scatter = ClothesCover(factory_fn=blanket.BlanketFactory, width=log_uniform(.8, 1.2),
316
+ # size=uniform(.8, 1.2)) if uniform() < .3 else NoApply()
317
+ self.clothes_scatter = NoApply()
318
+
319
+
320
+ def post_init(self):
321
+ with FixedSeed(self.factory_seed):
322
+ if self.leg_type == "vertical":
323
+ self.leg_x_offset = 0
324
+ self.leg_y_offset = 0, 0
325
+ self.back_x_offset = 0
326
+ self.back_y_offset = 0
327
+ else:
328
+ self.leg_x_offset = self.width * uniform(0.05, 0.2)
329
+ self.leg_y_offset = self.size * uniform(0.05, 0.2, 2)
330
+ self.back_x_offset = self.width * uniform(-0.1, 0.15)
331
+ self.back_y_offset = self.size * uniform(0.1, 0.25)
332
+
333
+ match self.back_type:
334
+ case "partial":
335
+ self.back_profile = ((uniform(0.4, 0.8), 1),)
336
+ case "horizontal-bar":
337
+ n_cuts = np.random.randint(2, 4)
338
+ locs = uniform(1, 2, n_cuts).cumsum()
339
+ locs = locs / locs[-1]
340
+ ratio = uniform(0.5, 0.75)
341
+ locs = np.array(
342
+ [
343
+ (p + ratio * (l - p), l)
344
+ for p, l in zip([0, *locs[:-1]], locs)
345
+ ]
346
+ )
347
+ lowest = uniform(0, 0.4)
348
+ self.back_profile = locs * (1 - lowest) + lowest
349
+ case "vertical-bar":
350
+ self.back_profile = ((uniform(0.8, 0.9), 1),)
351
+ case _:
352
+ self.back_profile = [(0, 1)]
353
+
354
+ def create_placeholder(self, **kwargs) -> bpy.types.Object:
355
+ obj = new_bbox(
356
+ -self.width / 2 - max(self.leg_x_offset, self.back_x_offset),
357
+ self.width / 2 + max(self.leg_x_offset, self.back_x_offset),
358
+ -self.size - self.leg_y_offset[1] - self.leg_thickness * 0.5,
359
+ max(self.leg_y_offset[0], self.back_y_offset),
360
+ -self.leg_height,
361
+ self.back_height * 1.2,
362
+ )
363
+ obj.rotation_euler.z += np.pi / 2
364
+ butil.apply_transform(obj)
365
+ return obj
366
+
367
+ def create_asset(self, **params) -> bpy.types.Object:
368
+ obj = self.make_seat()
369
+ legs = self.make_legs()
370
+ backs = self.make_backs()
371
+
372
+ parts = [obj] + legs + backs
373
+ parts.extend(self.make_leg_decors(legs))
374
+ if self.has_arm:
375
+ parts.extend(self.make_arms(obj, backs))
376
+ parts.extend(self.make_back_decors(backs))
377
+
378
+ for obj in legs:
379
+ self.solidify(obj, 2)
380
+ for obj in backs:
381
+ self.solidify(obj, 2, self.back_thickness)
382
+
383
+ obj = join_objects(parts)
384
+ obj.rotation_euler.z += np.pi / 2
385
+ butil.apply_transform(obj)
386
+
387
+ with FixedSeed(self.factory_seed):
388
+ # TODO: wasteful to create unique materials for each individual asset
389
+ self.surface.apply(obj)
390
+ self.panel_surface.apply(obj, selection="panel")
391
+ self.limb_surface.apply(obj, selection="limb")
392
+
393
+ return obj
394
+
395
+ def finalize_assets(self, assets):
396
+ if self.scratch:
397
+ self.scratch.apply(assets)
398
+ if self.edge_wear:
399
+ self.edge_wear.apply(assets)
400
+
401
+ def make_seat(self):
402
+ x_anchors = (
403
+ np.array(
404
+ [
405
+ 0,
406
+ -self.seat_back,
407
+ -self.seat_mid_x,
408
+ -1,
409
+ 0,
410
+ 1,
411
+ self.seat_mid_x,
412
+ self.seat_back,
413
+ 0,
414
+ ]
415
+ )
416
+ * self.width
417
+ / 2
418
+ )
419
+ y_anchors = (
420
+ np.array(
421
+ [0, 0, -self.seat_mid, -1, -self.seat_front, -1, -self.seat_mid, 0, 0]
422
+ )
423
+ * self.size
424
+ )
425
+ z_anchors = (
426
+ np.array([0, 0, self.seat_mid_z, 0, 0, 0, self.seat_mid_z, 0, 0])
427
+ * self.thickness
428
+ )
429
+ vector_locations = [1, 7] if self.is_seat_round else [1, 3, 5, 7]
430
+ obj = bezier_curve((x_anchors, y_anchors, z_anchors), vector_locations, 8)
431
+ with butil.ViewportMode(obj, "EDIT"):
432
+ bpy.ops.mesh.select_all(action="SELECT")
433
+ bpy.ops.mesh.fill_grid(use_interp_simple=True)
434
+ butil.modify_mesh(obj, "SOLIDIFY", thickness=self.thickness, offset=0)
435
+ subsurf(obj, 1, not self.is_seat_subsurf)
436
+ butil.modify_mesh(obj, "BEVEL", width=self.bevel_width, segments=8)
437
+ return obj
438
+
439
+ def make_legs(self):
440
+ leg_starts = np.array(
441
+ [[-self.seat_back, 0, 0], [-1, -1, 0], [1, -1, 0], [self.seat_back, 0, 0]]
442
+ ) * np.array([[self.width / 2, self.size, 0]])
443
+ leg_ends = leg_starts.copy()
444
+ leg_ends[[0, 1], 0] -= self.leg_x_offset
445
+ leg_ends[[2, 3], 0] += self.leg_x_offset
446
+ leg_ends[[0, 3], 1] += self.leg_y_offset[0]
447
+ leg_ends[[1, 2], 1] -= self.leg_y_offset[1]
448
+ leg_ends[:, -1] = -self.leg_height
449
+ return self.make_limb(leg_ends, leg_starts)
450
+
451
+ def make_limb(self, leg_ends, leg_starts):
452
+ limbs = []
453
+ for leg_start, leg_end in zip(leg_starts, leg_ends):
454
+ match self.leg_type:
455
+ case "up-curved":
456
+ axes = [(0, 0, 1), None]
457
+ scale = [self.limb_profile, 1]
458
+ case "down-curved":
459
+ axes = [None, (0, 0, 1)]
460
+ scale = [1, self.limb_profile]
461
+ case _:
462
+ axes = None
463
+ scale = None
464
+ limb = align_bezier(
465
+ np.stack([leg_start, leg_end], -1), axes, scale, resolution=64
466
+ )
467
+ limb.location = (
468
+ np.array(
469
+ [
470
+ 1 if leg_start[0] < 0 else -1,
471
+ 1 if leg_start[1] < -self.size / 2 else -1,
472
+ 0,
473
+ ]
474
+ )
475
+ * self.leg_thickness
476
+ / 2
477
+ )
478
+ butil.apply_transform(limb, True)
479
+ limbs.append(limb)
480
+ return limbs
481
+
482
+ def make_backs(self):
483
+ back_starts = (
484
+ np.array([[-self.seat_back, 0, 0], [self.seat_back, 0, 0]]) * self.width / 2
485
+ )
486
+ back_ends = back_starts.copy()
487
+ back_ends[:, 0] += np.array([self.back_x_offset, -self.back_x_offset])
488
+ back_ends[:, 1] = self.back_y_offset
489
+ back_ends[:, 2] = self.back_height
490
+ return self.make_limb(back_starts, back_ends)
491
+
492
+ def make_leg_decors(self, legs):
493
+ decors = []
494
+ if self.has_leg_x_bar:
495
+ z_height = -self.leg_height * uniform(*self.leg_offset_bar)
496
+ locs = []
497
+ for leg in legs:
498
+ co = read_co(leg)
499
+ locs.append(co[np.argmin(np.abs(co[:, -1] - z_height))])
500
+ decors.append(
501
+ self.solidify(bezier_curve(np.stack([locs[0], locs[3]], -1)), 0)
502
+ )
503
+ decors.append(
504
+ self.solidify(bezier_curve(np.stack([locs[1], locs[2]], -1)), 0)
505
+ )
506
+ if self.has_leg_y_bar:
507
+ z_height = -self.leg_height * uniform(*self.leg_offset_bar)
508
+ locs = []
509
+ for leg in legs:
510
+ co = read_co(leg)
511
+ locs.append(co[np.argmin(np.abs(co[:, -1] - z_height))])
512
+ decors.append(
513
+ self.solidify(bezier_curve(np.stack([locs[0], locs[1]], -1)), 1)
514
+ )
515
+ decors.append(
516
+ self.solidify(bezier_curve(np.stack([locs[2], locs[3]], -1)), 1)
517
+ )
518
+ for d in decors:
519
+ write_attribute(d, 1, "limb", "FACE")
520
+ return decors
521
+
522
+ def make_back_decors(self, backs, finalize=True):
523
+ obj = join_objects([deep_clone_obj(b) for b in backs])
524
+ x, y, z = read_co(obj).T
525
+ x += np.where(x > 0, self.back_thickness / 2, -self.back_thickness / 2)
526
+ write_co(obj, np.stack([x, y, z], -1))
527
+ smoothness = uniform(0, 1)
528
+ profile_shape_factor = uniform(0, 0.4)
529
+ with butil.ViewportMode(obj, "EDIT"):
530
+ bpy.ops.mesh.select_mode(type="EDGE")
531
+ center = read_edge_center(obj)
532
+ for z_min, z_max in self.back_profile:
533
+ select_edges(
534
+ obj,
535
+ (z_min * self.back_height <= center[:, -1])
536
+ & (center[:, -1] <= z_max * self.back_height),
537
+ )
538
+ bpy.ops.mesh.bridge_edge_loops(
539
+ number_cuts=32,
540
+ interpolation="LINEAR",
541
+ smoothness=smoothness,
542
+ profile_shape_factor=profile_shape_factor,
543
+ )
544
+ bpy.ops.mesh.select_loose()
545
+ bpy.ops.mesh.delete()
546
+ butil.modify_mesh(
547
+ obj,
548
+ "SOLIDIFY",
549
+ thickness=np.minimum(self.thickness, self.back_thickness),
550
+ offset=0,
551
+ )
552
+ if finalize:
553
+ butil.modify_mesh(obj, "BEVEL", width=self.bevel_width, segments=8)
554
+ parts = [obj]
555
+ if self.back_type == "vertical-bar":
556
+ other = join_objects([deep_clone_obj(b) for b in backs])
557
+ with butil.ViewportMode(other, "EDIT"):
558
+ bpy.ops.mesh.select_mode(type="EDGE")
559
+ bpy.ops.mesh.select_all(action="SELECT")
560
+ bpy.ops.mesh.bridge_edge_loops(
561
+ number_cuts=self.back_vertical_cuts,
562
+ interpolation="LINEAR",
563
+ smoothness=smoothness,
564
+ profile_shape_factor=profile_shape_factor,
565
+ )
566
+ bpy.ops.mesh.select_all(action="INVERT")
567
+ bpy.ops.mesh.delete()
568
+ bpy.ops.mesh.select_all(action="SELECT")
569
+ bpy.ops.mesh.delete(type="ONLY_FACE")
570
+ remove_edges(other, np.abs(read_edge_direction(other)[:, -1]) < 0.5)
571
+ remove_vertices(other, lambda x, y, z: z < -self.thickness / 2)
572
+ remove_vertices(
573
+ other,
574
+ lambda x, y, z: z
575
+ > (self.back_profile[0][0] + self.back_profile[0][1])
576
+ * self.back_height
577
+ / 2,
578
+ )
579
+ parts.append(self.solidify(other, 2, self.back_thickness))
580
+ elif self.back_type == "partial":
581
+ co = read_co(obj)
582
+ co[:, 1] *= self.back_partial_scale
583
+ write_co(obj, co)
584
+ for p in parts:
585
+ write_attribute(p, 1, "panel", "FACE")
586
+ return parts
587
+
588
+ def make_arms(self, base, backs):
589
+ co = read_co(base)
590
+ end = co[np.argmin(co[:, 0] - (np.abs(co[:, 1] + self.arm_y) < 0.02))]
591
+ end[0] += self.arm_thickness / 4
592
+ end_ = end.copy()
593
+ end_[0] = -end[0]
594
+ arms = []
595
+ co = read_co(backs[0])
596
+ start = co[np.argmin(co[:, 0] - (np.abs(co[:, -1] - self.arm_z) < 0.02))]
597
+ start[0] -= self.arm_thickness / 4
598
+ start_ = start.copy()
599
+ start_[0] = -start[0]
600
+ for start, end in zip([start, start_], [end, end_]):
601
+ mid = np.array(
602
+ [
603
+ end[0] + self.arm_mid[0] * (-1 if end[0] > 0 else 1),
604
+ end[1] + self.arm_mid[1],
605
+ start[2] + self.arm_mid[2],
606
+ ]
607
+ )
608
+ arm = align_bezier(
609
+ np.stack([start, mid, end], -1),
610
+ np.array(
611
+ [
612
+ [end[0] - start[0], end[1] - start[1], 0],
613
+ [0, 1 / np.sqrt(2), 1 / np.sqrt(2)],
614
+ [0, 0, 1],
615
+ ]
616
+ ),
617
+ [1, *self.arm_profile, 1],
618
+ )
619
+ if self.is_leg_round:
620
+ surface.add_geomod(
621
+ arm,
622
+ geo_radius,
623
+ apply=True,
624
+ input_args=[self.arm_thickness / 2, 32],
625
+ input_kwargs={"to_align_tilt": False},
626
+ )
627
+ else:
628
+ with butil.ViewportMode(arm, "EDIT"):
629
+ bpy.ops.mesh.select_all(action="SELECT")
630
+ bpy.ops.mesh.extrude_edges_move(
631
+ TRANSFORM_OT_translate={
632
+ "value": (
633
+ self.arm_thickness
634
+ if end[0] < 0
635
+ else -self.arm_thickness,
636
+ 0,
637
+ 0,
638
+ )
639
+ }
640
+ )
641
+ butil.modify_mesh(arm, "SOLIDIFY", thickness=self.arm_height, offset=0)
642
+ write_attribute(arm, 1, "limb", "FACE")
643
+ arms.append(arm)
644
+ return arms
645
+
646
+ def solidify(self, obj, axis, thickness=None):
647
+ if thickness is None:
648
+ thickness = self.leg_thickness
649
+ if self.is_leg_round:
650
+ solidify(obj, axis, thickness)
651
+ butil.modify_mesh(obj, "BEVEL", width=self.bevel_width, segments=8)
652
+ else:
653
+ surface.add_geomod(
654
+ obj, geo_radius, apply=True, input_args=[thickness / 2, 32]
655
+ )
656
+ write_attribute(obj, 1, "limb", "FACE")
657
+ return obj
core/assets/dandelion.py ADDED
@@ -0,0 +1,1097 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Princeton University.
2
+ # This source code is licensed under the BSD 3-Clause license found in the LICENSE file in the root directory of this source tree.
3
+
4
+ # Authors: Beining Han
5
+ # Acknowledgement: This file draws inspiration from https://www.youtube.com/watch?v=61Sk8j1Ml9c by BradleyAnimation
6
+
7
+ import bpy
8
+ import numpy as np
9
+ from numpy.random import normal, randint, uniform
10
+
11
+ import infinigen
12
+ from infinigen.assets.materials import simple_brownish, simple_greenery, simple_whitish
13
+ from infinigen.core import surface
14
+ from infinigen.core.nodes import node_utils
15
+ from infinigen.core.nodes.node_wrangler import Nodes, NodeWrangler
16
+ from infinigen.core.placement.factory import AssetFactory
17
+ from infinigen.core.tagging import tag_nodegroup, tag_object
18
+ from infinigen.core.util.math import FixedSeed
19
+
20
+
21
+
22
+ @node_utils.to_nodegroup(
23
+ "nodegroup_pedal_stem_head_geometry", singleton=False, type="GeometryNodeTree"
24
+ )
25
+ def nodegroup_pedal_stem_head_geometry(nw: NodeWrangler):
26
+ # Code generated using version 2.4.3 of the node_transpiler
27
+
28
+ group_input = nw.new_node(
29
+ Nodes.GroupInput,
30
+ expose_input=[
31
+ ("NodeSocketVectorTranslation", "Translation", (0.0, 0.0, 1.0)),
32
+ ("NodeSocketFloatDistance", "Radius", 0.04),
33
+ ],
34
+ )
35
+
36
+ uv_sphere_1 = nw.new_node(
37
+ Nodes.MeshUVSphere,
38
+ input_kwargs={"Segments": 64, "Radius": group_input.outputs["Radius"]},
39
+ )
40
+
41
+ transform_1 = nw.new_node(
42
+ Nodes.Transform,
43
+ input_kwargs={
44
+ "Geometry": uv_sphere_1,
45
+ "Translation": group_input.outputs["Translation"],
46
+ },
47
+ )
48
+
49
+ set_material = nw.new_node(
50
+ Nodes.SetMaterial,
51
+ input_kwargs={
52
+ "Geometry": transform_1,
53
+ "Material": surface.shaderfunc_to_material(
54
+ simple_brownish.shader_simple_brown
55
+ ),
56
+ },
57
+ )
58
+
59
+ group_output = nw.new_node(
60
+ Nodes.GroupOutput, input_kwargs={"Geometry": set_material}
61
+ )
62
+
63
+
64
+ @node_utils.to_nodegroup(
65
+ "nodegroup_pedal_stem_end_geometry", singleton=False, type="GeometryNodeTree"
66
+ )
67
+ def nodegroup_pedal_stem_end_geometry(nw: NodeWrangler):
68
+ # Code generated using version 2.4.3 of the node_transpiler
69
+
70
+ group_input = nw.new_node(
71
+ Nodes.GroupInput, expose_input=[("NodeSocketGeometry", "Points", None)]
72
+ )
73
+
74
+ endpoint_selection = nw.new_node(
75
+ "GeometryNodeCurveEndpointSelection", input_kwargs={"End Size": 0}
76
+ )
77
+
78
+ uv_sphere = nw.new_node(
79
+ Nodes.MeshUVSphere, input_kwargs={"Segments": 64, "Radius": 0.04}
80
+ )
81
+
82
+ vector = nw.new_node(Nodes.Vector)
83
+ vector.vector = (uniform(0.45, 0.7), uniform(0.45, 0.7), uniform(2, 3))
84
+
85
+ transform = nw.new_node(
86
+ Nodes.Transform, input_kwargs={"Geometry": uv_sphere, "Scale": vector}
87
+ )
88
+
89
+ cone = nw.new_node(
90
+ "GeometryNodeMeshCone", input_kwargs={"Radius Bottom": 0.0040, "Depth": 0.0040}
91
+ )
92
+
93
+ normal = nw.new_node(Nodes.InputNormal)
94
+
95
+ align_euler_to_vector_1 = nw.new_node(
96
+ Nodes.AlignEulerToVector, input_kwargs={"Vector": normal}, attrs={"axis": "Z"}
97
+ )
98
+
99
+ instance_on_points_1 = nw.new_node(
100
+ Nodes.InstanceOnPoints,
101
+ input_kwargs={
102
+ "Points": transform,
103
+ "Instance": cone.outputs["Mesh"],
104
+ "Rotation": align_euler_to_vector_1,
105
+ },
106
+ )
107
+
108
+ join_geometry = nw.new_node(
109
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [instance_on_points_1, transform]}
110
+ )
111
+
112
+ set_material = nw.new_node(
113
+ Nodes.SetMaterial,
114
+ input_kwargs={
115
+ "Geometry": join_geometry,
116
+ "Material": surface.shaderfunc_to_material(
117
+ simple_brownish.shader_simple_brown
118
+ ),
119
+ },
120
+ )
121
+
122
+ geometry_to_instance = nw.new_node(
123
+ "GeometryNodeGeometryToInstance", input_kwargs={"Geometry": set_material}
124
+ )
125
+
126
+ curve_tangent = nw.new_node(Nodes.CurveTangent)
127
+
128
+ align_euler_to_vector = nw.new_node(
129
+ Nodes.AlignEulerToVector,
130
+ input_kwargs={"Vector": curve_tangent},
131
+ attrs={"axis": "Z"},
132
+ )
133
+
134
+ instance_on_points = nw.new_node(
135
+ Nodes.InstanceOnPoints,
136
+ input_kwargs={
137
+ "Points": group_input.outputs["Points"],
138
+ "Selection": endpoint_selection,
139
+ "Instance": geometry_to_instance,
140
+ "Rotation": align_euler_to_vector,
141
+ },
142
+ )
143
+
144
+ realize_instances = nw.new_node(
145
+ Nodes.RealizeInstances, input_kwargs={"Geometry": instance_on_points}
146
+ )
147
+
148
+ group_output = nw.new_node(
149
+ Nodes.GroupOutput, input_kwargs={"Geometry": realize_instances}
150
+ )
151
+
152
+
153
+ @node_utils.to_nodegroup(
154
+ "nodegroup_pedal_stem_branch_shape", singleton=False, type="GeometryNodeTree"
155
+ )
156
+ def nodegroup_pedal_stem_branch_shape(nw: NodeWrangler):
157
+ # Code generated using version 2.6.4 of the node_transpiler
158
+
159
+ pedal_stem_branches_num = nw.new_node(
160
+ Nodes.Integer, label="pedal_stem_branches_num"
161
+ )
162
+ pedal_stem_branches_num.integer = 40
163
+
164
+ group_input = nw.new_node(
165
+ Nodes.GroupInput, expose_input=[("NodeSocketFloatDistance", "Radius", 0.0100)]
166
+ )
167
+
168
+ curve_circle_1 = nw.new_node(
169
+ Nodes.CurveCircle,
170
+ input_kwargs={
171
+ "Resolution": pedal_stem_branches_num,
172
+ "Radius": group_input.outputs["Radius"],
173
+ },
174
+ )
175
+
176
+ pedal_stem_branch_length = nw.new_node(
177
+ Nodes.Value, label="pedal_stem_branch_length"
178
+ )
179
+ pedal_stem_branch_length.outputs[0].default_value = 0.5000
180
+
181
+ combine_xyz_1 = nw.new_node(
182
+ Nodes.CombineXYZ, input_kwargs={"X": pedal_stem_branch_length}
183
+ )
184
+
185
+ curve_line_1 = nw.new_node(Nodes.CurveLine, input_kwargs={"End": combine_xyz_1})
186
+
187
+ resample_curve = nw.new_node(
188
+ Nodes.ResampleCurve, input_kwargs={"Curve": curve_line_1, "Count": 40}
189
+ )
190
+
191
+ spline_parameter = nw.new_node(Nodes.SplineParameter)
192
+
193
+ float_curve = nw.new_node(
194
+ Nodes.FloatCurve, input_kwargs={"Value": spline_parameter.outputs["Factor"]}
195
+ )
196
+ node_utils.assign_curve(
197
+ float_curve.mapping.curves[0],
198
+ [
199
+ (0.0000, 0.0000),
200
+ (0.2, 0.08 * np.random.normal(1.0, 0.15)),
201
+ (0.4, 0.22 * np.random.normal(1.0, 0.2)),
202
+ (0.6, 0.45 * np.random.normal(1.0, 0.2)),
203
+ (0.8, 0.7 * np.random.normal(1.0, 0.1)),
204
+ (1.0000, 1.0000),
205
+ ],
206
+ )
207
+
208
+ multiply = nw.new_node(
209
+ Nodes.Math,
210
+ input_kwargs={0: float_curve, 1: uniform(0.15, 0.4)},
211
+ attrs={"operation": "MULTIPLY"},
212
+ )
213
+
214
+ combine_xyz = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply})
215
+
216
+ set_position = nw.new_node(
217
+ Nodes.SetPosition,
218
+ input_kwargs={"Geometry": resample_curve, "Offset": combine_xyz},
219
+ )
220
+
221
+ normal = nw.new_node(Nodes.InputNormal)
222
+
223
+ align_euler_to_vector = nw.new_node(
224
+ Nodes.AlignEulerToVector, input_kwargs={"Vector": normal}
225
+ )
226
+
227
+ instance_on_points = nw.new_node(
228
+ Nodes.InstanceOnPoints,
229
+ input_kwargs={
230
+ "Points": curve_circle_1.outputs["Curve"],
231
+ "Instance": set_position,
232
+ "Rotation": align_euler_to_vector,
233
+ },
234
+ )
235
+
236
+ random_value_1 = nw.new_node(
237
+ Nodes.RandomValue, input_kwargs={2: -0.2000, 3: 0.2000, "Seed": 2}
238
+ )
239
+
240
+ random_value_2 = nw.new_node(
241
+ Nodes.RandomValue, input_kwargs={2: -0.2000, 3: 0.2000, "Seed": 1}
242
+ )
243
+
244
+ random_value = nw.new_node(Nodes.RandomValue, input_kwargs={2: -0.2000, 3: 0.2000})
245
+
246
+ combine_xyz_2 = nw.new_node(
247
+ Nodes.CombineXYZ,
248
+ input_kwargs={
249
+ "X": random_value_1.outputs[1],
250
+ "Y": random_value_2.outputs[1],
251
+ "Z": random_value.outputs[1],
252
+ },
253
+ )
254
+
255
+ rotate_instances = nw.new_node(
256
+ Nodes.RotateInstances,
257
+ input_kwargs={"Instances": instance_on_points, "Rotation": combine_xyz_2},
258
+ )
259
+
260
+ random_value_3 = nw.new_node(Nodes.RandomValue, input_kwargs={2: 0.8000})
261
+
262
+ scale_instances = nw.new_node(
263
+ Nodes.ScaleInstances,
264
+ input_kwargs={
265
+ "Instances": rotate_instances,
266
+ "Scale": random_value_3.outputs[1],
267
+ },
268
+ )
269
+
270
+ group_output = nw.new_node(
271
+ Nodes.GroupOutput,
272
+ input_kwargs={"Instances": scale_instances},
273
+ attrs={"is_active_output": True},
274
+ )
275
+
276
+
277
+ @node_utils.to_nodegroup(
278
+ "nodegroup_pedal_stem_branch_contour", singleton=False, type="GeometryNodeTree"
279
+ )
280
+ def nodegroup_pedal_stem_branch_contour(nw: NodeWrangler):
281
+ # Code generated using version 2.4.3 of the node_transpiler
282
+
283
+ group_input = nw.new_node(
284
+ Nodes.GroupInput, expose_input=[("NodeSocketGeometry", "Geometry", None)]
285
+ )
286
+
287
+ realize_instances = nw.new_node(
288
+ Nodes.RealizeInstances,
289
+ input_kwargs={"Geometry": group_input.outputs["Geometry"]},
290
+ )
291
+
292
+ pedal_stem_branch_rsample = nw.new_node(
293
+ Nodes.Value, label="pedal_stem_branch_rsample"
294
+ )
295
+ pedal_stem_branch_rsample.outputs[0].default_value = 10.0
296
+
297
+ resample_curve = nw.new_node(
298
+ Nodes.ResampleCurve,
299
+ input_kwargs={"Curve": realize_instances, "Count": pedal_stem_branch_rsample},
300
+ )
301
+
302
+ index = nw.new_node(Nodes.Index)
303
+
304
+ capture_attribute = nw.new_node(
305
+ Nodes.CaptureAttribute,
306
+ input_kwargs={"Geometry": resample_curve, 5: index},
307
+ attrs={"domain": "CURVE", "data_type": "INT"},
308
+ )
309
+
310
+ spline_parameter = nw.new_node(Nodes.SplineParameter)
311
+
312
+ float_curve = nw.new_node(
313
+ Nodes.FloatCurve, input_kwargs={"Value": spline_parameter.outputs["Factor"]}
314
+ )
315
+
316
+ # generate pedal branch contour
317
+ dist = uniform(-0.05, -0.25)
318
+ node_utils.assign_curve(
319
+ float_curve.mapping.curves[0],
320
+ [
321
+ (0.0, 0.0),
322
+ (0.2, 0.2 + (dist + normal(0, 0.05)) / 2.0),
323
+ (0.4, 0.4 + (dist + normal(0, 0.05))),
324
+ (0.6, 0.6 + (dist + normal(0, 0.05)) / 1.2),
325
+ (0.8, 0.8 + (dist + normal(0, 0.05)) / 2.4),
326
+ (1.0, 0.95 + normal(0, 0.05)),
327
+ ],
328
+ )
329
+
330
+ random_value = nw.new_node(
331
+ Nodes.RandomValue,
332
+ input_kwargs={2: 0.05, 3: 0.35, "ID": capture_attribute.outputs[5]},
333
+ )
334
+
335
+ multiply = nw.new_node(
336
+ Nodes.Math,
337
+ input_kwargs={0: float_curve, 1: random_value.outputs[1]},
338
+ attrs={"operation": "MULTIPLY"},
339
+ )
340
+
341
+ combine_xyz = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply})
342
+
343
+ set_position = nw.new_node(
344
+ Nodes.SetPosition,
345
+ input_kwargs={
346
+ "Geometry": capture_attribute.outputs["Geometry"],
347
+ "Offset": combine_xyz,
348
+ },
349
+ )
350
+
351
+ group_output = nw.new_node(
352
+ Nodes.GroupOutput, input_kwargs={"Geometry": set_position}
353
+ )
354
+
355
+
356
+ @node_utils.to_nodegroup(
357
+ "nodegroup_pedal_stem_branch_geometry", singleton=False, type="GeometryNodeTree"
358
+ )
359
+ def nodegroup_pedal_stem_branch_geometry(nw: NodeWrangler):
360
+ # Code generated using version 2.4.3 of the node_transpiler
361
+
362
+ group_input = nw.new_node(
363
+ Nodes.GroupInput,
364
+ expose_input=[
365
+ ("NodeSocketGeometry", "Curve", None),
366
+ ("NodeSocketVectorTranslation", "Translation", (0.0, 0.0, 1.0)),
367
+ ],
368
+ )
369
+
370
+ set_curve_radius_1 = nw.new_node(
371
+ Nodes.SetCurveRadius,
372
+ input_kwargs={"Curve": group_input.outputs["Curve"], "Radius": 1.0},
373
+ )
374
+
375
+ curve_circle_2 = nw.new_node(
376
+ Nodes.CurveCircle,
377
+ input_kwargs={"Radius": uniform(0.001, 0.0025), "Resolution": 4},
378
+ )
379
+
380
+ curve_to_mesh_1 = nw.new_node(
381
+ Nodes.CurveToMesh,
382
+ input_kwargs={
383
+ "Curve": set_curve_radius_1,
384
+ "Profile Curve": curve_circle_2.outputs["Curve"],
385
+ "Fill Caps": True,
386
+ },
387
+ )
388
+
389
+ transform_2 = nw.new_node(
390
+ Nodes.Transform,
391
+ input_kwargs={
392
+ "Geometry": curve_to_mesh_1,
393
+ "Translation": group_input.outputs["Translation"],
394
+ },
395
+ )
396
+
397
+ group_output = nw.new_node(
398
+ Nodes.GroupOutput, input_kwargs={"Geometry": transform_2}
399
+ )
400
+
401
+
402
+ @node_utils.to_nodegroup(
403
+ "nodegroup_pedal_stem_geometry", singleton=False, type="GeometryNodeTree"
404
+ )
405
+ def nodegroup_pedal_stem_geometry(nw: NodeWrangler):
406
+ # Code generated using version 2.4.3 of the node_transpiler
407
+
408
+ group_input = nw.new_node(
409
+ Nodes.GroupInput,
410
+ expose_input=[
411
+ ("NodeSocketVectorTranslation", "End", (0.0, 0.0, 1.0)),
412
+ ("NodeSocketVectorTranslation", "Middle", (0.0, 0.0, 0.5)),
413
+ ("NodeSocketFloatDistance", "Radius", 0.05),
414
+ ],
415
+ )
416
+
417
+ quadratic_bezier = nw.new_node(
418
+ Nodes.QuadraticBezier,
419
+ input_kwargs={
420
+ "Start": (0.0, 0.0, 0.0),
421
+ "Middle": group_input.outputs["Middle"],
422
+ "End": group_input.outputs["End"],
423
+ },
424
+ )
425
+
426
+ set_curve_radius = nw.new_node(
427
+ Nodes.SetCurveRadius,
428
+ input_kwargs={
429
+ "Curve": quadratic_bezier,
430
+ "Radius": group_input.outputs["Radius"],
431
+ },
432
+ )
433
+
434
+ curve_circle = nw.new_node(
435
+ Nodes.CurveCircle, input_kwargs={"Radius": 0.2, "Resolution": 8}
436
+ )
437
+
438
+ curve_to_mesh = nw.new_node(
439
+ Nodes.CurveToMesh,
440
+ input_kwargs={
441
+ "Curve": set_curve_radius,
442
+ "Profile Curve": curve_circle.outputs["Curve"],
443
+ "Fill Caps": True,
444
+ },
445
+ )
446
+
447
+ set_material_2 = nw.new_node(
448
+ Nodes.SetMaterial,
449
+ input_kwargs={
450
+ "Geometry": curve_to_mesh,
451
+ "Material": surface.shaderfunc_to_material(
452
+ simple_whitish.shader_simple_white
453
+ ),
454
+ },
455
+ )
456
+
457
+ group_output = nw.new_node(
458
+ Nodes.GroupOutput,
459
+ input_kwargs={"Geometry": set_material_2, "Curve": quadratic_bezier},
460
+ )
461
+
462
+
463
+ @node_utils.to_nodegroup(
464
+ "nodegroup_pedal_selection", singleton=False, type="GeometryNodeTree"
465
+ )
466
+ def nodegroup_pedal_selection(nw: NodeWrangler, params):
467
+ # Code generated using version 2.4.3 of the node_transpiler
468
+
469
+ random_value = nw.new_node(Nodes.RandomValue, input_kwargs={5: 1})
470
+
471
+ greater_than = nw.new_node(
472
+ Nodes.Math,
473
+ input_kwargs={0: params["random_dropout"], 1: random_value.outputs[1]},
474
+ attrs={"operation": "GREATER_THAN"},
475
+ )
476
+
477
+ index_1 = nw.new_node(Nodes.Index)
478
+
479
+ group_input = nw.new_node(
480
+ Nodes.GroupInput, expose_input=[("NodeSocketFloat", "num_segments", 0.5)]
481
+ )
482
+
483
+ divide = nw.new_node(
484
+ Nodes.Math,
485
+ input_kwargs={0: index_1, 1: group_input.outputs["num_segments"]},
486
+ attrs={"operation": "DIVIDE"},
487
+ )
488
+
489
+ less_than = nw.new_node(
490
+ Nodes.Math,
491
+ input_kwargs={0: divide, 1: params["row_less_than"]},
492
+ attrs={"operation": "LESS_THAN"},
493
+ )
494
+
495
+ greater_than_1 = nw.new_node(
496
+ Nodes.Math,
497
+ input_kwargs={0: divide, 1: params["row_great_than"]},
498
+ attrs={"operation": "GREATER_THAN"},
499
+ )
500
+
501
+ op_and = nw.new_node(
502
+ Nodes.BooleanMath, input_kwargs={0: less_than, 1: greater_than_1}
503
+ )
504
+
505
+ modulo = nw.new_node(
506
+ Nodes.Math,
507
+ input_kwargs={0: index_1, 1: group_input.outputs["num_segments"]},
508
+ attrs={"operation": "MODULO"},
509
+ )
510
+
511
+ less_than_1 = nw.new_node(
512
+ Nodes.Math,
513
+ input_kwargs={0: modulo, 1: params["col_less_than"]},
514
+ attrs={"operation": "LESS_THAN"},
515
+ )
516
+
517
+ greater_than_2 = nw.new_node(
518
+ Nodes.Math,
519
+ input_kwargs={0: modulo, 1: params["col_great_than"]},
520
+ attrs={"operation": "GREATER_THAN"},
521
+ )
522
+
523
+ op_and_1 = nw.new_node(
524
+ Nodes.BooleanMath, input_kwargs={0: less_than_1, 1: greater_than_2}
525
+ )
526
+
527
+ nand = nw.new_node(
528
+ Nodes.BooleanMath,
529
+ input_kwargs={0: op_and, 1: op_and_1},
530
+ attrs={"operation": "NAND"},
531
+ )
532
+
533
+ op_and_2 = nw.new_node(Nodes.BooleanMath, input_kwargs={0: greater_than, 1: nand})
534
+
535
+ group_output = nw.new_node(Nodes.GroupOutput, input_kwargs={"Boolean": op_and_2})
536
+
537
+
538
+ @node_utils.to_nodegroup(
539
+ "nodegroup_stem_geometry", singleton=False, type="GeometryNodeTree"
540
+ )
541
+ def nodegroup_stem_geometry(nw: NodeWrangler, params):
542
+ # Code generated using version 2.4.3 of the node_transpiler
543
+
544
+ group_input = nw.new_node(
545
+ Nodes.GroupInput,
546
+ expose_input=[
547
+ ("NodeSocketGeometry", "Curve", None),
548
+ ]
549
+ )
550
+
551
+ spline_parameter = nw.new_node(Nodes.SplineParameter)
552
+
553
+ value = nw.new_node(Nodes.Value)
554
+ value.outputs[0].default_value = params["stem_map_range"]
555
+
556
+ map_range = nw.new_node(
557
+ Nodes.MapRange,
558
+ input_kwargs={"Value": spline_parameter.outputs["Factor"], 3: 0.4, 4: value},
559
+ )
560
+
561
+ set_curve_radius_2 = nw.new_node(
562
+ Nodes.SetCurveRadius,
563
+ input_kwargs={
564
+ "Curve": group_input.outputs["Curve"],
565
+ "Radius": map_range.outputs["Result"],
566
+ },
567
+ )
568
+
569
+ stem_radius = nw.new_node(Nodes.Value, label="stem_radius")
570
+ stem_radius.outputs[0].default_value = params["stem_radius"]
571
+
572
+ curve_circle_3 = nw.new_node(
573
+ Nodes.CurveCircle, input_kwargs={"Radius": stem_radius}
574
+ )
575
+
576
+ curve_to_mesh_2 = nw.new_node(
577
+ Nodes.CurveToMesh,
578
+ input_kwargs={
579
+ "Curve": set_curve_radius_2,
580
+ "Profile Curve": curve_circle_3.outputs["Curve"],
581
+ "Fill Caps": True,
582
+ },
583
+ )
584
+
585
+ set_material = nw.new_node(
586
+ Nodes.SetMaterial,
587
+ input_kwargs={
588
+ "Geometry": curve_to_mesh_2,
589
+ "Material": surface.shaderfunc_to_material(
590
+ simple_greenery.shader_simple_greenery
591
+ ),
592
+ },
593
+ )
594
+
595
+ group_output = nw.new_node(
596
+ Nodes.GroupOutput,
597
+ input_kwargs={"Mesh": tag_nodegroup(nw, set_material, "stem")},
598
+ )
599
+
600
+
601
+ @node_utils.to_nodegroup(
602
+ "nodegroup_pedal_stem", singleton=False, type="GeometryNodeTree"
603
+ )
604
+ def nodegroup_pedal_stem(nw: NodeWrangler, params):
605
+ # Code generated using version 2.4.3 of the node_transpiler
606
+ pedal_stem_top_point = nw.new_node(Nodes.Vector, label="pedal_stem_top_point")
607
+ pedal_stem_top_point.vector = (0.0, 0.0, 1.0)
608
+
609
+ pedal_stem_mid_point = nw.new_node(Nodes.Vector, label="pedal_stem_mid_point")
610
+ pedal_stem_mid_point.vector = (
611
+ params["pedal_stem_mid_point_x"],
612
+ params["pedal_stem_mid_point_y"],
613
+ 0.5
614
+ )
615
+
616
+ pedal_stem_radius = nw.new_node(Nodes.Value, label="pedal_stem_radius")
617
+ pedal_stem_radius.outputs[0].default_value = params["pedal_stem_radius"]
618
+
619
+ pedal_stem_geometry = nw.new_node(
620
+ nodegroup_pedal_stem_geometry().name,
621
+ input_kwargs={
622
+ "End": pedal_stem_top_point,
623
+ "Middle": pedal_stem_mid_point,
624
+ "Radius": pedal_stem_radius,
625
+ },
626
+ )
627
+
628
+ pedal_stem_top_radius = nw.new_node(Nodes.Value, label="pedal_stem_top_radius")
629
+ pedal_stem_top_radius.outputs[0].default_value = params["pedal_stem_top_radius"]
630
+
631
+ pedal_stem_branch_shape = nw.new_node(
632
+ nodegroup_pedal_stem_branch_shape().name,
633
+ input_kwargs={"Radius": pedal_stem_top_radius},
634
+ )
635
+
636
+ pedal_stem_branch_geometry = nw.new_node(
637
+ nodegroup_pedal_stem_branch_geometry().name,
638
+ input_kwargs={
639
+ "Curve": pedal_stem_branch_shape,
640
+ "Translation": pedal_stem_top_point,
641
+ },
642
+ )
643
+
644
+ set_material_3 = nw.new_node(
645
+ Nodes.SetMaterial,
646
+ input_kwargs={
647
+ "Geometry": pedal_stem_branch_geometry,
648
+ "Material": surface.shaderfunc_to_material(
649
+ simple_whitish.shader_simple_white
650
+ ),
651
+ },
652
+ )
653
+
654
+ resample_curve = nw.new_node(
655
+ Nodes.ResampleCurve,
656
+ input_kwargs={"Curve": pedal_stem_geometry.outputs["Curve"]},
657
+ )
658
+
659
+ pedal_stem_end_geometry = nw.new_node(
660
+ nodegroup_pedal_stem_end_geometry().name,
661
+ input_kwargs={"Points": resample_curve},
662
+ )
663
+
664
+ pedal_stem_head_geometry = nw.new_node(
665
+ nodegroup_pedal_stem_head_geometry().name,
666
+ input_kwargs={
667
+ "Translation": pedal_stem_top_point,
668
+ "Radius": pedal_stem_top_radius,
669
+ },
670
+ )
671
+
672
+ join_geometry = nw.new_node(
673
+ Nodes.JoinGeometry,
674
+ input_kwargs={
675
+ "Geometry": [
676
+ pedal_stem_geometry.outputs["Geometry"],
677
+ set_material_3,
678
+ pedal_stem_end_geometry,
679
+ pedal_stem_head_geometry,
680
+ ]
681
+ },
682
+ )
683
+
684
+ group_output = nw.new_node(
685
+ Nodes.GroupOutput, input_kwargs={"Geometry": join_geometry}
686
+ )
687
+
688
+
689
+ @node_utils.to_nodegroup(
690
+ "nodegroup_flower_geometry", singleton=False, type="GeometryNodeTree"
691
+ )
692
+ def nodegroup_flower_geometry(nw: NodeWrangler, params):
693
+ # Code generated using version 2.4.3 of the node_transpiler
694
+
695
+ num_core_segments = nw.new_node(
696
+ Nodes.Integer, label="num_core_segments", attrs={"integer": 10}
697
+ )
698
+ num_core_segments.integer = params["flower_num_core_segments"]
699
+
700
+ num_core_rings = nw.new_node(
701
+ Nodes.Integer, label="num_core_rings", attrs={"integer": 10}
702
+ )
703
+ num_core_rings.integer = params["flower_num_core_rings"]
704
+
705
+ uv_sphere_2 = nw.new_node(
706
+ Nodes.MeshUVSphere,
707
+ input_kwargs={
708
+ "Segments": num_core_segments,
709
+ "Rings": num_core_rings,
710
+ "Radius": params["flower_radius"],
711
+ },
712
+ )
713
+
714
+ flower_core_shape = nw.new_node(Nodes.Vector, label="flower_core_shape")
715
+ flower_core_shape.vector = (params["flower_core_shape_x"], params["flower_core_shape_y"], params["flower_core_shape_z"])
716
+
717
+ transform = nw.new_node(
718
+ Nodes.Transform,
719
+ input_kwargs={"Geometry": uv_sphere_2, "Scale": flower_core_shape},
720
+ )
721
+
722
+ selection_params = {
723
+ "random_dropout": params["random_dropout"],
724
+ "row_less_than": int(params["row_less_than"] * num_core_rings.integer),
725
+ "row_great_than": int(params["row_great_than"] * num_core_rings.integer),
726
+ "col_less_than": int(params["col_less_than"] * num_core_segments.integer),
727
+ "col_great_than": int(params["col_less_than"] * num_core_segments.integer),
728
+ }
729
+ pedal_selection = nw.new_node(
730
+ nodegroup_pedal_selection(params=selection_params).name,
731
+ input_kwargs={"num_segments": num_core_segments},
732
+ )
733
+
734
+ group_input = nw.new_node(
735
+ Nodes.GroupInput, expose_input=[("NodeSocketGeometry", "Instance", None)]
736
+ )
737
+
738
+ normal_1 = nw.new_node(Nodes.InputNormal)
739
+
740
+ align_euler_to_vector_1 = nw.new_node(
741
+ Nodes.AlignEulerToVector, input_kwargs={"Vector": normal_1}, attrs={"axis": "Z"}
742
+ )
743
+
744
+ random_value_1 = nw.new_node(Nodes.RandomValue, input_kwargs={2: 0.4, 3: 0.7})
745
+
746
+ multiply = nw.new_node(
747
+ Nodes.Math,
748
+ input_kwargs={0: random_value_1.outputs[1]},
749
+ attrs={"operation": "MULTIPLY"},
750
+ )
751
+
752
+ instance_on_points_1 = nw.new_node(
753
+ Nodes.InstanceOnPoints,
754
+ input_kwargs={
755
+ "Points": transform,
756
+ "Selection": pedal_selection,
757
+ "Instance": group_input.outputs["Instance"],
758
+ "Rotation": align_euler_to_vector_1,
759
+ "Scale": multiply,
760
+ },
761
+ )
762
+
763
+ realize_instances_1 = nw.new_node(
764
+ Nodes.RealizeInstances, input_kwargs={"Geometry": instance_on_points_1}
765
+ )
766
+
767
+ set_material = nw.new_node(
768
+ Nodes.SetMaterial,
769
+ input_kwargs={
770
+ "Geometry": transform,
771
+ "Material": surface.shaderfunc_to_material(
772
+ simple_whitish.shader_simple_white
773
+ ),
774
+ },
775
+ )
776
+
777
+ join_geometry_1 = nw.new_node(
778
+ Nodes.JoinGeometry,
779
+ input_kwargs={"Geometry": [realize_instances_1, set_material]},
780
+ )
781
+
782
+ group_output = nw.new_node(
783
+ Nodes.GroupOutput,
784
+ input_kwargs={"Geometry": tag_nodegroup(nw, join_geometry_1, "flower")},
785
+ )
786
+
787
+
788
+ @node_utils.to_nodegroup(
789
+ "nodegroup_flower_on_stem", singleton=False, type="GeometryNodeTree"
790
+ )
791
+ def nodegroup_flower_on_stem(nw: NodeWrangler):
792
+ # Code generated using version 2.4.3 of the node_transpiler
793
+
794
+ group_input = nw.new_node(
795
+ Nodes.GroupInput,
796
+ expose_input=[
797
+ ("NodeSocketGeometry", "Points", None),
798
+ ("NodeSocketGeometry", "Instance", None),
799
+ ],
800
+ )
801
+
802
+ endpoint_selection = nw.new_node(
803
+ "GeometryNodeCurveEndpointSelection", input_kwargs={"Start Size": 0}
804
+ )
805
+
806
+ curve_tangent = nw.new_node(Nodes.CurveTangent)
807
+
808
+ align_euler_to_vector_2 = nw.new_node(
809
+ Nodes.AlignEulerToVector,
810
+ input_kwargs={"Vector": curve_tangent},
811
+ attrs={"axis": "Z"},
812
+ )
813
+
814
+ instance_on_points_2 = nw.new_node(
815
+ Nodes.InstanceOnPoints,
816
+ input_kwargs={
817
+ "Points": group_input.outputs["Points"],
818
+ "Selection": endpoint_selection,
819
+ "Instance": group_input.outputs["Instance"],
820
+ "Rotation": align_euler_to_vector_2,
821
+ },
822
+ )
823
+
824
+ realize_instances_2 = nw.new_node(
825
+ Nodes.RealizeInstances, input_kwargs={"Geometry": instance_on_points_2}
826
+ )
827
+
828
+ group_output = nw.new_node(
829
+ Nodes.GroupOutput, input_kwargs={"Instances": realize_instances_2}
830
+ )
831
+
832
+
833
+ def geometry_dandelion_nodes(nw: NodeWrangler, **kwargs):
834
+ # Code generated using version 2.4.3 of the node_transpiler
835
+
836
+ quadratic_bezier_1 = nw.new_node(
837
+ Nodes.QuadraticBezier,
838
+ input_kwargs={
839
+ "Start": (0.0, 0.0, 0.0),
840
+ "Middle": (kwargs["bezier_middle_x"], kwargs["bezier_middle_y"], 0.5),
841
+ "End": (kwargs["bezier_end_x"], kwargs["bezier_end_y"], 1.0),
842
+ },
843
+ )
844
+
845
+ resample_curve = nw.new_node(
846
+ Nodes.ResampleCurve, input_kwargs={"Curve": quadratic_bezier_1}
847
+ )
848
+
849
+ pedal_stem = nw.new_node(
850
+ nodegroup_pedal_stem(kwargs).name,
851
+ input_kwargs={},
852
+ )
853
+
854
+ geometry_to_instance = nw.new_node(
855
+ "GeometryNodeGeometryToInstance", input_kwargs={"Geometry": pedal_stem}
856
+ )
857
+
858
+ flower_geometry = nw.new_node(
859
+ nodegroup_flower_geometry(kwargs).name,
860
+ input_kwargs={"Instance": geometry_to_instance},
861
+ )
862
+
863
+ geometry_to_instance_1 = nw.new_node(
864
+ "GeometryNodeGeometryToInstance", input_kwargs={"Geometry": flower_geometry}
865
+ )
866
+
867
+ value_2 = nw.new_node(Nodes.Value)
868
+ value_2.outputs[0].default_value = kwargs["transform_scale"]
869
+
870
+ transform_3 = nw.new_node(
871
+ Nodes.Transform,
872
+ input_kwargs={"Geometry": geometry_to_instance_1, "Scale": value_2},
873
+ )
874
+
875
+ flower_on_stem = nw.new_node(
876
+ nodegroup_flower_on_stem().name,
877
+ input_kwargs={"Points": resample_curve, "Instance": transform_3},
878
+ )
879
+
880
+ stem_geometry = nw.new_node(
881
+ nodegroup_stem_geometry(kwargs).name,
882
+ input_kwargs={
883
+ "Curve": quadratic_bezier_1,
884
+ }
885
+ )
886
+
887
+ join_geometry_2 = nw.new_node(
888
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [flower_on_stem, stem_geometry]}
889
+ )
890
+
891
+ realize_instances = nw.new_node(
892
+ Nodes.RealizeInstances, input_kwargs={"Geometry": join_geometry_2}
893
+ )
894
+
895
+ group_output = nw.new_node(
896
+ Nodes.GroupOutput, input_kwargs={"Geometry": realize_instances}
897
+ )
898
+
899
+
900
+ def geometry_dandelion_seed_nodes(nw: NodeWrangler, **kwargs):
901
+ # Code generated using version 2.4.3 of the node_transpiler
902
+
903
+ pedal_stem = nw.new_node(nodegroup_pedal_stem().name)
904
+
905
+ geometry_to_instance = nw.new_node(
906
+ "GeometryNodeGeometryToInstance", input_kwargs={"Geometry": pedal_stem}
907
+ )
908
+
909
+ group_output = nw.new_node(
910
+ Nodes.GroupOutput, input_kwargs={"Geometry": geometry_to_instance}
911
+ )
912
+
913
+ flower_modes_dict = {
914
+ 0: "full_flower",
915
+ 1: "no_flower",
916
+ 2: "sparse_flower",
917
+ }
918
+ class DandelionFactory(AssetFactory):
919
+ def __init__(self, factory_seed, coarse=False):
920
+ super(DandelionFactory, self).__init__(factory_seed, coarse=coarse)
921
+ self.get_params_dict()
922
+
923
+ with FixedSeed(factory_seed):
924
+ self.sample_parameters()
925
+
926
+ def get_params_dict(self):
927
+ # list all the parameters (key:name, value: [type, range]) used in this generator
928
+ self.params_dict = {
929
+ "flower_mode": ["discrete", (0, 1, 2)],
930
+ "random_dropout": ["continuous", (0.2, 0.6)],
931
+ "row_less_than": ["continuous", (0.0, 1.0)],
932
+ "col_less_than": ["continuous", (0.0, 1.0)],
933
+ "row_great_than": ["continuous", (0.0, 1.0)],
934
+ "col_great_than": ["continuous", (0.0, 1.0)],
935
+ "bezier_middle_x": ["continuous", (-0.6, 0.6)],
936
+ "bezier_middle_y": ["continuous", (-0.6, 0.6)],
937
+ "bezier_end_x": ["continuous", (-0.6, 0.6)],
938
+ "bezier_end_y": ["continuous", (-0.6, 0.6)],
939
+ "flower_num_core_segments": ["discrete", (8, 15, 20, 25)],
940
+ "flower_num_core_rings": ["discrete", (8, 15, 20)],
941
+ "transform_scale": ["continuous", (-0.7, -0.1)],
942
+ "stem_map_range": ["continuous", (0.1, 0.6)],
943
+ "stem_radius": ["continuous", (0.01, 0.03)],
944
+ }
945
+
946
+ def sample_parameters(self):
947
+ # sample all the parameters
948
+ flower_mode = flower_modes_dict[randint(0, 2)]
949
+ if flower_mode == "full_flower":
950
+ random_dropout = 1.0
951
+ row_less_than = 0.0
952
+ row_great_than = 0.0
953
+ col_less_than = 0.0
954
+ col_great_than = 0.0
955
+ elif flower_mode == "no_flower":
956
+ random_dropout = 0.0
957
+ row_less_than = 1.0
958
+ row_great_than = 0.0
959
+ col_less_than = 1.0
960
+ col_great_than = 0.0
961
+ elif flower_mode == "sparse_flower":
962
+ random_dropout = uniform(0.2, 0.6)
963
+ row_less_than = 0.0
964
+ row_great_than = 0.0
965
+ col_less_than = 0.0
966
+ col_great_than = 0.0
967
+ else:
968
+ raise ValueError("Invalid flower mode")
969
+ self.params = {
970
+ "flower_mode": flower_mode,
971
+ "random_dropout": random_dropout,
972
+ "row_less_than": row_less_than,
973
+ "row_great_than": row_great_than,
974
+ "col_less_than": col_less_than,
975
+ "col_great_than": col_great_than,
976
+ "bezier_middle_x": normal(0.0, 0.1),
977
+ "bezier_middle_y": normal(0.0, 0.1),
978
+ "bezier_end_x": normal(0.0, 0.1),
979
+ "bezier_end_y": normal(0.0, 0.1),
980
+ "pedal_stem_mid_point_x": normal(0.0, 0.05),
981
+ "pedal_stem_mid_point_y": normal(0.0, 0.05),
982
+ "pedal_stem_radius": uniform(0.02, 0.045),
983
+ "pedal_stem_top_radius": uniform(0.005, 0.008),
984
+ "flower_num_core_segments": randint(8, 25),
985
+ "flower_num_core_rings": randint(8, 20),
986
+ "flower_radius": uniform(0.02, 0.05),
987
+ "flower_core_shape_x": uniform(0.8, 1.2),
988
+ "flower_core_shape_y": uniform(0.8, 1.2),
989
+ "flower_core_shape_z": uniform(0.5, 0.8),
990
+ "transform_scale": uniform(-0.5, -0.15),
991
+ "stem_map_range": uniform(0.2, 0.4),
992
+ "stem_radius": uniform(0.01, 0.024),
993
+ }
994
+
995
+ def fix_unused_params(self, params):
996
+ return params
997
+
998
+ def update_params(self, params):
999
+ # update the parameters in the node graph
1000
+ flower_mode = flower_modes_dict[params["flower_mode"]]
1001
+ if flower_mode == "full_flower":
1002
+ random_dropout = uniform(0.7, 1.0)
1003
+ row_less_than = 0.0
1004
+ row_great_than = 0.0
1005
+ col_less_than = 0.0
1006
+ col_great_than = 0.0
1007
+ elif flower_mode == "no_flower":
1008
+ random_dropout = 0.0
1009
+ row_less_than = 1.0
1010
+ row_great_than = 0.0
1011
+ col_less_than = 1.0
1012
+ col_great_than = 0.0
1013
+ elif flower_mode == "sparse_flower":
1014
+ random_dropout = params["random_dropout"]
1015
+ row_less_than = params["row_less_than"]
1016
+ row_great_than = params["row_great_than"]
1017
+ col_less_than = params["col_less_than"]
1018
+ col_great_than = params["col_great_than"]
1019
+ else:
1020
+ raise ValueError("Invalid flower mode")
1021
+ params = {
1022
+ "flower_mode": flower_mode,
1023
+ "random_dropout": random_dropout,
1024
+ "row_less_than": row_less_than,
1025
+ "row_great_than": row_great_than,
1026
+ "col_less_than": col_less_than,
1027
+ "col_great_than": col_great_than,
1028
+ "bezier_middle_x": params["bezier_middle_x"],
1029
+ "bezier_middle_y": params["bezier_middle_y"],
1030
+ "bezier_end_x": params["bezier_end_x"],
1031
+ "bezier_end_y": params["bezier_end_y"],
1032
+ "flower_num_core_segments": int(params["flower_num_core_segments"]),
1033
+ "flower_num_core_rings": int(params["flower_num_core_rings"]),
1034
+ "flower_radius": uniform(0.02, 0.05),
1035
+ "flower_core_shape_x": uniform(0.8, 1.2),
1036
+ "flower_core_shape_y": uniform(0.8, 1.2),
1037
+ "flower_core_shape_z": uniform(0.5, 0.8),
1038
+ "pedal_stem_mid_point_x": normal(0.0, 0.05),
1039
+ "pedal_stem_mid_point_y": normal(0.0, 0.05),
1040
+ "pedal_stem_radius": uniform(0.02, 0.045),
1041
+ "pedal_stem_top_radius": uniform(0.005, 0.008),
1042
+ "transform_scale": params["transform_scale"],
1043
+ "stem_map_range": params["stem_map_range"],
1044
+ "stem_radius": params["stem_radius"],
1045
+ }
1046
+ self.params.update(params)
1047
+
1048
+
1049
+ def create_asset(self, **params):
1050
+ bpy.ops.mesh.primitive_plane_add(
1051
+ size=1,
1052
+ enter_editmode=False,
1053
+ align="WORLD",
1054
+ location=(0, 0, 0),
1055
+ scale=(1, 1, 1),
1056
+ )
1057
+ obj = bpy.context.active_object
1058
+
1059
+ surface.add_geomod(
1060
+ obj,
1061
+ geometry_dandelion_nodes,
1062
+ apply=True,
1063
+ attributes=[],
1064
+ input_kwargs=self.params,
1065
+ )
1066
+ tag_object(obj, "dandelion")
1067
+ return obj
1068
+
1069
+
1070
+ class DandelionSeedFactory(AssetFactory):
1071
+ def __init__(self, factory_seed, coarse=False):
1072
+ super(DandelionSeedFactory, self).__init__(factory_seed, coarse=coarse)
1073
+
1074
+ def create_asset(self, **params):
1075
+ bpy.ops.mesh.primitive_plane_add(
1076
+ size=1,
1077
+ enter_editmode=False,
1078
+ align="WORLD",
1079
+ location=(0, 0, 0),
1080
+ scale=(1, 1, 1),
1081
+ )
1082
+ obj = bpy.context.active_object
1083
+
1084
+ surface.add_geomod(
1085
+ obj,
1086
+ geometry_dandelion_seed_nodes,
1087
+ apply=True,
1088
+ attributes=[],
1089
+ input_kwargs=params,
1090
+ )
1091
+ tag_object(obj, "seed")
1092
+ return obj
1093
+
1094
+
1095
+ if __name__ == "__main__":
1096
+ f = DandelionSeedFactory(0)
1097
+ obj = f.create_asset()
core/assets/flower.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Princeton University.
2
+ # This source code is licensed under the BSD 3-Clause license found in the LICENSE file in the root directory of this source tree.
3
+
4
+ # Authors: Alexander Raistrick, Alejandro Newell
5
+
6
+
7
+ # Code generated using version v2.0.1 of the node_transpiler
8
+ import bpy
9
+ import numpy as np
10
+ from numpy.random import normal, uniform
11
+
12
+ import infinigen
13
+ from infinigen.core import surface
14
+ from infinigen.core.nodes import node_utils
15
+ from infinigen.core.nodes.node_wrangler import Nodes
16
+ from infinigen.core.placement.factory import AssetFactory
17
+ from infinigen.core.tagging import tag_nodegroup, tag_object
18
+ from infinigen.core.util import blender as butil
19
+ from infinigen.core.util import color
20
+ from infinigen.core.util.math import FixedSeed, dict_lerp
21
+
22
+
23
+ @node_utils.to_nodegroup("nodegroup_polar_to_cart_old", singleton=True)
24
+ def nodegroup_polar_to_cart_old(nw):
25
+ group_input = nw.new_node(
26
+ Nodes.GroupInput,
27
+ expose_input=[
28
+ ("NodeSocketVector", "Addend", (0.0, 0.0, 0.0)),
29
+ ("NodeSocketFloat", "Value", 0.5),
30
+ ("NodeSocketVector", "Vector", (0.0, 0.0, 0.0)),
31
+ ],
32
+ )
33
+
34
+ cosine = nw.new_node(
35
+ Nodes.Math,
36
+ input_kwargs={0: group_input.outputs["Value"]},
37
+ attrs={"operation": "COSINE"},
38
+ )
39
+
40
+ sine = nw.new_node(
41
+ Nodes.Math,
42
+ input_kwargs={0: group_input.outputs["Value"]},
43
+ attrs={"operation": "SINE"},
44
+ )
45
+
46
+ combine_xyz_4 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Y": cosine, "Z": sine})
47
+
48
+ multiply_add = nw.new_node(
49
+ Nodes.VectorMath,
50
+ input_kwargs={
51
+ 0: group_input.outputs["Vector"],
52
+ 1: combine_xyz_4,
53
+ 2: group_input.outputs["Addend"],
54
+ },
55
+ attrs={"operation": "MULTIPLY_ADD"},
56
+ )
57
+
58
+ group_output = nw.new_node(
59
+ Nodes.GroupOutput, input_kwargs={"Vector": multiply_add.outputs["Vector"]}
60
+ )
61
+
62
+
63
+ @node_utils.to_nodegroup("nodegroup_follow_curve", singleton=True)
64
+ def nodegroup_follow_curve(nw):
65
+ group_input = nw.new_node(
66
+ Nodes.GroupInput,
67
+ expose_input=[
68
+ ("NodeSocketGeometry", "Geometry", None),
69
+ ("NodeSocketGeometry", "Curve", None),
70
+ ("NodeSocketFloat", "Curve Min", 0.5),
71
+ ("NodeSocketFloat", "Curve Max", 1.0),
72
+ ],
73
+ )
74
+
75
+ position = nw.new_node(Nodes.InputPosition)
76
+
77
+ capture_attribute = nw.new_node(
78
+ Nodes.CaptureAttribute,
79
+ input_kwargs={"Geometry": group_input.outputs["Geometry"], 1: position},
80
+ attrs={"data_type": "FLOAT_VECTOR"},
81
+ )
82
+
83
+ separate_xyz = nw.new_node(
84
+ Nodes.SeparateXYZ,
85
+ input_kwargs={"Vector": capture_attribute.outputs["Attribute"]},
86
+ )
87
+
88
+ attribute_statistic = nw.new_node(
89
+ Nodes.AttributeStatistic,
90
+ input_kwargs={
91
+ "Geometry": capture_attribute.outputs["Geometry"],
92
+ 2: separate_xyz.outputs["Z"],
93
+ },
94
+ )
95
+
96
+ map_range = nw.new_node(
97
+ Nodes.MapRange,
98
+ input_kwargs={
99
+ "Value": separate_xyz.outputs["Z"],
100
+ 1: attribute_statistic.outputs["Min"],
101
+ 2: attribute_statistic.outputs["Max"],
102
+ 3: group_input.outputs["Curve Min"],
103
+ 4: group_input.outputs["Curve Max"],
104
+ },
105
+ )
106
+
107
+ curve_length = nw.new_node(
108
+ Nodes.CurveLength, input_kwargs={"Curve": group_input.outputs["Curve"]}
109
+ )
110
+
111
+ multiply = nw.new_node(
112
+ Nodes.Math,
113
+ input_kwargs={0: map_range.outputs["Result"], 1: curve_length},
114
+ attrs={"operation": "MULTIPLY"},
115
+ )
116
+
117
+ sample_curve = nw.new_node(
118
+ Nodes.SampleCurve,
119
+ input_kwargs={"Curves": group_input.outputs["Curve"], "Length": multiply},
120
+ attrs={"mode": "LENGTH"},
121
+ )
122
+
123
+ cross_product = nw.new_node(
124
+ Nodes.VectorMath,
125
+ input_kwargs={
126
+ 0: sample_curve.outputs["Tangent"],
127
+ 1: sample_curve.outputs["Normal"],
128
+ },
129
+ attrs={"operation": "CROSS_PRODUCT"},
130
+ )
131
+
132
+ scale = nw.new_node(
133
+ Nodes.VectorMath,
134
+ input_kwargs={
135
+ 0: cross_product.outputs["Vector"],
136
+ "Scale": separate_xyz.outputs["X"],
137
+ },
138
+ attrs={"operation": "SCALE"},
139
+ )
140
+
141
+ scale_1 = nw.new_node(
142
+ Nodes.VectorMath,
143
+ input_kwargs={
144
+ 0: sample_curve.outputs["Normal"],
145
+ "Scale": separate_xyz.outputs["Y"],
146
+ },
147
+ attrs={"operation": "SCALE"},
148
+ )
149
+
150
+ add = nw.new_node(
151
+ Nodes.VectorMath,
152
+ input_kwargs={0: scale.outputs["Vector"], 1: scale_1.outputs["Vector"]},
153
+ )
154
+
155
+ set_position = nw.new_node(
156
+ Nodes.SetPosition,
157
+ input_kwargs={
158
+ "Geometry": capture_attribute.outputs["Geometry"],
159
+ "Position": sample_curve.outputs["Position"],
160
+ "Offset": add.outputs["Vector"],
161
+ },
162
+ )
163
+
164
+ group_output = nw.new_node(
165
+ Nodes.GroupOutput, input_kwargs={"Geometry": set_position}
166
+ )
167
+
168
+
169
+ @node_utils.to_nodegroup("nodegroup_norm_index", singleton=True)
170
+ def nodegroup_norm_index(nw):
171
+ index = nw.new_node(Nodes.Index)
172
+
173
+ group_input = nw.new_node(
174
+ Nodes.GroupInput, expose_input=[("NodeSocketInt", "Count", 0)]
175
+ )
176
+
177
+ divide = nw.new_node(
178
+ Nodes.Math,
179
+ input_kwargs={0: index, 1: group_input.outputs["Count"]},
180
+ attrs={"operation": "DIVIDE"},
181
+ )
182
+
183
+ group_output = nw.new_node(Nodes.GroupOutput, input_kwargs={"T": divide})
184
+
185
+
186
+ @node_utils.to_nodegroup("nodegroup_flower_petal", singleton=True)
187
+ def nodegroup_flower_petal(nw):
188
+ group_input = nw.new_node(
189
+ Nodes.GroupInput,
190
+ expose_input=[
191
+ ("NodeSocketGeometry", "Geometry", None),
192
+ ("NodeSocketFloat", "Length", 0.2),
193
+ ("NodeSocketFloat", "Point", 1.0),
194
+ ("NodeSocketFloat", "Point height", 0.5),
195
+ ("NodeSocketFloat", "Bevel", 6.8),
196
+ ("NodeSocketFloat", "Base width", 0.2),
197
+ ("NodeSocketFloat", "Upper width", 0.3),
198
+ ("NodeSocketInt", "Resolution H", 8),
199
+ ("NodeSocketInt", "Resolution V", 4),
200
+ ("NodeSocketFloat", "Wrinkle", 0.1),
201
+ ("NodeSocketFloat", "Curl", 0.0),
202
+ ],
203
+ )
204
+
205
+ multiply_add = nw.new_node(
206
+ Nodes.Math,
207
+ input_kwargs={0: group_input.outputs["Resolution H"], 1: 2.0, 2: 1.0},
208
+ attrs={"operation": "MULTIPLY_ADD"},
209
+ )
210
+
211
+ grid = nw.new_node(
212
+ Nodes.MeshGrid,
213
+ input_kwargs={
214
+ "Vertices X": group_input.outputs["Resolution V"],
215
+ "Vertices Y": multiply_add,
216
+ },
217
+ )
218
+
219
+ position = nw.new_node(Nodes.InputPosition)
220
+
221
+ capture_attribute = nw.new_node(
222
+ Nodes.CaptureAttribute,
223
+ input_kwargs={"Geometry": grid, 1: position},
224
+ attrs={"data_type": "FLOAT_VECTOR"},
225
+ )
226
+
227
+ separate_xyz = nw.new_node(
228
+ Nodes.SeparateXYZ,
229
+ input_kwargs={"Vector": capture_attribute.outputs["Attribute"]},
230
+ )
231
+
232
+ multiply = nw.new_node(
233
+ Nodes.Math,
234
+ input_kwargs={0: separate_xyz.outputs["X"], 1: 0.05},
235
+ attrs={"operation": "MULTIPLY"},
236
+ )
237
+
238
+ combine_xyz = nw.new_node(
239
+ Nodes.CombineXYZ, input_kwargs={"X": multiply, "Y": separate_xyz.outputs["Y"]}
240
+ )
241
+
242
+ noise_texture = nw.new_node(
243
+ Nodes.NoiseTexture,
244
+ input_kwargs={
245
+ "Vector": combine_xyz,
246
+ "Scale": 7.9,
247
+ "Detail": 0.0,
248
+ "Distortion": 0.2,
249
+ },
250
+ attrs={"noise_dimensions": "2D"},
251
+ )
252
+
253
+ add = nw.new_node(
254
+ Nodes.Math, input_kwargs={0: noise_texture.outputs["Fac"], 1: -0.5}
255
+ )
256
+
257
+ multiply_1 = nw.new_node(
258
+ Nodes.Math,
259
+ input_kwargs={0: add, 1: group_input.outputs["Wrinkle"]},
260
+ attrs={"operation": "MULTIPLY"},
261
+ )
262
+
263
+ separate_xyz_1 = nw.new_node(
264
+ Nodes.SeparateXYZ,
265
+ input_kwargs={"Vector": capture_attribute.outputs["Attribute"]},
266
+ )
267
+
268
+ add_1 = nw.new_node(Nodes.Math, input_kwargs={0: separate_xyz_1.outputs["X"]})
269
+
270
+ absolute = nw.new_node(
271
+ Nodes.Math,
272
+ input_kwargs={0: separate_xyz_1.outputs["Y"]},
273
+ attrs={"operation": "ABSOLUTE"},
274
+ )
275
+
276
+ multiply_2 = nw.new_node(
277
+ Nodes.Math, input_kwargs={0: absolute, 1: 2.0}, attrs={"operation": "MULTIPLY"}
278
+ )
279
+
280
+ power = nw.new_node(
281
+ Nodes.Math,
282
+ input_kwargs={0: multiply_2, 1: group_input.outputs["Bevel"]},
283
+ attrs={"operation": "POWER"},
284
+ )
285
+
286
+ multiply_add_1 = nw.new_node(
287
+ Nodes.Math,
288
+ input_kwargs={0: power, 1: -1.0, 2: 1.0},
289
+ attrs={"operation": "MULTIPLY_ADD"},
290
+ )
291
+
292
+ multiply_3 = nw.new_node(
293
+ Nodes.Math,
294
+ input_kwargs={0: add_1, 1: multiply_add_1},
295
+ attrs={"operation": "MULTIPLY"},
296
+ )
297
+
298
+ multiply_add_2 = nw.new_node(
299
+ Nodes.Math,
300
+ input_kwargs={
301
+ 0: multiply_3,
302
+ 1: group_input.outputs["Upper width"],
303
+ 2: group_input.outputs["Base width"],
304
+ },
305
+ attrs={"operation": "MULTIPLY_ADD"},
306
+ )
307
+
308
+ multiply_4 = nw.new_node(
309
+ Nodes.Math,
310
+ input_kwargs={0: separate_xyz_1.outputs["Y"], 1: multiply_add_2},
311
+ attrs={"operation": "MULTIPLY"},
312
+ )
313
+
314
+ power_1 = nw.new_node(
315
+ Nodes.Math,
316
+ input_kwargs={0: absolute, 1: group_input.outputs["Point"]},
317
+ attrs={"operation": "POWER"},
318
+ )
319
+
320
+ multiply_add_3 = nw.new_node(
321
+ Nodes.Math,
322
+ input_kwargs={0: power_1, 1: -1.0, 2: 1.0},
323
+ attrs={"operation": "MULTIPLY_ADD"},
324
+ )
325
+
326
+ multiply_5 = nw.new_node(
327
+ Nodes.Math,
328
+ input_kwargs={0: multiply_add_3, 1: group_input.outputs["Point height"]},
329
+ attrs={"operation": "MULTIPLY"},
330
+ )
331
+
332
+ multiply_add_4 = nw.new_node(
333
+ Nodes.Math,
334
+ input_kwargs={0: group_input.outputs["Point height"], 1: -1.0, 2: 1.0},
335
+ attrs={"operation": "MULTIPLY_ADD"},
336
+ )
337
+
338
+ add_2 = nw.new_node(Nodes.Math, input_kwargs={0: multiply_5, 1: multiply_add_4})
339
+
340
+ multiply_6 = nw.new_node(
341
+ Nodes.Math,
342
+ input_kwargs={0: add_2, 1: multiply_add_1},
343
+ attrs={"operation": "MULTIPLY"},
344
+ )
345
+
346
+ multiply_7 = nw.new_node(
347
+ Nodes.Math,
348
+ input_kwargs={0: add_1, 1: multiply_6},
349
+ attrs={"operation": "MULTIPLY"},
350
+ )
351
+
352
+ combine_xyz_1 = nw.new_node(
353
+ Nodes.CombineXYZ,
354
+ input_kwargs={"X": multiply_1, "Y": multiply_4, "Z": multiply_7},
355
+ )
356
+
357
+ set_position = nw.new_node(
358
+ Nodes.SetPosition,
359
+ input_kwargs={
360
+ "Geometry": capture_attribute.outputs["Geometry"],
361
+ "Position": combine_xyz_1,
362
+ },
363
+ )
364
+
365
+ multiply_8 = nw.new_node(
366
+ Nodes.Math,
367
+ input_kwargs={0: group_input.outputs["Length"]},
368
+ attrs={"operation": "MULTIPLY"},
369
+ )
370
+
371
+ combine_xyz_3 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Y": multiply_8})
372
+
373
+ reroute = nw.new_node(
374
+ Nodes.Reroute, input_kwargs={"Input": group_input.outputs["Curl"]}
375
+ )
376
+
377
+ group_1 = nw.new_node(
378
+ nodegroup_polar_to_cart_old().name,
379
+ input_kwargs={"Addend": combine_xyz_3, "Value": reroute, "Vector": multiply_8},
380
+ )
381
+
382
+ quadratic_bezier = nw.new_node(
383
+ Nodes.QuadraticBezier,
384
+ input_kwargs={
385
+ "Resolution": 8,
386
+ "Start": (0.0, 0.0, 0.0),
387
+ "Middle": combine_xyz_3,
388
+ "End": group_1,
389
+ },
390
+ )
391
+
392
+ group = nw.new_node(
393
+ nodegroup_follow_curve().name,
394
+ input_kwargs={
395
+ "Geometry": set_position,
396
+ "Curve": quadratic_bezier,
397
+ "Curve Min": 0.0,
398
+ },
399
+ )
400
+
401
+ group_output = nw.new_node(
402
+ Nodes.GroupOutput, input_kwargs={"Geometry": tag_nodegroup(nw, group, "petal")}
403
+ )
404
+
405
+
406
+ @node_utils.to_nodegroup("nodegroup_phyllo_points", singleton=True)
407
+ def nodegroup_phyllo_points(nw):
408
+ group_input = nw.new_node(
409
+ Nodes.GroupInput,
410
+ expose_input=[
411
+ ("NodeSocketInt", "Count", 50),
412
+ ("NodeSocketFloat", "Min Radius", 0.0),
413
+ ("NodeSocketFloat", "Max Radius", 2.0),
414
+ ("NodeSocketFloat", "Radius exp", 0.5),
415
+ ("NodeSocketFloat", "Min angle", -0.5236),
416
+ ("NodeSocketFloat", "Max angle", 0.7854),
417
+ ("NodeSocketFloat", "Min z", 0.0),
418
+ ("NodeSocketFloat", "Max z", 1.0),
419
+ ("NodeSocketFloat", "Clamp z", 1.0),
420
+ ("NodeSocketFloat", "Yaw offset", -1.5708),
421
+ ],
422
+ )
423
+
424
+ mesh_line = nw.new_node(
425
+ Nodes.MeshLine, input_kwargs={"Count": group_input.outputs["Count"]}
426
+ )
427
+
428
+ mesh_to_points = nw.new_node(Nodes.MeshToPoints, input_kwargs={"Mesh": mesh_line})
429
+
430
+ position = nw.new_node(Nodes.InputPosition)
431
+
432
+ capture_attribute = nw.new_node(
433
+ Nodes.CaptureAttribute,
434
+ input_kwargs={"Geometry": mesh_to_points, 1: position},
435
+ attrs={"data_type": "FLOAT_VECTOR"},
436
+ )
437
+
438
+ index = nw.new_node(Nodes.Index)
439
+
440
+ cosine = nw.new_node(
441
+ Nodes.Math, input_kwargs={0: index}, attrs={"operation": "COSINE"}
442
+ )
443
+
444
+ sine = nw.new_node(Nodes.Math, input_kwargs={0: index}, attrs={"operation": "SINE"})
445
+
446
+ combine_xyz = nw.new_node(Nodes.CombineXYZ, input_kwargs={"X": cosine, "Y": sine})
447
+
448
+ divide = nw.new_node(
449
+ Nodes.Math,
450
+ input_kwargs={0: index, 1: group_input.outputs["Count"]},
451
+ attrs={"operation": "DIVIDE"},
452
+ )
453
+
454
+ power = nw.new_node(
455
+ Nodes.Math,
456
+ input_kwargs={0: divide, 1: group_input.outputs["Radius exp"]},
457
+ attrs={"operation": "POWER"},
458
+ )
459
+
460
+ map_range = nw.new_node(
461
+ Nodes.MapRange,
462
+ input_kwargs={
463
+ "Value": power,
464
+ 3: group_input.outputs["Min Radius"],
465
+ 4: group_input.outputs["Max Radius"],
466
+ },
467
+ )
468
+
469
+ multiply = nw.new_node(
470
+ Nodes.VectorMath,
471
+ input_kwargs={0: combine_xyz, 1: map_range.outputs["Result"]},
472
+ attrs={"operation": "MULTIPLY"},
473
+ )
474
+
475
+ separate_xyz = nw.new_node(
476
+ Nodes.SeparateXYZ, input_kwargs={"Vector": multiply.outputs["Vector"]}
477
+ )
478
+
479
+ map_range_2 = nw.new_node(
480
+ Nodes.MapRange,
481
+ input_kwargs={
482
+ "Value": divide,
483
+ 2: group_input.outputs["Clamp z"],
484
+ 3: group_input.outputs["Min z"],
485
+ 4: group_input.outputs["Max z"],
486
+ },
487
+ )
488
+
489
+ combine_xyz_1 = nw.new_node(
490
+ Nodes.CombineXYZ,
491
+ input_kwargs={
492
+ "X": separate_xyz.outputs["X"],
493
+ "Y": separate_xyz.outputs["Y"],
494
+ "Z": map_range_2.outputs["Result"],
495
+ },
496
+ )
497
+
498
+ set_position = nw.new_node(
499
+ Nodes.SetPosition,
500
+ input_kwargs={
501
+ "Geometry": capture_attribute.outputs["Geometry"],
502
+ "Position": combine_xyz_1,
503
+ },
504
+ )
505
+
506
+ map_range_3 = nw.new_node(
507
+ Nodes.MapRange,
508
+ input_kwargs={
509
+ "Value": divide,
510
+ 3: group_input.outputs["Min angle"],
511
+ 4: group_input.outputs["Max angle"],
512
+ },
513
+ )
514
+
515
+ random_value = nw.new_node(Nodes.RandomValue, input_kwargs={2: -0.1, 3: 0.1})
516
+
517
+ add = nw.new_node(
518
+ Nodes.Math, input_kwargs={0: index, 1: group_input.outputs["Yaw offset"]}
519
+ )
520
+
521
+ combine_xyz_2 = nw.new_node(
522
+ Nodes.CombineXYZ,
523
+ input_kwargs={
524
+ "X": map_range_3.outputs["Result"],
525
+ "Y": random_value.outputs[1],
526
+ "Z": add,
527
+ },
528
+ )
529
+
530
+ group_output = nw.new_node(
531
+ Nodes.GroupOutput,
532
+ input_kwargs={"Points": set_position, "Rotation": combine_xyz_2},
533
+ )
534
+
535
+
536
+ @node_utils.to_nodegroup("nodegroup_plant_seed", singleton=True)
537
+ def nodegroup_plant_seed(nw):
538
+ group_input = nw.new_node(
539
+ Nodes.GroupInput,
540
+ expose_input=[
541
+ ("NodeSocketVector", "Dimensions", (0.0, 0.0, 0.0)),
542
+ ("NodeSocketIntUnsigned", "U", 4),
543
+ ("NodeSocketInt", "V", 8),
544
+ ],
545
+ )
546
+
547
+ separate_xyz = nw.new_node(
548
+ Nodes.SeparateXYZ, input_kwargs={"Vector": group_input.outputs["Dimensions"]}
549
+ )
550
+
551
+ combine_xyz = nw.new_node(
552
+ Nodes.CombineXYZ, input_kwargs={"X": separate_xyz.outputs["X"]}
553
+ )
554
+
555
+ multiply_add = nw.new_node(
556
+ Nodes.VectorMath,
557
+ input_kwargs={0: combine_xyz, 1: (0.5, 0.5, 0.5)},
558
+ attrs={"operation": "MULTIPLY_ADD"},
559
+ )
560
+
561
+ quadratic_bezier_1 = nw.new_node(
562
+ Nodes.QuadraticBezier,
563
+ input_kwargs={
564
+ "Resolution": group_input.outputs["U"],
565
+ "Start": (0.0, 0.0, 0.0),
566
+ "Middle": multiply_add.outputs["Vector"],
567
+ "End": combine_xyz,
568
+ },
569
+ )
570
+
571
+ group = nw.new_node(
572
+ nodegroup_norm_index().name, input_kwargs={"Count": group_input.outputs["U"]}
573
+ )
574
+
575
+ float_curve = nw.new_node(Nodes.FloatCurve, input_kwargs={"Value": group})
576
+ node_utils.assign_curve(
577
+ float_curve.mapping.curves[0], [(0.0, 0.0), (0.3159, 0.4469), (1.0, 0.0156)]
578
+ )
579
+
580
+ map_range = nw.new_node(Nodes.MapRange, input_kwargs={"Value": float_curve, 4: 3.0})
581
+
582
+ set_curve_radius = nw.new_node(
583
+ Nodes.SetCurveRadius,
584
+ input_kwargs={
585
+ "Curve": quadratic_bezier_1,
586
+ "Radius": map_range.outputs["Result"],
587
+ },
588
+ )
589
+
590
+ curve_circle = nw.new_node(
591
+ Nodes.CurveCircle,
592
+ input_kwargs={
593
+ "Resolution": group_input.outputs["V"],
594
+ "Radius": separate_xyz.outputs["Y"],
595
+ },
596
+ )
597
+
598
+ curve_to_mesh = nw.new_node(
599
+ Nodes.CurveToMesh,
600
+ input_kwargs={
601
+ "Curve": set_curve_radius,
602
+ "Profile Curve": curve_circle.outputs["Curve"],
603
+ "Fill Caps": True,
604
+ },
605
+ )
606
+
607
+ group_output = nw.new_node(
608
+ Nodes.GroupOutput,
609
+ input_kwargs={"Mesh": tag_nodegroup(nw, curve_to_mesh, "seed")},
610
+ )
611
+
612
+
613
+ def shader_flower_center(nw):
614
+ ambient_occlusion = nw.new_node(Nodes.AmbientOcclusion)
615
+
616
+ colorramp = nw.new_node(
617
+ Nodes.ColorRamp, input_kwargs={"Fac": ambient_occlusion.outputs["Color"]}
618
+ )
619
+ colorramp.color_ramp.elements.new(1)
620
+ colorramp.color_ramp.elements[0].position = 0.4841
621
+ colorramp.color_ramp.elements[0].color = (0.0127, 0.0075, 0.0026, 1.0)
622
+ colorramp.color_ramp.elements[1].position = 0.8591
623
+ colorramp.color_ramp.elements[1].color = (0.0848, 0.0066, 0.0007, 1.0)
624
+ colorramp.color_ramp.elements[2].position = 1.0
625
+ colorramp.color_ramp.elements[2].color = (1.0, 0.6228, 0.1069, 1.0)
626
+
627
+ principled_bsdf = nw.new_node(
628
+ Nodes.PrincipledBSDF, input_kwargs={"Base Color": colorramp.outputs["Color"]}
629
+ )
630
+
631
+ material_output = nw.new_node(
632
+ Nodes.MaterialOutput, input_kwargs={"Surface": principled_bsdf}
633
+ )
634
+
635
+
636
+ def shader_petal(nw):
637
+ translucent_color_change = uniform(0.1, 0.6)
638
+ specular = normal(0.6, 0.1)
639
+ roughness = normal(0.4, 0.05)
640
+ translucent_amt = normal(0.3, 0.05)
641
+
642
+ petal_color = nw.new_node(Nodes.RGB)
643
+ petal_color.outputs[0].default_value = color.color_category("petal")
644
+
645
+ translucent_color = nw.new_node(
646
+ Nodes.MixRGB,
647
+ [translucent_color_change, petal_color, color.color_category("petal")],
648
+ )
649
+
650
+ translucent_bsdf = nw.new_node(
651
+ Nodes.TranslucentBSDF, input_kwargs={"Color": translucent_color}
652
+ )
653
+
654
+ principled_bsdf = nw.new_node(
655
+ Nodes.PrincipledBSDF,
656
+ input_kwargs={
657
+ "Base Color": petal_color,
658
+ "Specular": specular,
659
+ "Roughness": roughness,
660
+ },
661
+ )
662
+
663
+ mix_shader = nw.new_node(
664
+ Nodes.MixShader,
665
+ input_kwargs={"Fac": translucent_amt, 1: principled_bsdf, 2: translucent_bsdf},
666
+ )
667
+
668
+ material_output = nw.new_node(
669
+ Nodes.MaterialOutput, input_kwargs={"Surface": mix_shader}
670
+ )
671
+
672
+
673
+ def geo_flower(nw, petal_material, center_material):
674
+ group_input = nw.new_node(
675
+ Nodes.GroupInput,
676
+ expose_input=[
677
+ ("NodeSocketGeometry", "Geometry", None),
678
+ ("NodeSocketFloat", "Center Rad", 0.0),
679
+ ("NodeSocketVector", "Petal Dims", (0.0, 0.0, 0.0)),
680
+ ("NodeSocketFloat", "Seed Size", 0.0),
681
+ ("NodeSocketFloat", "Min Petal Angle", 0.1),
682
+ ("NodeSocketFloat", "Max Petal Angle", 1.36),
683
+ ("NodeSocketFloat", "Wrinkle", 0.01),
684
+ ("NodeSocketFloat", "Curl", 13.89),
685
+ ],
686
+ )
687
+
688
+ uv_sphere = nw.new_node(
689
+ Nodes.MeshUVSphere,
690
+ input_kwargs={
691
+ "Segments": 8,
692
+ "Rings": 8,
693
+ "Radius": group_input.outputs["Center Rad"],
694
+ },
695
+ )
696
+
697
+ transform = nw.new_node(
698
+ Nodes.Transform, input_kwargs={"Geometry": uv_sphere, "Scale": (1.0, 1.0, 0.05)}
699
+ )
700
+
701
+ multiply = nw.new_node(
702
+ Nodes.Math,
703
+ input_kwargs={0: group_input.outputs["Seed Size"], 1: 1.5},
704
+ attrs={"operation": "MULTIPLY"},
705
+ )
706
+
707
+ distribute_points_on_faces = nw.new_node(
708
+ Nodes.DistributePointsOnFaces,
709
+ input_kwargs={
710
+ "Mesh": transform,
711
+ "Distance Min": multiply,
712
+ "Density Max": 50000.0,
713
+ },
714
+ attrs={"distribute_method": "POISSON"},
715
+ )
716
+
717
+ multiply_1 = nw.new_node(
718
+ Nodes.Math,
719
+ input_kwargs={0: group_input.outputs["Seed Size"], 1: 10.0},
720
+ attrs={"operation": "MULTIPLY"},
721
+ )
722
+
723
+ combine_xyz = nw.new_node(
724
+ Nodes.CombineXYZ,
725
+ input_kwargs={"X": multiply_1, "Y": group_input.outputs["Seed Size"]},
726
+ )
727
+
728
+ group_3 = nw.new_node(
729
+ nodegroup_plant_seed().name,
730
+ input_kwargs={"Dimensions": combine_xyz, "U": 6, "V": 6},
731
+ )
732
+
733
+ musgrave_texture = nw.new_node(
734
+ Nodes.MusgraveTexture,
735
+ input_kwargs={"W": 13.8, "Scale": 2.41},
736
+ attrs={"musgrave_dimensions": "4D"},
737
+ )
738
+
739
+ map_range = nw.new_node(
740
+ Nodes.MapRange, input_kwargs={"Value": musgrave_texture, 3: 0.34, 4: 1.21}
741
+ )
742
+
743
+ combine_xyz_1 = nw.new_node(
744
+ Nodes.CombineXYZ,
745
+ input_kwargs={"X": map_range.outputs["Result"], "Y": 1.0, "Z": 1.0},
746
+ )
747
+
748
+ instance_on_points_1 = nw.new_node(
749
+ Nodes.InstanceOnPoints,
750
+ input_kwargs={
751
+ "Points": distribute_points_on_faces.outputs["Points"],
752
+ "Instance": group_3,
753
+ "Rotation": (0.0, -1.5708, 0.0541),
754
+ "Scale": combine_xyz_1,
755
+ },
756
+ )
757
+
758
+ realize_instances = nw.new_node(
759
+ Nodes.RealizeInstances, input_kwargs={"Geometry": instance_on_points_1}
760
+ )
761
+
762
+ join_geometry_1 = nw.new_node(
763
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [realize_instances, transform]}
764
+ )
765
+
766
+ set_material_1 = nw.new_node(
767
+ Nodes.SetMaterial,
768
+ input_kwargs={"Geometry": join_geometry_1, "Material": center_material},
769
+ )
770
+
771
+ multiply_2 = nw.new_node(
772
+ Nodes.Math,
773
+ input_kwargs={0: group_input.outputs["Center Rad"], 1: 6.2832},
774
+ attrs={"operation": "MULTIPLY"},
775
+ )
776
+
777
+ separate_xyz = nw.new_node(
778
+ Nodes.SeparateXYZ, input_kwargs={"Vector": group_input.outputs["Petal Dims"]}
779
+ )
780
+
781
+ divide = nw.new_node(
782
+ Nodes.Math,
783
+ input_kwargs={0: multiply_2, 1: separate_xyz.outputs["Y"]},
784
+ attrs={"operation": "DIVIDE"},
785
+ )
786
+
787
+ multiply_3 = nw.new_node(
788
+ Nodes.Math, input_kwargs={0: divide, 1: 1.2}, attrs={"operation": "MULTIPLY"}
789
+ )
790
+
791
+ reroute_3 = nw.new_node(
792
+ Nodes.Reroute, input_kwargs={"Input": group_input.outputs["Center Rad"]}
793
+ )
794
+
795
+ reroute_1 = nw.new_node(
796
+ Nodes.Reroute, input_kwargs={"Input": group_input.outputs["Min Petal Angle"]}
797
+ )
798
+
799
+ reroute = nw.new_node(
800
+ Nodes.Reroute, input_kwargs={"Input": group_input.outputs["Max Petal Angle"]}
801
+ )
802
+
803
+ group_1 = nw.new_node(
804
+ nodegroup_phyllo_points().name,
805
+ input_kwargs={
806
+ "Count": multiply_3,
807
+ "Min Radius": reroute_3,
808
+ "Max Radius": reroute_3,
809
+ "Radius exp": 0.0,
810
+ "Min angle": reroute_1,
811
+ "Max angle": reroute,
812
+ "Max z": 0.0,
813
+ },
814
+ )
815
+
816
+ subtract = nw.new_node(
817
+ Nodes.Math,
818
+ input_kwargs={0: separate_xyz.outputs["Z"], 1: separate_xyz.outputs["Y"]},
819
+ attrs={"operation": "SUBTRACT", "use_clamp": True},
820
+ )
821
+
822
+ reroute_2 = nw.new_node(
823
+ Nodes.Reroute, input_kwargs={"Input": group_input.outputs["Wrinkle"]}
824
+ )
825
+
826
+ reroute_4 = nw.new_node(
827
+ Nodes.Reroute, input_kwargs={"Input": group_input.outputs["Curl"]}
828
+ )
829
+
830
+ group = nw.new_node(
831
+ nodegroup_flower_petal().name,
832
+ input_kwargs={
833
+ "Length": separate_xyz.outputs["X"],
834
+ "Point": 0.56,
835
+ "Point height": -0.1,
836
+ "Bevel": 1.83,
837
+ "Base width": separate_xyz.outputs["Y"],
838
+ "Upper width": subtract,
839
+ "Resolution H": 8,
840
+ "Resolution V": 16,
841
+ "Wrinkle": reroute_2,
842
+ "Curl": reroute_4,
843
+ },
844
+ )
845
+
846
+ instance_on_points = nw.new_node(
847
+ Nodes.InstanceOnPoints,
848
+ input_kwargs={
849
+ "Points": group_1.outputs["Points"],
850
+ "Instance": group,
851
+ "Rotation": group_1.outputs["Rotation"],
852
+ },
853
+ )
854
+
855
+ realize_instances_1 = nw.new_node(
856
+ Nodes.RealizeInstances, input_kwargs={"Geometry": instance_on_points}
857
+ )
858
+
859
+ noise_texture = nw.new_node(
860
+ Nodes.NoiseTexture,
861
+ input_kwargs={"Scale": 3.73, "Detail": 5.41, "Distortion": -1.0},
862
+ )
863
+
864
+ subtract_1 = nw.new_node(
865
+ Nodes.VectorMath,
866
+ input_kwargs={0: noise_texture.outputs["Color"], 1: (0.5, 0.5, 0.5)},
867
+ attrs={"operation": "SUBTRACT"},
868
+ )
869
+
870
+ value = nw.new_node(Nodes.Value)
871
+ value.outputs[0].default_value = 0.025
872
+
873
+ multiply_4 = nw.new_node(
874
+ Nodes.VectorMath,
875
+ input_kwargs={0: subtract_1.outputs["Vector"], 1: value},
876
+ attrs={"operation": "MULTIPLY"},
877
+ )
878
+
879
+ set_position = nw.new_node(
880
+ Nodes.SetPosition,
881
+ input_kwargs={
882
+ "Geometry": realize_instances_1,
883
+ "Offset": multiply_4.outputs["Vector"],
884
+ },
885
+ )
886
+
887
+ set_material = nw.new_node(
888
+ Nodes.SetMaterial,
889
+ input_kwargs={"Geometry": set_position, "Material": petal_material},
890
+ )
891
+
892
+ join_geometry = nw.new_node(
893
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [set_material_1, set_material]}
894
+ )
895
+
896
+ set_shade_smooth = nw.new_node(
897
+ Nodes.SetShadeSmooth,
898
+ input_kwargs={"Geometry": join_geometry, "Shade Smooth": False},
899
+ )
900
+
901
+ group_output = nw.new_node(
902
+ Nodes.GroupOutput, input_kwargs={"Geometry": set_shade_smooth}
903
+ )
904
+
905
+
906
+ class FlowerFactory(AssetFactory):
907
+ def __init__(self, factory_seed, rad=0.15, diversity_fac=0.25):
908
+ super(FlowerFactory, self).__init__(factory_seed=factory_seed)
909
+
910
+ self.get_params_dict()
911
+
912
+ self.rad = rad
913
+ self.diversity_fac = diversity_fac
914
+
915
+ with FixedSeed(factory_seed):
916
+ self.petal_material = surface.shaderfunc_to_material(shader_petal)
917
+ self.center_material = surface.shaderfunc_to_material(shader_flower_center)
918
+ #self.species_params = self.get_flower_params(self.rad)
919
+ self.params = self.get_flower_params(self.rad * normal(1.0, 0.05))
920
+
921
+ def get_params_dict(self):
922
+ self.params_dict = {
923
+ "overall_rad": ['continuous', (0.7, 1.3)],
924
+ "pct_inner": ['continuous', (0.05, 0.5)],
925
+ "base_width": ['continuous', (4, 16)],
926
+ "top_width": ['continuous', (0.0, 1.6)],
927
+ "min_angle": ['continuous', (-20, 100)],
928
+ "max_angle": ['continuous', (-20, 100)],
929
+ "seed_size": ['continuous', (0.005, 0.03)],
930
+ "wrinkle": ['continuous', (0.003, 0.02)],
931
+ "curl": ['continuous', (-120, 120)],
932
+ }
933
+
934
+ @staticmethod
935
+ def get_flower_params(overall_rad=0.05):
936
+ pct_inner = uniform(0.05, 0.4)
937
+ base_width = 2 * np.pi * overall_rad * pct_inner / normal(20, 5)
938
+ top_width = overall_rad * np.clip(normal(0.7, 0.3), base_width * 1.2, 100)
939
+
940
+ min_angle, max_angle = np.deg2rad(np.sort(uniform(-20, 100, 2)))
941
+
942
+ return {
943
+ "Center Rad": overall_rad * pct_inner,
944
+ "Petal Dims": np.array(
945
+ [overall_rad * (1 - pct_inner), base_width, top_width], dtype=np.float32
946
+ ),
947
+ "Seed Size": uniform(0.005, 0.01),
948
+ "Min Petal Angle": min_angle,
949
+ "Max Petal Angle": max_angle,
950
+ "Wrinkle": uniform(0.003, 0.02),
951
+ "Curl": np.deg2rad(normal(30, 50)),
952
+ }
953
+
954
+ def update_params(self, params):
955
+ overall_rad = params['overall_rad']
956
+ pct_inner = params['pct_inner']
957
+ base_width = 2 * np.pi * overall_rad * pct_inner / params['base_width']
958
+ top_width = overall_rad * np.clip(params['top_width'], base_width * 1.2, 100)
959
+
960
+ min_angle = np.deg2rad(params['min_angle'])
961
+ max_angle = np.deg2rad(params['max_angle'])
962
+ if min_angle > max_angle:
963
+ min_angle, max_angle = max_angle, min_angle
964
+
965
+ parameters = {
966
+ "Center Rad": overall_rad * pct_inner,
967
+ "Petal Dims": np.array(
968
+ [overall_rad * (1 - pct_inner), base_width, top_width], dtype=np.float32
969
+ ),
970
+ "Seed Size": params['seed_size'],
971
+ "Min Petal Angle": min_angle,
972
+ "Max Petal Angle": max_angle,
973
+ "Wrinkle": params['wrinkle'],
974
+ "Curl": np.deg2rad(params['curl']),
975
+ }
976
+ self.params.update(parameters)
977
+ self.petal_material = surface.shaderfunc_to_material(shader_petal)
978
+ self.center_material = surface.shaderfunc_to_material(shader_flower_center)
979
+
980
+ def fix_unused_params(self, params):
981
+ return params
982
+
983
+ def create_asset(self, **kwargs) -> bpy.types.Object:
984
+ vert = butil.spawn_vert("flower")
985
+ mod = surface.add_geomod(
986
+ vert,
987
+ geo_flower,
988
+ input_kwargs={
989
+ "petal_material": self.petal_material,
990
+ "center_material": self.center_material,
991
+ },
992
+ )
993
+
994
+ #inst_params = self.get_flower_params(self.rad * normal(1, 0.05))
995
+ #params = dict_lerp(self.species_params, inst_params, 0.25)
996
+ butil.set_geomod_inputs(mod, self.params)
997
+
998
+ butil.apply_modifiers(vert, mod=mod)
999
+
1000
+ vert.rotation_euler.z = uniform(0, 360)
1001
+ tag_object(vert, "flower")
1002
+ return vert
core/assets/table.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Princeton University.
2
+ # This source code is licensed under the BSD 3-Clause license found in the LICENSE file in the root directory of this source tree.
3
+
4
+ # Authors: Yiming Zuo
5
+
6
+
7
+ import bpy
8
+ from numpy.random import choice, normal, uniform
9
+
10
+ from infinigen.assets.material_assignments import AssetList
11
+ from infinigen.assets.objects.tables.legs.single_stand import (
12
+ nodegroup_generate_single_stand,
13
+ )
14
+ from infinigen.assets.objects.tables.legs.square import nodegroup_generate_leg_square
15
+ from infinigen.assets.objects.tables.legs.straight import (
16
+ nodegroup_generate_leg_straight,
17
+ )
18
+ from infinigen.assets.objects.tables.strechers import nodegroup_strecher
19
+ from infinigen.assets.objects.tables.table_top import nodegroup_generate_table_top
20
+ from infinigen.assets.objects.tables.table_utils import (
21
+ nodegroup_create_anchors,
22
+ nodegroup_create_legs_and_strechers,
23
+ )
24
+ from infinigen.core import surface, tagging
25
+ from infinigen.core import tags as t
26
+ from infinigen.core.nodes import node_utils
27
+
28
+ # from infinigen.assets.materials import metal, metal_shader_list
29
+ # from infinigen.assets.materials.fabrics import fabric
30
+ from infinigen.core.nodes.node_wrangler import Nodes, NodeWrangler
31
+ from infinigen.core.placement.factory import AssetFactory
32
+ from infinigen.core.surface import NoApply
33
+ from infinigen.core.util.math import FixedSeed
34
+
35
+
36
+ @node_utils.to_nodegroup(
37
+ "geometry_create_legs", singleton=False, type="GeometryNodeTree"
38
+ )
39
+ def geometry_create_legs(nw: NodeWrangler, **kwargs):
40
+ createanchors = nw.new_node(
41
+ nodegroup_create_anchors().name,
42
+ input_kwargs={
43
+ "Profile N-gon": kwargs["Leg Number"],
44
+ "Profile Width": kwargs["Leg Placement Top Relative Scale"]
45
+ * kwargs["Top Profile Width"],
46
+ "Profile Aspect Ratio": kwargs["Top Profile Aspect Ratio"],
47
+ },
48
+ )
49
+
50
+ if kwargs["Leg Style"] == "single_stand":
51
+ leg = nw.new_node(
52
+ nodegroup_generate_single_stand(**kwargs).name,
53
+ input_kwargs={
54
+ "Leg Height": kwargs["Leg Height"],
55
+ "Leg Diameter": kwargs["Leg Diameter"],
56
+ "Resolution": 64,
57
+ },
58
+ )
59
+
60
+ leg = nw.new_node(
61
+ nodegroup_create_legs_and_strechers().name,
62
+ input_kwargs={
63
+ "Anchors": createanchors,
64
+ "Keep Legs": True,
65
+ "Leg Instance": leg,
66
+ "Table Height": kwargs["Top Height"],
67
+ "Leg Bottom Relative Scale": kwargs[
68
+ "Leg Placement Bottom Relative Scale"
69
+ ],
70
+ "Align Leg X rot": True,
71
+ },
72
+ )
73
+
74
+ elif kwargs["Leg Style"] == "straight":
75
+ leg = nw.new_node(
76
+ nodegroup_generate_leg_straight(**kwargs).name,
77
+ input_kwargs={
78
+ "Leg Height": kwargs["Leg Height"],
79
+ "Leg Diameter": kwargs["Leg Diameter"],
80
+ "Resolution": 32,
81
+ "N-gon": kwargs["Leg NGon"],
82
+ "Fillet Ratio": 0.1,
83
+ },
84
+ )
85
+
86
+ strecher = nw.new_node(
87
+ nodegroup_strecher().name,
88
+ input_kwargs={"Profile Width": kwargs["Leg Diameter"] * 0.5},
89
+ )
90
+
91
+ leg = nw.new_node(
92
+ nodegroup_create_legs_and_strechers().name,
93
+ input_kwargs={
94
+ "Anchors": createanchors,
95
+ "Keep Legs": True,
96
+ "Leg Instance": leg,
97
+ "Table Height": kwargs["Top Height"],
98
+ "Strecher Instance": strecher,
99
+ "Strecher Index Increment": kwargs["Strecher Increament"],
100
+ "Strecher Relative Position": kwargs["Strecher Relative Pos"],
101
+ "Leg Bottom Relative Scale": kwargs[
102
+ "Leg Placement Bottom Relative Scale"
103
+ ],
104
+ "Align Leg X rot": True,
105
+ },
106
+ )
107
+
108
+ elif kwargs["Leg Style"] == "square":
109
+ leg = nw.new_node(
110
+ nodegroup_generate_leg_square(**kwargs).name,
111
+ input_kwargs={
112
+ "Height": kwargs["Leg Height"],
113
+ "Width": 0.707
114
+ * kwargs["Leg Placement Top Relative Scale"]
115
+ * kwargs["Top Profile Width"]
116
+ * kwargs["Top Profile Aspect Ratio"],
117
+ "Has Bottom Connector": (kwargs["Strecher Increament"] > 0),
118
+ "Profile Width": kwargs["Leg Diameter"],
119
+ },
120
+ )
121
+
122
+ leg = nw.new_node(
123
+ nodegroup_create_legs_and_strechers().name,
124
+ input_kwargs={
125
+ "Anchors": createanchors,
126
+ "Keep Legs": True,
127
+ "Leg Instance": leg,
128
+ "Table Height": kwargs["Top Height"],
129
+ "Leg Bottom Relative Scale": kwargs[
130
+ "Leg Placement Bottom Relative Scale"
131
+ ],
132
+ "Align Leg X rot": True,
133
+ },
134
+ )
135
+
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ leg = nw.new_node(
140
+ Nodes.SetMaterial,
141
+ input_kwargs={"Geometry": leg, "Material": kwargs["LegMaterial"]},
142
+ )
143
+
144
+ group_output = nw.new_node(
145
+ Nodes.GroupOutput,
146
+ input_kwargs={"Geometry": leg},
147
+ attrs={"is_active_output": True},
148
+ )
149
+
150
+
151
+ def geometry_assemble_table(nw: NodeWrangler, **kwargs):
152
+ # Code generated using version 2.6.4 of the node_transpiler
153
+
154
+ generatetabletop = nw.new_node(
155
+ nodegroup_generate_table_top().name,
156
+ input_kwargs={
157
+ "Thickness": kwargs["Top Thickness"],
158
+ "N-gon": kwargs["Top Profile N-gon"],
159
+ "Profile Width": kwargs["Top Profile Width"],
160
+ "Aspect Ratio": kwargs["Top Profile Aspect Ratio"],
161
+ "Fillet Ratio": kwargs["Top Profile Fillet Ratio"],
162
+ "Fillet Radius Vertical": kwargs["Top Vertical Fillet Ratio"],
163
+ },
164
+ )
165
+
166
+ tabletop_instance = nw.new_node(
167
+ Nodes.Transform,
168
+ input_kwargs={
169
+ "Geometry": generatetabletop,
170
+ "Translation": (0.0000, 0.0000, kwargs["Top Height"]),
171
+ },
172
+ )
173
+
174
+ tabletop_instance = nw.new_node(
175
+ Nodes.SetMaterial,
176
+ input_kwargs={"Geometry": tabletop_instance, "Material": kwargs["TopMaterial"]},
177
+ )
178
+
179
+ legs = nw.new_node(geometry_create_legs(**kwargs).name)
180
+
181
+ join_geometry = nw.new_node(
182
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [tabletop_instance, legs]}
183
+ )
184
+
185
+ group_output = nw.new_node(
186
+ Nodes.GroupOutput,
187
+ input_kwargs={"Geometry": join_geometry},
188
+ attrs={"is_active_output": True},
189
+ )
190
+
191
+
192
+ class TableDiningFactory(AssetFactory):
193
+ def __init__(self, factory_seed, coarse=False, dimensions=None):
194
+ super(TableDiningFactory, self).__init__(factory_seed, coarse=coarse)
195
+
196
+ self.dimensions = dimensions
197
+ self.get_params_dict()
198
+ self.leg_styles = ["single_stand", "square", "straight"]
199
+
200
+ with FixedSeed(factory_seed):
201
+ self.params = self.sample_parameters(dimensions)
202
+
203
+ # self.clothes_scatter = ClothesCover(factory_fn=blanket.BlanketFactory, width=log_uniform(.8, 1.2),
204
+ # size=uniform(.8, 1.2)) if uniform() < .3 else NoApply()
205
+ self.clothes_scatter = NoApply()
206
+ self.material_params, self.scratch, self.edge_wear = (
207
+ self.get_material_params()
208
+ )
209
+ self.params.update(self.material_params)
210
+
211
+ def get_params_dict(self):
212
+ # list all the parameters (key:name, value: [type, range]) used in this generator
213
+ self.params_dict = {
214
+ "ngon": ["discrete", (4, 36)],
215
+ "dimension_x": ["continuous", (0.9, 2.2)],
216
+ "dimension_y": ["continuous", (0.9, 2.2)],
217
+ "dimension_z": ["continuous", (0.5, 0.9)],
218
+ "leg_style": ["discrete", (0, 1, 2)],
219
+ "leg_number": ["discrete", (1, 2, 4)],
220
+ "leg_ngon": ["discrete", (4, 12)],
221
+ "leg_diameter": ["continuous", (0, 1)],
222
+ "leg_height": ["continuous", (0.6, 2.0)],
223
+ "leg_curve_ctrl_pts0": ["continuous", (0, 1)],
224
+ "leg_curve_ctrl_pts1": ["continuous", (0, 1)],
225
+ "leg_curve_ctrl_pts2": ["continuous", (0, 1)],
226
+ "top_scale": ["continuous", (0.6, 0.8)], # leg start point relative position
227
+ "bottom_scale": ["continuous", (0.9, 1.3)], # leg end point relative position
228
+ "top_thickness": ["continuous", (0.02, 0.1)],
229
+ "top_profile_fillet_ratio": ["continuous", (-0.6, 0.6)], # table corner round / square
230
+ "top_vertical_fillet_ratio": ["continuous", (0.0, 0.2)], # table corner round / square
231
+ "strecher_relative_pos": ["continuous", (0.15, 0.8)],
232
+ "strecher_increament": ["discrete", (0, 1, 2)],
233
+ }
234
+
235
+
236
+ def get_material_params(self):
237
+ material_assignments = AssetList["TableDiningFactory"]()
238
+ params = {
239
+ "TopMaterial": material_assignments["top"].assign_material(),
240
+ "LegMaterial": material_assignments["leg"].assign_material(),
241
+ }
242
+ wrapped_params = {
243
+ k: surface.shaderfunc_to_material(v) for k, v in params.items()
244
+ }
245
+
246
+ scratch_prob, edge_wear_prob = material_assignments["wear_tear_prob"]
247
+ scratch, edge_wear = material_assignments["wear_tear"]
248
+
249
+ is_scratch = uniform() < scratch_prob
250
+ is_edge_wear = uniform() < edge_wear_prob
251
+ if not is_scratch:
252
+ scratch = None
253
+
254
+ if not is_edge_wear:
255
+ edge_wear = None
256
+
257
+ return wrapped_params, scratch, edge_wear
258
+
259
+ @staticmethod
260
+ def sample_parameters(dimensions):
261
+ # not used in DI-PCG
262
+ if dimensions is None:
263
+ width = uniform(0.91, 1.16)
264
+
265
+ if uniform() < 0.7:
266
+ # oblong
267
+ length = uniform(1.4, 2.8)
268
+ else:
269
+ # approx square
270
+ length = width * normal(1, 0.1)
271
+
272
+ dimensions = (length, width, uniform(0.65, 0.85))
273
+
274
+ # all in meters
275
+ x, y, z = dimensions
276
+
277
+ NGon = 4
278
+
279
+ leg_style = choice(["straight", "single_stand", "square"], p=[0.5, 0.1, 0.4])
280
+ # leg_style = choice(['straight'])
281
+
282
+ if leg_style == "single_stand":
283
+ leg_number = 2
284
+ leg_diameter = uniform(0.22 * x, 0.28 * x)
285
+
286
+ leg_curve_ctrl_pts = [
287
+ (0.0, uniform(0.1, 0.2)),
288
+ (0.5, uniform(0.1, 0.2)),
289
+ (0.9, uniform(0.2, 0.3)),
290
+ (1.0, 1.0),
291
+ ]
292
+
293
+ top_scale = uniform(0.6, 0.7)
294
+ bottom_scale = 1.0
295
+
296
+ elif leg_style == "square":
297
+ leg_number = 2
298
+ leg_diameter = uniform(0.07, 0.10)
299
+
300
+ leg_curve_ctrl_pts = None
301
+
302
+ top_scale = 0.8
303
+ bottom_scale = 1.0
304
+
305
+ elif leg_style == "straight":
306
+ leg_diameter = uniform(0.05, 0.07)
307
+
308
+ leg_number = 4
309
+
310
+ leg_curve_ctrl_pts = [
311
+ (0.0, 1.0),
312
+ (0.4, uniform(0.85, 0.95)),
313
+ (1.0, uniform(0.4, 0.6)),
314
+ ]
315
+
316
+ top_scale = 0.8
317
+ bottom_scale = uniform(1.0, 1.2)
318
+
319
+ else:
320
+ raise NotImplementedError
321
+
322
+ top_thickness = uniform(0.03, 0.06)
323
+
324
+ parameters = {
325
+ "Top Profile N-gon": NGon,
326
+ "Top Profile Width": 1.414 * x,
327
+ "Top Profile Aspect Ratio": y / x,
328
+ "Top Profile Fillet Ratio": uniform(0.0, 0.02),
329
+ "Top Thickness": top_thickness,
330
+ "Top Vertical Fillet Ratio": uniform(0.1, 0.3),
331
+ # 'Top Material': choice(['marble', 'tiled_wood', 'metal', 'fabric'], p=[.3, .3, .2, .2]),
332
+ "Height": z,
333
+ "Top Height": z - top_thickness,
334
+ "Leg Number": leg_number,
335
+ "Leg Style": leg_style,
336
+ "Leg NGon": 4,
337
+ "Leg Placement Top Relative Scale": top_scale,
338
+ "Leg Placement Bottom Relative Scale": bottom_scale,
339
+ "Leg Height": 1.0,
340
+ "Leg Diameter": leg_diameter,
341
+ "Leg Curve Control Points": leg_curve_ctrl_pts,
342
+ # 'Leg Material': choice(['metal', 'wood', 'glass', 'plastic']),
343
+ "Strecher Relative Pos": uniform(0.2, 0.6),
344
+ "Strecher Increament": choice([0, 1, 2]),
345
+ }
346
+
347
+ return parameters
348
+
349
+ def fix_unused_params(self, params):
350
+ if params['leg_style'] == 0:
351
+ # single stand only allow 1 or 2 legs
352
+ if params['leg_number'] == 4:
353
+ params['leg_number'] = 2
354
+ params['bottom_scale'] = 1.1
355
+ params['strecher_increament'] = 1
356
+ elif params['leg_style'] == 1:
357
+ params['leg_number'] = 2
358
+ params['leg_curve_ctrl_pts0'] = 0.5
359
+ params['leg_curve_ctrl_pts1'] = 0.5
360
+ params['leg_curve_ctrl_pts2'] = 0.5
361
+ params['bottom_scale'] = 1.1
362
+ params['top_scale'] = 0.8
363
+ params['strecher_increament'] = 1
364
+ elif params['leg_style'] == 2:
365
+ params['leg_number'] = 4
366
+ params['leg_curve_ctrl_pts0'] = 0.5
367
+ params['top_scale'] = 0.8
368
+ if params['ngon'] == 36:
369
+ params['top_profile_fillet_ratio'] = 0.0
370
+ params['top_vertical_fillet_ratio'] = 0.0
371
+ return params
372
+
373
+ def update_params(self, params):
374
+ x, y, z = params["dimension_x"], params["dimension_y"], params["dimension_z"]
375
+ NGon = params['ngon']
376
+
377
+ leg_style = self.leg_styles[int(params['leg_style'])]
378
+
379
+ if leg_style == "single_stand":
380
+ leg_number = params['leg_number']
381
+ if leg_number == 4:
382
+ leg_number = 2
383
+ leg_diameter = (0.2 + 0.2 * params['leg_diameter']) * x
384
+ leg_curve_ctrl_pts = [
385
+ (0.0, 0.1 + 0.8 * params['leg_curve_ctrl_pts0']),
386
+ (0.5, 0.1 + 0.8 * params['leg_curve_ctrl_pts1']),
387
+ (0.9, 0.2 + 0.8 * params['leg_curve_ctrl_pts2']),
388
+ (1.0, 1.0),
389
+ ]
390
+ top_scale = params['top_scale']
391
+ bottom_scale = 1.0
392
+ strecher_increament = 1
393
+
394
+ elif leg_style == "square":
395
+ leg_number = 2
396
+ leg_diameter = 0.05 + 0.2 * params['leg_diameter']
397
+ leg_curve_ctrl_pts = None
398
+ top_scale = 0.8
399
+ bottom_scale = 1.0
400
+ strecher_increament = 1
401
+
402
+ elif leg_style == "straight":
403
+ leg_diameter = 0.05 + 0.2 * params['leg_diameter']
404
+ leg_number = 4
405
+ leg_curve_ctrl_pts = [
406
+ (0.0, 1.0),
407
+ (0.4, 0.5 + 0.5 * params['leg_curve_ctrl_pts1']),
408
+ (1.0, 0.3 + 0.5 * params['leg_curve_ctrl_pts2'])
409
+ ]
410
+ top_scale = 0.8
411
+ bottom_scale = params['bottom_scale']
412
+ strecher_increament = params["strecher_increament"]
413
+ else:
414
+ raise NotImplementedError
415
+
416
+ if params['ngon'] == 36:
417
+ top_profile_fillet_ratio = 0.0
418
+ top_vertical_fillet_ratio = 0.0
419
+ else:
420
+ top_profile_fillet_ratio = params['top_profile_fillet_ratio']
421
+ top_vertical_fillet_ratio = params['top_vertical_fillet_ratio']
422
+
423
+ top_thickness = params['top_thickness']
424
+ parameters = {
425
+ "Top Profile N-gon": NGon,
426
+ "Top Profile Width": 1.414 * x,
427
+ "Top Profile Aspect Ratio": y / x,
428
+ "Top Profile Fillet Ratio": top_profile_fillet_ratio,
429
+ "Top Thickness": top_thickness,
430
+ "Top Vertical Fillet Ratio": top_vertical_fillet_ratio,
431
+ "Height": z,
432
+ "Top Height": z - top_thickness,
433
+ "Leg Number": leg_number,
434
+ "Leg Style": leg_style,
435
+ "Leg NGon": params['leg_ngon'],
436
+ "Leg Placement Top Relative Scale": top_scale,
437
+ "Leg Placement Bottom Relative Scale": bottom_scale,
438
+ "Leg Height": params['leg_height'],
439
+ "Leg Diameter": leg_diameter,
440
+ "Leg Curve Control Points": leg_curve_ctrl_pts,
441
+ "Strecher Relative Pos": params["strecher_relative_pos"],
442
+ "Strecher Increament": strecher_increament,
443
+ }
444
+ self.params.update(parameters)
445
+ self.clothes_scatter = NoApply()
446
+ self.material_params, self.scratch, self.edge_wear = (
447
+ self.get_material_params()
448
+ )
449
+ self.params.update(self.material_params)
450
+
451
+ def create_asset(self, **params):
452
+ bpy.ops.mesh.primitive_plane_add(
453
+ size=2,
454
+ enter_editmode=False,
455
+ align="WORLD",
456
+ location=(0, 0, 0),
457
+ scale=(1, 1, 1),
458
+ )
459
+ obj = bpy.context.active_object
460
+
461
+ # surface.add_geomod(obj, geometry_assemble_table, apply=False, input_kwargs=self.params)
462
+ surface.add_geomod(
463
+ obj, geometry_assemble_table, apply=True, input_kwargs=self.params
464
+ )
465
+ tagging.tag_system.relabel_obj(obj)
466
+ assert tagging.tagged_face_mask(obj, {t.Subpart.SupportSurface}).sum() != 0
467
+
468
+ return obj
469
+
470
+ def finalize_assets(self, assets):
471
+ if self.scratch:
472
+ self.scratch.apply(assets)
473
+ if self.edge_wear:
474
+ self.edge_wear.apply(assets)
475
+
476
+ # def finalize_assets(self, assets):
477
+ # self.clothes_scatter.apply(assets)
478
+
479
+
480
+ class SideTableFactory(TableDiningFactory):
481
+ def __init__(self, factory_seed, coarse=False, dimensions=None):
482
+ if dimensions is None:
483
+ w = 0.55 * normal(1, 0.05)
484
+ h = 0.95 * w * normal(1, 0.05)
485
+ dimensions = (w, w, h)
486
+ super().__init__(factory_seed, coarse=coarse, dimensions=dimensions)
487
+
488
+
489
+ class CoffeeTableFactory(TableDiningFactory):
490
+ def __init__(self, factory_seed, coarse=False, dimensions=None):
491
+ if dimensions is None:
492
+ dimensions = (uniform(1, 1.5), uniform(0.6, 0.9), uniform(0.4, 0.5))
493
+ super().__init__(factory_seed, coarse=coarse, dimensions=dimensions)
core/assets/vase.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Princeton University.
2
+ # This source code is licensed under the BSD 3-Clause license found in the LICENSE file in the root directory of this source tree.
3
+
4
+ # Authors: Yiming Zuo
5
+
6
+ import bpy
7
+ import numpy as np
8
+ from numpy.random import choice, randint, uniform
9
+
10
+ import infinigen
11
+ import infinigen.core.util.blender as butil
12
+ from infinigen.assets.material_assignments import AssetList
13
+ from infinigen.assets.objects.table_decorations.utils import (
14
+ nodegroup_lofting,
15
+ nodegroup_star_profile,
16
+ )
17
+ from infinigen.core import surface
18
+ from infinigen.core.nodes import node_utils
19
+ from infinigen.core.nodes.node_wrangler import Nodes, NodeWrangler
20
+ from infinigen.core.placement.factory import AssetFactory
21
+ from infinigen.core.util.math import FixedSeed
22
+
23
+
24
+ class VaseFactory(AssetFactory):
25
+ def __init__(self, factory_seed, coarse=False, dimensions=None):
26
+ super(VaseFactory, self).__init__(factory_seed, coarse=coarse)
27
+
28
+ if dimensions is None:
29
+ z = uniform(0.17, 0.5)
30
+ x = z * uniform(0.3, 0.6)
31
+ dimensions = (x, x, z)
32
+ self.dimensions = dimensions
33
+ self.get_params_dict()
34
+
35
+ with FixedSeed(factory_seed):
36
+ self.params = self.sample_parameters(dimensions)
37
+ self.material_params, self.scratch, self.edge_wear = (
38
+ self.get_material_params()
39
+ )
40
+
41
+ self.params.update(self.material_params)
42
+
43
+ def get_params_dict(self):
44
+ # list all the parameters (key:name, value: [type, range]) used in this generator
45
+ self.params_dict = {
46
+ "dimension_x": ["continuous", (0.05, 0.4)],
47
+ "dimension_z": ["continuous", (0.2, 0.8)],
48
+ "neck_scale": ["continuous", (0.15, 0.8)],
49
+ "profile_inner_radius": ["continuous", (0.8, 1.2)],
50
+ "profile_star_points": ["discrete", (2,3,4,5,6,7,8,9,10,16,18,20,22,24,26,28,30)],
51
+ "top_scale": ["continuous", (0.6, 1.4)],
52
+ "neck_mid_position": ["continuous", (0.5, 1.5)],
53
+ "neck_position": ["continuous", (-0.2, 0.2)],
54
+ "shoulder_position": ["continuous", (0.1, 0.8)],
55
+ "shoulder_thickness": ["continuous", (0.1, 0.3)],
56
+ "foot_scale": ["continuous", (0.2, 0.8)],
57
+ "foot_height": ["continuous", (0.01, 0.1)],
58
+ }
59
+
60
+ def get_material_params(self):
61
+ material_assignments = AssetList["VaseFactory"]()
62
+ params = {
63
+ "Material": material_assignments["surface"].assign_material(),
64
+ }
65
+ wrapped_params = {
66
+ k: surface.shaderfunc_to_material(v) for k, v in params.items()
67
+ }
68
+
69
+ scratch_prob, edge_wear_prob = material_assignments["wear_tear_prob"]
70
+ scratch, edge_wear = material_assignments["wear_tear"]
71
+
72
+ is_scratch = uniform() < scratch_prob
73
+ is_edge_wear = uniform() < edge_wear_prob
74
+ if not is_scratch:
75
+ scratch = None
76
+
77
+ if not is_edge_wear:
78
+ edge_wear = None
79
+
80
+ return wrapped_params, scratch, edge_wear
81
+
82
+ @staticmethod
83
+ def sample_parameters(dimensions):
84
+ # all in meters
85
+ if dimensions is None:
86
+ z = uniform(0.25, 0.40)
87
+ x = uniform(0.2, 0.4) * z
88
+ dimensions = (x, x, z)
89
+
90
+ x, y, z = dimensions
91
+
92
+ U_resolution = 64
93
+ V_resolution = 64
94
+
95
+ neck_scale = uniform(0.2, 0.8)
96
+
97
+ parameters = {
98
+ "Profile Inner Radius": choice([1.0, uniform(0.8, 1.0)]),
99
+ "Profile Star Points": randint(16, U_resolution // 2 + 1),
100
+ "U_resolution": U_resolution,
101
+ "V_resolution": V_resolution,
102
+ "Height": z,
103
+ "Diameter": x,
104
+ "Top Scale": neck_scale * uniform(0.8, 1.2),
105
+ "Neck Mid Position": uniform(0.7, 0.95),
106
+ "Neck Position": 0.5 * neck_scale + 0.5 + uniform(-0.05, 0.05),
107
+ "Neck Scale": neck_scale,
108
+ "Shoulder Position": uniform(0.3, 0.7),
109
+ "Shoulder Thickness": uniform(0.1, 0.25),
110
+ "Foot Scale": uniform(0.4, 0.6),
111
+ "Foot Height": uniform(0.01, 0.1),
112
+ }
113
+
114
+ return parameters
115
+
116
+ def fix_unused_params(self, params):
117
+ return params
118
+
119
+ def update_params(self, params):
120
+ x, y, z = params["dimension_x"], params["dimension_x"], params["dimension_z"]
121
+ U_resolution = 64
122
+ V_resolution = 64
123
+ neck_scale = params["neck_scale"]
124
+ parameters = {
125
+ "Profile Inner Radius": np.clip(params["profile_inner_radius"], 0.8, 1.0),
126
+ "Profile Star Points": params["profile_star_points"],
127
+ "U_resolution": U_resolution,
128
+ "V_resolution": V_resolution,
129
+ "Height": z,
130
+ "Diameter": x,
131
+ "Top Scale": neck_scale * params["top_scale"],
132
+ "Neck Mid Position": params["neck_mid_position"],
133
+ "Neck Position": 0.5 * neck_scale + 0.5 + params["neck_position"],
134
+ "Neck Scale": neck_scale,
135
+ "Shoulder Position": params["shoulder_position"],
136
+ "Shoulder Thickness": params["shoulder_thickness"],
137
+ "Foot Scale": params["foot_scale"],
138
+ "Foot Height": params["foot_height"],
139
+ }
140
+ self.params.update(parameters)
141
+ self.material_params, self.scratch, self.edge_wear = (
142
+ self.get_material_params()
143
+ )
144
+
145
+ self.params.update(self.material_params)
146
+
147
+ def create_asset(self, **params):
148
+ bpy.ops.mesh.primitive_plane_add(
149
+ size=2,
150
+ enter_editmode=False,
151
+ align="WORLD",
152
+ location=(0, 0, 0),
153
+ scale=(1, 1, 1),
154
+ )
155
+ obj = bpy.context.active_object
156
+
157
+ surface.add_geomod(obj, geometry_vases, apply=True, input_kwargs=self.params)
158
+ butil.modify_mesh(obj, "SOLIDIFY", apply=True, thickness=0.002)
159
+ butil.modify_mesh(obj, "SUBSURF", apply=True, levels=2, render_levels=2)
160
+
161
+ return obj
162
+
163
+ def finalize_assets(self, assets):
164
+ if self.scratch:
165
+ self.scratch.apply(assets)
166
+ if self.edge_wear:
167
+ self.edge_wear.apply(assets)
168
+
169
+
170
+ @node_utils.to_nodegroup(
171
+ "nodegroup_vase_profile", singleton=False, type="GeometryNodeTree"
172
+ )
173
+ def nodegroup_vase_profile(nw: NodeWrangler):
174
+ # Code generated using version 2.6.4 of the node_transpiler
175
+
176
+ group_input = nw.new_node(
177
+ Nodes.GroupInput,
178
+ expose_input=[
179
+ ("NodeSocketGeometry", "Profile Curve", None),
180
+ ("NodeSocketFloat", "Height", 0.0000),
181
+ ("NodeSocketFloat", "Diameter", 0.0000),
182
+ ("NodeSocketFloat", "Top Scale", 0.0000),
183
+ ("NodeSocketFloat", "Neck Mid Position", 0.0000),
184
+ ("NodeSocketFloat", "Neck Position", 0.5000),
185
+ ("NodeSocketFloat", "Neck Scale", 0.0000),
186
+ ("NodeSocketFloat", "Shoulder Position", 0.0000),
187
+ ("NodeSocketFloat", "Shoulder Thickness", 0.0000),
188
+ ("NodeSocketFloat", "Foot Scale", 0.0000),
189
+ ("NodeSocketFloat", "Foot Height", 0.0000),
190
+ ],
191
+ )
192
+
193
+ combine_xyz_1 = nw.new_node(
194
+ Nodes.CombineXYZ, input_kwargs={"Z": group_input.outputs["Height"]}
195
+ )
196
+
197
+ multiply = nw.new_node(
198
+ Nodes.Math,
199
+ input_kwargs={
200
+ 0: group_input.outputs["Top Scale"],
201
+ 1: group_input.outputs["Diameter"],
202
+ },
203
+ attrs={"operation": "MULTIPLY"},
204
+ )
205
+
206
+ neck_top = nw.new_node(
207
+ Nodes.Transform,
208
+ input_kwargs={
209
+ "Geometry": group_input.outputs["Profile Curve"],
210
+ "Translation": combine_xyz_1,
211
+ "Scale": multiply,
212
+ },
213
+ )
214
+
215
+ multiply_1 = nw.new_node(
216
+ Nodes.Math,
217
+ input_kwargs={
218
+ 0: group_input.outputs["Height"],
219
+ 1: group_input.outputs["Neck Position"],
220
+ },
221
+ attrs={"operation": "MULTIPLY"},
222
+ )
223
+
224
+ combine_xyz = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply_1})
225
+
226
+ multiply_2 = nw.new_node(
227
+ Nodes.Math,
228
+ input_kwargs={
229
+ 0: group_input.outputs["Diameter"],
230
+ 1: group_input.outputs["Neck Scale"],
231
+ },
232
+ attrs={"operation": "MULTIPLY"},
233
+ )
234
+
235
+ neck = nw.new_node(
236
+ Nodes.Transform,
237
+ input_kwargs={
238
+ "Geometry": group_input.outputs["Profile Curve"],
239
+ "Translation": combine_xyz,
240
+ "Scale": multiply_2,
241
+ },
242
+ )
243
+
244
+ subtract = nw.new_node(
245
+ Nodes.Math,
246
+ input_kwargs={0: 1.0000, 1: group_input.outputs["Neck Position"]},
247
+ attrs={"use_clamp": True, "operation": "SUBTRACT"},
248
+ )
249
+
250
+ multiply_add = nw.new_node(
251
+ Nodes.Math,
252
+ input_kwargs={
253
+ 0: subtract,
254
+ 1: group_input.outputs["Neck Mid Position"],
255
+ 2: group_input.outputs["Neck Position"],
256
+ },
257
+ attrs={"operation": "MULTIPLY_ADD"},
258
+ )
259
+
260
+ multiply_3 = nw.new_node(
261
+ Nodes.Math,
262
+ input_kwargs={0: multiply_add, 1: group_input.outputs["Height"]},
263
+ attrs={"operation": "MULTIPLY"},
264
+ )
265
+
266
+ combine_xyz_2 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply_3})
267
+
268
+ add = nw.new_node(
269
+ Nodes.Math,
270
+ input_kwargs={
271
+ 0: group_input.outputs["Neck Scale"],
272
+ 1: group_input.outputs["Top Scale"],
273
+ },
274
+ )
275
+
276
+ divide = nw.new_node(
277
+ Nodes.Math, input_kwargs={0: add, 1: 2.0000}, attrs={"operation": "DIVIDE"}
278
+ )
279
+
280
+ multiply_4 = nw.new_node(
281
+ Nodes.Math,
282
+ input_kwargs={0: group_input.outputs["Diameter"], 1: divide},
283
+ attrs={"operation": "MULTIPLY"},
284
+ )
285
+
286
+ neck_middle = nw.new_node(
287
+ Nodes.Transform,
288
+ input_kwargs={
289
+ "Geometry": group_input.outputs["Profile Curve"],
290
+ "Translation": combine_xyz_2,
291
+ "Scale": multiply_4,
292
+ },
293
+ )
294
+
295
+ neck_geometry = nw.new_node(
296
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [neck, neck_middle, neck_top]}
297
+ )
298
+
299
+ map_range = nw.new_node(
300
+ Nodes.MapRange,
301
+ input_kwargs={
302
+ "Value": group_input.outputs["Shoulder Position"],
303
+ 3: group_input.outputs["Foot Height"],
304
+ 4: group_input.outputs["Neck Position"],
305
+ },
306
+ )
307
+
308
+ subtract_1 = nw.new_node(
309
+ Nodes.Math,
310
+ input_kwargs={
311
+ 0: group_input.outputs["Neck Position"],
312
+ 1: group_input.outputs["Foot Height"],
313
+ },
314
+ attrs={"operation": "SUBTRACT"},
315
+ )
316
+
317
+ multiply_5 = nw.new_node(
318
+ Nodes.Math,
319
+ input_kwargs={0: subtract_1, 1: group_input.outputs["Shoulder Thickness"]},
320
+ attrs={"operation": "MULTIPLY"},
321
+ )
322
+
323
+ add_1 = nw.new_node(
324
+ Nodes.Math, input_kwargs={0: map_range.outputs["Result"], 1: multiply_5}
325
+ )
326
+
327
+ minimum = nw.new_node(
328
+ Nodes.Math,
329
+ input_kwargs={0: add_1, 1: group_input.outputs["Neck Position"]},
330
+ attrs={"operation": "MINIMUM"},
331
+ )
332
+
333
+ multiply_6 = nw.new_node(
334
+ Nodes.Math,
335
+ input_kwargs={0: minimum, 1: group_input.outputs["Height"]},
336
+ attrs={"operation": "MULTIPLY"},
337
+ )
338
+
339
+ combine_xyz_3 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply_6})
340
+
341
+ body_top = nw.new_node(
342
+ Nodes.Transform,
343
+ input_kwargs={
344
+ "Geometry": group_input.outputs["Profile Curve"],
345
+ "Translation": combine_xyz_3,
346
+ "Scale": group_input.outputs["Diameter"],
347
+ },
348
+ )
349
+
350
+ subtract_2 = nw.new_node(
351
+ Nodes.Math,
352
+ input_kwargs={0: map_range.outputs["Result"], 1: multiply_5},
353
+ attrs={"operation": "SUBTRACT"},
354
+ )
355
+
356
+ maximum = nw.new_node(
357
+ Nodes.Math,
358
+ input_kwargs={0: subtract_2, 1: group_input.outputs["Foot Height"]},
359
+ attrs={"operation": "MAXIMUM"},
360
+ )
361
+
362
+ multiply_7 = nw.new_node(
363
+ Nodes.Math,
364
+ input_kwargs={0: maximum, 1: group_input.outputs["Height"]},
365
+ attrs={"operation": "MULTIPLY"},
366
+ )
367
+
368
+ combine_xyz_5 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply_7})
369
+
370
+ body_bottom = nw.new_node(
371
+ Nodes.Transform,
372
+ input_kwargs={
373
+ "Geometry": group_input.outputs["Profile Curve"],
374
+ "Translation": combine_xyz_5,
375
+ "Scale": group_input.outputs["Diameter"],
376
+ },
377
+ )
378
+
379
+ body_geometry = nw.new_node(
380
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [body_bottom, body_top]}
381
+ )
382
+
383
+ multiply_8 = nw.new_node(
384
+ Nodes.Math,
385
+ input_kwargs={
386
+ 0: group_input.outputs["Foot Height"],
387
+ 1: group_input.outputs["Height"],
388
+ },
389
+ attrs={"operation": "MULTIPLY"},
390
+ )
391
+
392
+ combine_xyz_4 = nw.new_node(Nodes.CombineXYZ, input_kwargs={"Z": multiply_8})
393
+
394
+ multiply_9 = nw.new_node(
395
+ Nodes.Math,
396
+ input_kwargs={
397
+ 0: group_input.outputs["Diameter"],
398
+ 1: group_input.outputs["Foot Scale"],
399
+ },
400
+ attrs={"operation": "MULTIPLY"},
401
+ )
402
+
403
+ foot_top = nw.new_node(
404
+ Nodes.Transform,
405
+ input_kwargs={
406
+ "Geometry": group_input,
407
+ "Translation": combine_xyz_4,
408
+ "Scale": multiply_9,
409
+ },
410
+ )
411
+
412
+ foot_bottom = nw.new_node(
413
+ Nodes.Transform, input_kwargs={"Geometry": group_input, "Scale": multiply_9}
414
+ )
415
+
416
+ foot_geometry = nw.new_node(
417
+ Nodes.JoinGeometry, input_kwargs={"Geometry": [foot_bottom, foot_top]}
418
+ )
419
+
420
+ join_geometry_2 = nw.new_node(
421
+ Nodes.JoinGeometry,
422
+ input_kwargs={"Geometry": [foot_geometry, body_geometry, neck_geometry]},
423
+ )
424
+
425
+ group_output = nw.new_node(
426
+ Nodes.GroupOutput,
427
+ input_kwargs={"Geometry": join_geometry_2},
428
+ attrs={"is_active_output": True},
429
+ )
430
+
431
+
432
+ def geometry_vases(nw: NodeWrangler, **kwargs):
433
+ # Code generated using version 2.6.4 of the node_transpiler
434
+ starprofile = nw.new_node(
435
+ nodegroup_star_profile().name,
436
+ input_kwargs={
437
+ "Resolution": kwargs["U_resolution"],
438
+ "Points": kwargs["Profile Star Points"],
439
+ "Inner Radius": kwargs["Profile Inner Radius"],
440
+ },
441
+ )
442
+
443
+ vaseprofile = nw.new_node(
444
+ nodegroup_vase_profile().name,
445
+ input_kwargs={
446
+ "Profile Curve": starprofile.outputs["Curve"],
447
+ "Height": kwargs["Height"],
448
+ "Diameter": kwargs["Diameter"],
449
+ "Top Scale": kwargs["Top Scale"],
450
+ "Neck Mid Position": kwargs["Neck Mid Position"],
451
+ "Neck Position": kwargs["Neck Position"],
452
+ "Neck Scale": kwargs["Neck Scale"],
453
+ "Shoulder Position": kwargs["Shoulder Position"],
454
+ "Shoulder Thickness": kwargs["Shoulder Thickness"],
455
+ "Foot Scale": kwargs["Foot Scale"],
456
+ "Foot Height": kwargs["Foot Height"],
457
+ },
458
+ )
459
+
460
+ lofting = nw.new_node(
461
+ nodegroup_lofting().name,
462
+ input_kwargs={
463
+ "Profile Curves": vaseprofile,
464
+ "U Resolution": 64,
465
+ "V Resolution": 64,
466
+ },
467
+ )
468
+
469
+ delete_geometry = nw.new_node(
470
+ Nodes.DeleteGeometry,
471
+ input_kwargs={
472
+ "Geometry": lofting.outputs["Geometry"],
473
+ "Selection": lofting.outputs["Top"],
474
+ },
475
+ )
476
+
477
+ set_material = nw.new_node(
478
+ Nodes.SetMaterial,
479
+ input_kwargs={"Geometry": delete_geometry, "Material": kwargs["Material"]},
480
+ )
481
+
482
+ group_output = nw.new_node(
483
+ Nodes.GroupOutput,
484
+ input_kwargs={"Geometry": set_material},
485
+ attrs={"is_active_output": True},
486
+ )
core/dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import numpy as np
7
+ import cv2
8
+ import json
9
+ from core.utils.io import read_list_from_txt
10
+ from core.utils.math_utils import normalize_params
11
+
12
+ class ImageParamsDataset(Dataset):
13
+ def __init__(self, data_root, list_file, params_dict_file):
14
+ self.data_root = data_root
15
+ self.data_lists = read_list_from_txt(os.path.join(data_root, list_file))
16
+ self.params_dict = json.load(open(os.path.join(data_root, params_dict_file), 'r'))
17
+
18
+ def __len__(self):
19
+ return len(self.data_lists)
20
+
21
+ def __getitem__(self, idx):
22
+ name = self.data_lists[idx]
23
+ id = name.split("/")[0]
24
+ params = json.load(open(os.path.join(self.data_root, id, "params.txt"), 'r'))
25
+ # normalize the params to [-1, 1] range for training diffusion
26
+ normalized_params = normalize_params(params, self.params_dict)
27
+ normalized_params_values = np.array(list(normalized_params.values()))
28
+ img = cv2.cvtColor(cv2.imread(os.path.join(self.data_root, name)), cv2.COLOR_BGR2RGB)
29
+
30
+ img_feat_name = os.path.join(self.data_root, name.replace(".png", "_dino_token.npy"))
31
+ if not os.path.exists(img_feat_name):
32
+ img_feat_file = np.load(os.path.join(self.data_root, name.replace(".png", "_dino_token.npz")))
33
+ img_feat = img_feat_file['arr_0']
34
+ img_feat_file.close()
35
+ else:
36
+ img_feat = np.load(img_feat_name)
37
+ img_feat_t = torch.from_numpy(img_feat).float()
38
+ return torch.from_numpy(normalized_params_values).float(), img_feat_t, img
39
+
40
+
core/diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
core/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
core/diffusion/__pycache__/diffusion_utils.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
core/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc ADDED
Binary file (24.4 kB). View file
 
core/diffusion/__pycache__/respace.cpython-310.pyc ADDED
Binary file (4.97 kB). View file
 
core/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
core/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "squaredcos_cap_v2":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317
+ if self.model_mean_type == ModelMeanType.START_X:
318
+ pred_xstart = process_xstart(model_output)
319
+ else:
320
+ pred_xstart = process_xstart(
321
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322
+ )
323
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324
+
325
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ "extra": extra,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345
+
346
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347
+ """
348
+ Compute the mean for the previous step, given a function cond_fn that
349
+ computes the gradient of a conditional log probability with respect to
350
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351
+ condition on y.
352
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353
+ """
354
+ gradient = cond_fn(x, t, **model_kwargs)
355
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356
+ return new_mean
357
+
358
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute what the p_mean_variance output would have been, should the
361
+ model's score function be conditioned by cond_fn.
362
+ See condition_mean() for details on cond_fn.
363
+ Unlike condition_mean(), this instead uses the conditioning strategy
364
+ from Song et al (2020).
365
+ """
366
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367
+
368
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370
+
371
+ out = p_mean_var.copy()
372
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374
+ return out
375
+
376
+ def p_sample(
377
+ self,
378
+ model,
379
+ x,
380
+ t,
381
+ clip_denoised=True,
382
+ denoised_fn=None,
383
+ cond_fn=None,
384
+ model_kwargs=None,
385
+ ):
386
+ """
387
+ Sample x_{t-1} from the model at the given timestep.
388
+ :param model: the model to sample from.
389
+ :param x: the current tensor at x_{t-1}.
390
+ :param t: the value of t, starting at 0 for the first diffusion step.
391
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
392
+ :param denoised_fn: if not None, a function which applies to the
393
+ x_start prediction before it is used to sample.
394
+ :param cond_fn: if not None, this is a gradient function that acts
395
+ similarly to the model.
396
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
397
+ pass to the model. This can be used for conditioning.
398
+ :return: a dict containing the following keys:
399
+ - 'sample': a random sample from the model.
400
+ - 'pred_xstart': a prediction of x_0.
401
+ """
402
+ out = self.p_mean_variance(
403
+ model,
404
+ x,
405
+ t,
406
+ clip_denoised=clip_denoised,
407
+ denoised_fn=denoised_fn,
408
+ model_kwargs=model_kwargs,
409
+ )
410
+ noise = th.randn_like(x)
411
+ nonzero_mask = (
412
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
413
+ ) # no noise when t == 0
414
+ if cond_fn is not None:
415
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
416
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
417
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
418
+
419
+ def p_sample_loop(
420
+ self,
421
+ model,
422
+ shape,
423
+ noise=None,
424
+ clip_denoised=True,
425
+ denoised_fn=None,
426
+ cond_fn=None,
427
+ model_kwargs=None,
428
+ device=None,
429
+ progress=False,
430
+ ):
431
+ """
432
+ Generate samples from the model.
433
+ :param model: the model module.
434
+ :param shape: the shape of the samples, (N, C, H, W).
435
+ :param noise: if specified, the noise from the encoder to sample.
436
+ Should be of the same shape as `shape`.
437
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
438
+ :param denoised_fn: if not None, a function which applies to the
439
+ x_start prediction before it is used to sample.
440
+ :param cond_fn: if not None, this is a gradient function that acts
441
+ similarly to the model.
442
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
443
+ pass to the model. This can be used for conditioning.
444
+ :param device: if specified, the device to create the samples on.
445
+ If not specified, use a model parameter's device.
446
+ :param progress: if True, show a tqdm progress bar.
447
+ :return: a non-differentiable batch of samples.
448
+ """
449
+ final = None
450
+ for sample in self.p_sample_loop_progressive(
451
+ model,
452
+ shape,
453
+ noise=noise,
454
+ clip_denoised=clip_denoised,
455
+ denoised_fn=denoised_fn,
456
+ cond_fn=cond_fn,
457
+ model_kwargs=model_kwargs,
458
+ device=device,
459
+ progress=progress,
460
+ ):
461
+ final = sample
462
+ return final["sample"]
463
+
464
+ def p_sample_loop_progressive(
465
+ self,
466
+ model,
467
+ shape,
468
+ noise=None,
469
+ clip_denoised=True,
470
+ denoised_fn=None,
471
+ cond_fn=None,
472
+ model_kwargs=None,
473
+ device=None,
474
+ progress=False,
475
+ ):
476
+ """
477
+ Generate samples from the model and yield intermediate samples from
478
+ each timestep of diffusion.
479
+ Arguments are the same as p_sample_loop().
480
+ Returns a generator over dicts, where each dict is the return value of
481
+ p_sample().
482
+ """
483
+ if device is None:
484
+ device = next(model.parameters()).device
485
+ assert isinstance(shape, (tuple, list))
486
+ if noise is not None:
487
+ img = noise
488
+ else:
489
+ img = th.randn(*shape, device=device)
490
+ indices = list(range(self.num_timesteps))[::-1]
491
+
492
+ if progress:
493
+ # Lazy import so that we don't depend on tqdm.
494
+ from tqdm.auto import tqdm
495
+
496
+ indices = tqdm(indices)
497
+
498
+ for i in indices:
499
+ t = th.tensor([i] * shape[0], device=device)
500
+ with th.no_grad():
501
+ out = self.p_sample(
502
+ model,
503
+ img,
504
+ t,
505
+ clip_denoised=clip_denoised,
506
+ denoised_fn=denoised_fn,
507
+ cond_fn=cond_fn,
508
+ model_kwargs=model_kwargs,
509
+ )
510
+ yield out
511
+ img = out["sample"]
512
+
513
+ def ddim_sample(
514
+ self,
515
+ model,
516
+ x,
517
+ t,
518
+ clip_denoised=True,
519
+ denoised_fn=None,
520
+ cond_fn=None,
521
+ model_kwargs=None,
522
+ eta=0.0,
523
+ ):
524
+ """
525
+ Sample x_{t-1} from the model using DDIM.
526
+ Same usage as p_sample().
527
+ """
528
+ out = self.p_mean_variance(
529
+ model,
530
+ x,
531
+ t,
532
+ clip_denoised=clip_denoised,
533
+ denoised_fn=denoised_fn,
534
+ model_kwargs=model_kwargs,
535
+ )
536
+ if cond_fn is not None:
537
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
538
+
539
+ # Usually our model outputs epsilon, but we re-derive it
540
+ # in case we used x_start or x_prev prediction.
541
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
542
+
543
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
544
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
545
+ sigma = (
546
+ eta
547
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
548
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
549
+ )
550
+ # Equation 12.
551
+ noise = th.randn_like(x)
552
+ mean_pred = (
553
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
554
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
555
+ )
556
+ nonzero_mask = (
557
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
558
+ ) # no noise when t == 0
559
+ sample = mean_pred + nonzero_mask * sigma * noise
560
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
561
+
562
+ def ddim_reverse_sample(
563
+ self,
564
+ model,
565
+ x,
566
+ t,
567
+ clip_denoised=True,
568
+ denoised_fn=None,
569
+ cond_fn=None,
570
+ model_kwargs=None,
571
+ eta=0.0,
572
+ ):
573
+ """
574
+ Sample x_{t+1} from the model using DDIM reverse ODE.
575
+ """
576
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
577
+ out = self.p_mean_variance(
578
+ model,
579
+ x,
580
+ t,
581
+ clip_denoised=clip_denoised,
582
+ denoised_fn=denoised_fn,
583
+ model_kwargs=model_kwargs,
584
+ )
585
+ if cond_fn is not None:
586
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
587
+ # Usually our model outputs epsilon, but we re-derive it
588
+ # in case we used x_start or x_prev prediction.
589
+ eps = (
590
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
591
+ - out["pred_xstart"]
592
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
593
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
594
+
595
+ # Equation 12. reversed
596
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
597
+
598
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
599
+
600
+ def ddim_sample_loop(
601
+ self,
602
+ model,
603
+ shape,
604
+ noise=None,
605
+ clip_denoised=True,
606
+ denoised_fn=None,
607
+ cond_fn=None,
608
+ model_kwargs=None,
609
+ device=None,
610
+ progress=False,
611
+ eta=0.0,
612
+ ):
613
+ """
614
+ Generate samples from the model using DDIM.
615
+ Same usage as p_sample_loop().
616
+ """
617
+ final = None
618
+ for sample in self.ddim_sample_loop_progressive(
619
+ model,
620
+ shape,
621
+ noise=noise,
622
+ clip_denoised=clip_denoised,
623
+ denoised_fn=denoised_fn,
624
+ cond_fn=cond_fn,
625
+ model_kwargs=model_kwargs,
626
+ device=device,
627
+ progress=progress,
628
+ eta=eta,
629
+ ):
630
+ final = sample
631
+ return final["sample"]
632
+
633
+ def ddim_sample_loop_progressive(
634
+ self,
635
+ model,
636
+ shape,
637
+ noise=None,
638
+ clip_denoised=True,
639
+ denoised_fn=None,
640
+ cond_fn=None,
641
+ model_kwargs=None,
642
+ device=None,
643
+ progress=False,
644
+ eta=0.0,
645
+ ):
646
+ """
647
+ Use DDIM to sample from the model and yield intermediate samples from
648
+ each timestep of DDIM.
649
+ Same usage as p_sample_loop_progressive().
650
+ """
651
+ if device is None:
652
+ device = next(model.parameters()).device
653
+ assert isinstance(shape, (tuple, list))
654
+ if noise is not None:
655
+ img = noise
656
+ else:
657
+ img = th.randn(*shape, device=device)
658
+ indices = list(range(self.num_timesteps))[::-1]
659
+
660
+ if progress:
661
+ # Lazy import so that we don't depend on tqdm.
662
+ from tqdm.auto import tqdm
663
+
664
+ indices = tqdm(indices)
665
+
666
+ for i in indices:
667
+ t = th.tensor([i] * shape[0], device=device)
668
+ with th.no_grad():
669
+ out = self.ddim_sample(
670
+ model,
671
+ img,
672
+ t,
673
+ clip_denoised=clip_denoised,
674
+ denoised_fn=denoised_fn,
675
+ cond_fn=cond_fn,
676
+ model_kwargs=model_kwargs,
677
+ eta=eta,
678
+ )
679
+ yield out
680
+ img = out["sample"]
681
+
682
+ def _vb_terms_bpd(
683
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
684
+ ):
685
+ """
686
+ Get a term for the variational lower-bound.
687
+ The resulting units are bits (rather than nats, as one might expect).
688
+ This allows for comparison to other papers.
689
+ :return: a dict with the following keys:
690
+ - 'output': a shape [N] tensor of NLLs or KLs.
691
+ - 'pred_xstart': the x_0 predictions.
692
+ """
693
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
694
+ x_start=x_start, x_t=x_t, t=t
695
+ )
696
+ out = self.p_mean_variance(
697
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
698
+ )
699
+ kl = normal_kl(
700
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
701
+ )
702
+ kl = mean_flat(kl) / np.log(2.0)
703
+
704
+ decoder_nll = -discretized_gaussian_log_likelihood(
705
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
706
+ )
707
+ assert decoder_nll.shape == x_start.shape
708
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
709
+
710
+ # At the first timestep return the decoder NLL,
711
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
712
+ output = th.where((t == 0), decoder_nll, kl)
713
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
714
+
715
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
716
+ """
717
+ Compute training losses for a single timestep.
718
+ :param model: the model to evaluate loss on.
719
+ :param x_start: the [N x C x ...] tensor of inputs.
720
+ :param t: a batch of timestep indices.
721
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
722
+ pass to the model. This can be used for conditioning.
723
+ :param noise: if specified, the specific Gaussian noise to try to remove.
724
+ :return: a dict with the key "loss" containing a tensor of shape [N].
725
+ Some mean or variance settings may also have other keys.
726
+ """
727
+ if model_kwargs is None:
728
+ model_kwargs = {}
729
+ if noise is None:
730
+ noise = th.randn_like(x_start)
731
+ x_t = self.q_sample(x_start, t, noise=noise)
732
+
733
+ terms = {}
734
+
735
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
736
+ terms["loss"] = self._vb_terms_bpd(
737
+ model=model,
738
+ x_start=x_start,
739
+ x_t=x_t,
740
+ t=t,
741
+ clip_denoised=False,
742
+ model_kwargs=model_kwargs,
743
+ )["output"]
744
+ if self.loss_type == LossType.RESCALED_KL:
745
+ terms["loss"] *= self.num_timesteps
746
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
747
+ model_output = model(x_t, t, **model_kwargs)
748
+
749
+ if self.model_var_type in [
750
+ ModelVarType.LEARNED,
751
+ ModelVarType.LEARNED_RANGE,
752
+ ]:
753
+ B, C = x_t.shape[:2]
754
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
755
+ model_output, model_var_values = th.split(model_output, C, dim=1)
756
+ # Learn the variance using the variational bound, but don't let
757
+ # it affect our mean prediction.
758
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
759
+ terms["vb"] = self._vb_terms_bpd(
760
+ model=lambda *args, r=frozen_out: r,
761
+ x_start=x_start,
762
+ x_t=x_t,
763
+ t=t,
764
+ clip_denoised=False,
765
+ )["output"]
766
+ if self.loss_type == LossType.RESCALED_MSE:
767
+ # Divide by 1000 for equivalence with initial implementation.
768
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
769
+ terms["vb"] *= self.num_timesteps / 1000.0
770
+
771
+ target = {
772
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
773
+ x_start=x_start, x_t=x_t, t=t
774
+ )[0],
775
+ ModelMeanType.START_X: x_start,
776
+ ModelMeanType.EPSILON: noise,
777
+ }[self.model_mean_type]
778
+ assert model_output.shape == target.shape == x_start.shape
779
+ terms["mse"] = mean_flat((target - model_output) ** 2)
780
+ if "vb" in terms:
781
+ terms["loss"] = terms["mse"] + terms["vb"]
782
+ else:
783
+ terms["loss"] = terms["mse"]
784
+ else:
785
+ raise NotImplementedError(self.loss_type)
786
+
787
+ return terms
788
+
789
+ def _prior_bpd(self, x_start):
790
+ """
791
+ Get the prior KL term for the variational lower-bound, measured in
792
+ bits-per-dim.
793
+ This term can't be optimized, as it only depends on the encoder.
794
+ :param x_start: the [N x C x ...] tensor of inputs.
795
+ :return: a batch of [N] KL values (in bits), one per batch element.
796
+ """
797
+ batch_size = x_start.shape[0]
798
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
799
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
800
+ kl_prior = normal_kl(
801
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
802
+ )
803
+ return mean_flat(kl_prior) / np.log(2.0)
804
+
805
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
806
+ """
807
+ Compute the entire variational lower-bound, measured in bits-per-dim,
808
+ as well as other related quantities.
809
+ :param model: the model to evaluate loss on.
810
+ :param x_start: the [N x C x ...] tensor of inputs.
811
+ :param clip_denoised: if True, clip denoised samples.
812
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
813
+ pass to the model. This can be used for conditioning.
814
+ :return: a dict containing the following keys:
815
+ - total_bpd: the total variational lower-bound, per batch element.
816
+ - prior_bpd: the prior term in the lower-bound.
817
+ - vb: an [N x T] tensor of terms in the lower-bound.
818
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
819
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
820
+ """
821
+ device = x_start.device
822
+ batch_size = x_start.shape[0]
823
+
824
+ vb = []
825
+ xstart_mse = []
826
+ mse = []
827
+ for t in list(range(self.num_timesteps))[::-1]:
828
+ t_batch = th.tensor([t] * batch_size, device=device)
829
+ noise = th.randn_like(x_start)
830
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
831
+ # Calculate VLB term at the current timestep
832
+ with th.no_grad():
833
+ out = self._vb_terms_bpd(
834
+ model,
835
+ x_start=x_start,
836
+ x_t=x_t,
837
+ t=t_batch,
838
+ clip_denoised=clip_denoised,
839
+ model_kwargs=model_kwargs,
840
+ )
841
+ vb.append(out["output"])
842
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
843
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
844
+ mse.append(mean_flat((eps - noise) ** 2))
845
+
846
+ vb = th.stack(vb, dim=1)
847
+ xstart_mse = th.stack(xstart_mse, dim=1)
848
+ mse = th.stack(mse, dim=1)
849
+
850
+ prior_bpd = self._prior_bpd(x_start)
851
+ total_bpd = vb.sum(dim=1) + prior_bpd
852
+ return {
853
+ "total_bpd": total_bpd,
854
+ "prior_bpd": prior_bpd,
855
+ "vb": vb,
856
+ "xstart_mse": xstart_mse,
857
+ "mse": mse,
858
+ }
859
+
860
+
861
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
862
+ """
863
+ Extract values from a 1-D numpy array for a batch of indices.
864
+ :param arr: the 1-D numpy array.
865
+ :param timesteps: a tensor of indices into the array to extract.
866
+ :param broadcast_shape: a larger shape of K dimensions with the batch
867
+ dimension equal to the length of timesteps.
868
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
869
+ """
870
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
871
+ while len(res.shape) < len(broadcast_shape):
872
+ res = res[..., None]
873
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
core/diffusion/respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
core/diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
core/models.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17
+ import xformers.ops
18
+
19
+
20
+
21
+ def modulate(x, shift, scale):
22
+ return x * (1 + scale) + shift
23
+
24
+
25
+ #################################################################################
26
+ # Embedding Layers for Timesteps and Class Labels #
27
+ #################################################################################
28
+
29
+ class TimestepEmbedder(nn.Module):
30
+ """
31
+ Embeds scalar timesteps into vector representations.
32
+ """
33
+ def __init__(self, hidden_size, frequency_embedding_size=256):
34
+ super().__init__()
35
+ self.mlp = nn.Sequential(
36
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
37
+ nn.SiLU(),
38
+ nn.Linear(hidden_size, hidden_size, bias=True),
39
+ )
40
+ self.frequency_embedding_size = frequency_embedding_size
41
+
42
+ @staticmethod
43
+ def timestep_embedding(t, dim, max_period=10000):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
53
+ half = dim // 2
54
+ freqs = torch.exp(
55
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
56
+ ).to(device=t.device)
57
+ args = t[:, None].float() * freqs[None]
58
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
59
+ if dim % 2:
60
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
61
+ return embedding
62
+
63
+ def forward(self, t):
64
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
65
+ t_emb = self.mlp(t_freq)
66
+ return t_emb
67
+
68
+
69
+ class LabelEmbedder(nn.Module):
70
+ """
71
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
72
+ """
73
+ def __init__(self, num_classes, hidden_size, dropout_prob):
74
+ super().__init__()
75
+ use_cfg_embedding = dropout_prob > 0
76
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
77
+ self.num_classes = num_classes
78
+ self.dropout_prob = dropout_prob
79
+
80
+ def token_drop(self, labels, force_drop_ids=None):
81
+ """
82
+ Drops labels to enable classifier-free guidance.
83
+ """
84
+ if force_drop_ids is None:
85
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
86
+ else:
87
+ drop_ids = force_drop_ids == 1
88
+ labels = torch.where(drop_ids, self.num_classes, labels)
89
+ return labels
90
+
91
+ def forward(self, labels, train, force_drop_ids=None):
92
+ use_dropout = self.dropout_prob > 0
93
+ if (train and use_dropout) or (force_drop_ids is not None):
94
+ labels = self.token_drop(labels, force_drop_ids)
95
+ embeddings = self.embedding_table(labels)
96
+ return embeddings
97
+
98
+
99
+ class MultiHeadCrossAttention(nn.Module):
100
+ def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs):
101
+ super(MultiHeadCrossAttention, self).__init__()
102
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
103
+
104
+ self.d_model = d_model
105
+ self.num_heads = num_heads
106
+ self.head_dim = d_model // num_heads
107
+
108
+ self.q_linear = nn.Linear(d_model, d_model)
109
+ self.kv_linear = nn.Linear(d_model, d_model*2)
110
+ self.attn_drop = nn.Dropout(attn_drop)
111
+ self.proj = nn.Linear(d_model, d_model)
112
+ self.proj_drop = nn.Dropout(proj_drop)
113
+
114
+ def forward(self, x, cond, mask=None):
115
+ # query: img tokens; key/value: condition; mask: if padding tokens
116
+ B, N, C = x.shape
117
+
118
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
119
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
120
+ k, v = kv.unbind(2)
121
+ attn_bias = None
122
+ if mask is not None:
123
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
124
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
125
+ x = x.view(B, -1, C)
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+
129
+ return x
130
+
131
+ #################################################################################
132
+ # Core DiT Model #
133
+ #################################################################################
134
+
135
+ class DiTBlock(nn.Module):
136
+ """
137
+ A DiT block with cross attention for conditioning. Adapted from PixArt implementation.
138
+ """
139
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
140
+ super().__init__()
141
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
143
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
144
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
145
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
146
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
147
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
148
+ #self.adaLN_modulation = nn.Sequential(
149
+ # nn.SiLU(),
150
+ # nn.Linear(hidden_size, 6 * hidden_size, bias=True)
151
+ #)
152
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
153
+
154
+ def forward(self, x, y, t, mask=None):
155
+ B, N, C = x.shape
156
+
157
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
158
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)
159
+ x = x + self.cross_attn(x, y, mask)
160
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
161
+ return x
162
+
163
+
164
+ class FinalLayer(nn.Module):
165
+ """
166
+ The final layer of DiT.
167
+ """
168
+ def __init__(self, hidden_size, out_channels):
169
+ super().__init__()
170
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
171
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
172
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
173
+ self.out_channels = out_channels
174
+
175
+ def forward(self, x, t):
176
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
177
+ x = modulate(self.norm_final(x), shift, scale)
178
+ x = self.linear(x)
179
+ return x
180
+
181
+
182
+ class DiT(nn.Module):
183
+ """
184
+ Diffusion model with a Transformer backbone.
185
+ """
186
+ def __init__(
187
+ self,
188
+ input_size=32,
189
+ in_channels=1,
190
+ hidden_size=128,
191
+ depth=12,
192
+ num_heads=6,
193
+ mlp_ratio=4.0,
194
+ condition_channels=768,
195
+ learn_sigma=True,
196
+ ):
197
+ super().__init__()
198
+ self.learn_sigma = learn_sigma
199
+ self.input_size = input_size
200
+ self.in_channels = in_channels
201
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
202
+ self.num_heads = num_heads
203
+
204
+ self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
205
+ self.t_embedder = TimestepEmbedder(hidden_size)
206
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
207
+ self.t_block = nn.Sequential(
208
+ nn.SiLU(),
209
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
210
+ )
211
+ self.y_embedder = Mlp(in_features=condition_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=approx_gelu, drop=0)
212
+ # Will use fixed sin-cos embedding:
213
+ self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size), requires_grad=False)
214
+
215
+ self.blocks = nn.ModuleList([
216
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
217
+ ])
218
+ self.final_layer = FinalLayer(hidden_size, self.out_channels)
219
+ self.initialize_weights()
220
+
221
+ def initialize_weights(self):
222
+ # Initialize transformer layers:
223
+ def _basic_init(module):
224
+ if isinstance(module, nn.Linear):
225
+ torch.nn.init.xavier_uniform_(module.weight)
226
+ if module.bias is not None:
227
+ nn.init.constant_(module.bias, 0)
228
+ self.apply(_basic_init)
229
+
230
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
231
+ grid_1d = np.arange(self.input_size, dtype=np.float32)
232
+ pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], grid_1d)
233
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
234
+
235
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
236
+ nn.init.xavier_uniform_(self.x_embedder.weight)
237
+ nn.init.constant_(self.x_embedder.bias, 0)
238
+
239
+ # Initialize label embedding table:
240
+ nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
241
+ nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
242
+
243
+ # Initialize timestep embedding MLP:
244
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
245
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
246
+
247
+ # Zero-out adaLN modulation layers in DiT blocks:
248
+ for block in self.blocks:
249
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
250
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
251
+
252
+ # Zero-out output layers:
253
+ nn.init.constant_(self.final_layer.linear.weight, 0)
254
+ nn.init.constant_(self.final_layer.linear.bias, 0)
255
+
256
+ def ckpt_wrapper(self, module):
257
+ def ckpt_forward(*inputs):
258
+ outputs = module(*inputs)
259
+ return outputs
260
+ return ckpt_forward
261
+
262
+ def forward(self, x, t, y):
263
+ """
264
+ Forward pass of DiT.
265
+ x: (N, 1, T) tensor of PCG params
266
+ t: (N,) tensor of diffusion timesteps
267
+ y: (N, 1, C) or (N, M, C) tensor of condition image features
268
+ """
269
+ x = x.permute(0, 2, 1)
270
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T is the input token number (params number)
271
+ t = self.t_embedder(t) # (N, D)
272
+ t0 = self.t_block(t)
273
+ y = self.y_embedder(y) # (N, M, D)
274
+
275
+ # mask for batch cross-attention
276
+ y_lens = [y.shape[1]] * y.shape[0]
277
+ y = y.view(1, -1, x.shape[-1])
278
+ for block in self.blocks:
279
+ x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, y, t0, y_lens) # (N, T, D)
280
+ x = self.final_layer(x, t) # (N, T, out_channels)
281
+ return x.permute(0, 2, 1)
282
+
283
+
284
+ #################################################################################
285
+ # Sine/Cosine Positional Embedding Functions #
286
+ #################################################################################
287
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
288
+
289
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
290
+ """
291
+ embed_dim: output dimension for each position
292
+ pos: a list of positions to be encoded: size (M,)
293
+ out: (M, D)
294
+ """
295
+ assert embed_dim % 2 == 0
296
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
297
+ omega /= embed_dim / 2.
298
+ omega = 1. / 10000**omega # (D/2,)
299
+
300
+ pos = pos.reshape(-1) # (M,)
301
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
302
+
303
+ emb_sin = np.sin(out) # (M, D/2)
304
+ emb_cos = np.cos(out) # (M, D/2)
305
+
306
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
307
+ return emb
308
+
309
+
310
+ #################################################################################
311
+ # DiT Configs #
312
+ #################################################################################
313
+
314
+ def DiT_S(**kwargs):
315
+ # 39M
316
+ return DiT(depth=16, hidden_size=384, num_heads=6, **kwargs)
317
+
318
+ def DiT_mini(**kwargs):
319
+ # 7.6M
320
+ return DiT(depth=12, hidden_size=192, num_heads=6, **kwargs)
321
+
322
+ def DiT_tiny(**kwargs):
323
+ # 1.3M
324
+ return DiT(depth=8, hidden_size=96, num_heads=6, **kwargs)
325
+
326
+
327
+ DiT_models = {
328
+ 'DiT_S': DiT_S,
329
+ 'DiT_mini': DiT_mini,
330
+ 'DiT_tiny': DiT_tiny
331
+ }
core/utils/__pycache__/camera.cpython-310.pyc ADDED
Binary file (2.1 kB). View file
 
core/utils/__pycache__/dinov2.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
core/utils/__pycache__/io.cpython-310.pyc ADDED
Binary file (2.07 kB). View file