eliphatfs commited on
Commit
7a4df11
·
1 Parent(s): 7ee7303
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ *.egg-info
openshape/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import hf_hub_download
4
+ from .ppat_rgb import Projected, PointPatchTransformer
5
+
6
+
7
+ def module(state_dict: dict, name):
8
+ return {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items() if k.startswith(name + '.')}
9
+
10
+
11
+ def G14(s):
12
+ model = Projected(
13
+ PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6),
14
+ nn.Linear(512, 1280)
15
+ )
16
+ model.load_state_dict(module(s['state_dict'], 'module'))
17
+ return model
18
+
19
+
20
+ def L14(s):
21
+ model = Projected(
22
+ PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6),
23
+ nn.Linear(512, 768)
24
+ )
25
+ model.load_state_dict(module(s, 'pc_encoder'))
26
+ return model
27
+
28
+
29
+ def B32(s):
30
+ model = PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6)
31
+ model.load_state_dict(module(s, 'pc_encoder'))
32
+ return model
33
+
34
+
35
+ model_list = {
36
+ "openshape-pointbert-vitb32-rgb": B32,
37
+ "openshape-pointbert-vitl14-rgb": L14,
38
+ "openshape-pointbert-vitg14-rgb": G14,
39
+ }
40
+
41
+
42
+ def load_pc_encoder(name):
43
+ s = torch.load(hf_hub_download("OpenShape/" + name, "model.pt", token=True), map_location='cpu')
44
+ model = model_list[name](s).eval()
45
+ if torch.cuda.is_available():
46
+ model.cuda()
47
+ return model
openshape/demo/__init__.py ADDED
File without changes
openshape/demo/caption.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import numpy as np
3
+ import torch
4
+ from typing import Tuple, List, Union, Optional
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
+ from huggingface_hub import hf_hub_download
7
+
8
+
9
+ N = type(None)
10
+ V = np.array
11
+ ARRAY = np.ndarray
12
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
13
+ VS = Union[Tuple[V, ...], List[V]]
14
+ VN = Union[V, N]
15
+ VNS = Union[VS, N]
16
+ T = torch.Tensor
17
+ TS = Union[Tuple[T, ...], List[T]]
18
+ TN = Optional[T]
19
+ TNS = Union[Tuple[TN, ...], List[TN]]
20
+ TSN = Optional[TS]
21
+ TA = Union[T, ARRAY]
22
+
23
+
24
+ D = torch.device
25
+
26
+
27
+ class MLP(nn.Module):
28
+
29
+ def forward(self, x: T) -> T:
30
+ return self.model(x)
31
+
32
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
33
+ super(MLP, self).__init__()
34
+ layers = []
35
+ for i in range(len(sizes) -1):
36
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
37
+ if i < len(sizes) - 2:
38
+ layers.append(act())
39
+ self.model = nn.Sequential(*layers)
40
+
41
+
42
+ class ClipCaptionModel(nn.Module):
43
+
44
+ #@functools.lru_cache #FIXME
45
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
46
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
47
+
48
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
49
+ embedding_text = self.gpt.transformer.wte(tokens)
50
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
51
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
52
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
53
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
54
+ if labels is not None:
55
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
56
+ labels = torch.cat((dummy_token, tokens), dim=1)
57
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
58
+ return out
59
+
60
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
61
+ super(ClipCaptionModel, self).__init__()
62
+ self.prefix_length = prefix_length
63
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
64
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
65
+ if prefix_length > 10: # not enough memory
66
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
67
+ else:
68
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
69
+
70
+
71
+ class ClipCaptionPrefix(ClipCaptionModel):
72
+
73
+ def parameters(self, recurse: bool = True):
74
+ return self.clip_project.parameters()
75
+
76
+ def train(self, mode: bool = True):
77
+ super(ClipCaptionPrefix, self).train(mode)
78
+ self.gpt.eval()
79
+ return self
80
+
81
+
82
+ def generate2(
83
+ model,
84
+ tokenizer,
85
+ tokens=None,
86
+ prompt=None,
87
+ embed=None,
88
+ entry_count=1,
89
+ entry_length=67, # maximum number of words
90
+ top_p=0.8,
91
+ temperature=1.,
92
+ stop_token: str = '.',
93
+ ):
94
+ model.eval()
95
+ generated_num = 0
96
+ generated_list = []
97
+ stop_token_index = tokenizer.encode(stop_token)[0]
98
+ filter_value = -float("Inf")
99
+ device = next(model.parameters()).device
100
+ score_col = []
101
+ with torch.no_grad():
102
+
103
+ for entry_idx in range(entry_count):
104
+ if embed is not None:
105
+ generated = embed
106
+ else:
107
+ if tokens is None:
108
+ tokens = torch.tensor(tokenizer.encode(prompt))
109
+ tokens = tokens.unsqueeze(0).to(device)
110
+
111
+ generated = model.gpt.transformer.wte(tokens)
112
+
113
+ for i in range(entry_length):
114
+
115
+ outputs = model.gpt(inputs_embeds=generated)
116
+ logits = outputs.logits
117
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
118
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
119
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
120
+ sorted_indices_to_remove = cumulative_probs > top_p
121
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
122
+ ..., :-1
123
+ ].clone()
124
+ sorted_indices_to_remove[..., 0] = 0
125
+
126
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
127
+ logits[:, indices_to_remove] = filter_value
128
+ next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
129
+ score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
130
+ score_col.append(score)
131
+ next_token_embed = model.gpt.transformer.wte(next_token)
132
+ if tokens is None:
133
+ tokens = next_token
134
+ else:
135
+ tokens = torch.cat((tokens, next_token), dim=1)
136
+ generated = torch.cat((generated, next_token_embed), dim=1)
137
+ if stop_token_index == next_token.item():
138
+ break
139
+
140
+ output_list = list(tokens.squeeze(0).cpu().numpy())
141
+ output_text = tokenizer.decode(output_list)
142
+ generated_list.append(output_text)
143
+ return generated_list[0]
144
+
145
+
146
+ @torch.no_grad()
147
+ def pc_caption(pc_encoder: torch.nn.Module, pc, cond_scale):
148
+ ref_dev = next(pc_encoder.parameters()).device
149
+ prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
150
+ prefix = prefix.float() * cond_scale
151
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
152
+ text = generate2(model, tokenizer, embed=prefix_embed)
153
+ return text
154
+
155
+
156
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
157
+ prefix_length = 10
158
+ model = ClipCaptionModel(prefix_length)
159
+ # print(model.gpt_embedding_size)
160
+ model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt', token=True), map_location='cpu'))
161
+ model.eval()
162
+ if torch.cuda.is_available():
163
+ model = model.cuda()
openshape/demo/classification.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from collections import OrderedDict
4
+ from . import lvis
5
+
6
+
7
+ @torch.no_grad()
8
+ def pred_lvis_sims(pc_encoder: torch.nn.Module, pc):
9
+ ref_dev = next(pc_encoder.parameters()).device
10
+ enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
11
+ sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze())
12
+ argsort = torch.argsort(sim, descending=True)
13
+ return OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
openshape/demo/lvis.py ADDED
@@ -0,0 +1,1162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+
5
+ feats = torch.load(os.path.join(os.path.dirname(__file__), 'lvis_cats.pt'))
6
+ categories = [
7
+ 'Band_Aid',
8
+ 'Bible',
9
+ 'CD_player',
10
+ 'Christmas_tree',
11
+ 'Dixie_cup',
12
+ 'Ferris_wheel',
13
+ 'Lego',
14
+ 'Rollerblade',
15
+ 'Sharpie',
16
+ 'Tabasco_sauce',
17
+ 'aerosol_can',
18
+ 'air_conditioner',
19
+ 'airplane',
20
+ 'alarm_clock',
21
+ 'alcohol',
22
+ 'alligator',
23
+ 'almond',
24
+ 'ambulance',
25
+ 'amplifier',
26
+ 'anklet',
27
+ 'antenna',
28
+ 'apple',
29
+ 'apricot',
30
+ 'apron',
31
+ 'aquarium',
32
+ 'arctic_(type_of_shoe)',
33
+ 'armband',
34
+ 'armchair',
35
+ 'armoire',
36
+ 'armor',
37
+ 'army_tank',
38
+ 'artichoke',
39
+ 'ashtray',
40
+ 'asparagus',
41
+ 'atomizer',
42
+ 'automatic_washer',
43
+ 'avocado',
44
+ 'award',
45
+ 'awning',
46
+ 'ax',
47
+ 'baboon',
48
+ 'baby_buggy',
49
+ 'backpack',
50
+ 'bagel',
51
+ 'baguet',
52
+ 'bait',
53
+ 'ball',
54
+ 'ballet_skirt',
55
+ 'balloon',
56
+ 'bamboo',
57
+ 'banana',
58
+ 'bandage',
59
+ 'bandanna',
60
+ 'banjo',
61
+ 'banner',
62
+ 'barbell',
63
+ 'barge',
64
+ 'barrel',
65
+ 'barrow',
66
+ 'baseball',
67
+ 'baseball_bat',
68
+ 'baseball_cap',
69
+ 'baseball_glove',
70
+ 'basket',
71
+ 'basketball',
72
+ 'basketball_backboard',
73
+ 'bass_horn',
74
+ 'bat_(animal)',
75
+ 'bath_mat',
76
+ 'bath_towel',
77
+ 'bathrobe',
78
+ 'bathtub',
79
+ 'battery',
80
+ 'beachball',
81
+ 'bead',
82
+ 'beanbag',
83
+ 'beanie',
84
+ 'bear',
85
+ 'bed',
86
+ 'bedpan',
87
+ 'bedspread',
88
+ 'beef_(food)',
89
+ 'beeper',
90
+ 'beer_bottle',
91
+ 'beer_can',
92
+ 'beetle',
93
+ 'bell',
94
+ 'bell_pepper',
95
+ 'belt',
96
+ 'belt_buckle',
97
+ 'bench',
98
+ 'beret',
99
+ 'bicycle',
100
+ 'billboard',
101
+ 'binder',
102
+ 'binoculars',
103
+ 'bird',
104
+ 'birdbath',
105
+ 'birdcage',
106
+ 'birdfeeder',
107
+ 'birdhouse',
108
+ 'birthday_cake',
109
+ 'birthday_card',
110
+ 'blackberry',
111
+ 'blackboard',
112
+ 'blanket',
113
+ 'blazer',
114
+ 'blender',
115
+ 'blimp',
116
+ 'blouse',
117
+ 'blueberry',
118
+ 'boat',
119
+ 'bob',
120
+ 'bobbin',
121
+ 'boiled_egg',
122
+ 'bolo_tie',
123
+ 'bolt',
124
+ 'bonnet',
125
+ 'book',
126
+ 'bookcase',
127
+ 'booklet',
128
+ 'bookmark',
129
+ 'boom_microphone',
130
+ 'boot',
131
+ 'bottle',
132
+ 'bottle_cap',
133
+ 'bottle_opener',
134
+ 'bouquet',
135
+ 'bow-tie',
136
+ 'bow_(decorative_ribbons)',
137
+ 'bow_(weapon)',
138
+ 'bowl',
139
+ 'bowler_hat',
140
+ 'bowling_ball',
141
+ 'box',
142
+ 'boxing_glove',
143
+ 'bracelet',
144
+ 'brass_plaque',
145
+ 'brassiere',
146
+ 'bread',
147
+ 'bread-bin',
148
+ 'breechcloth',
149
+ 'bridal_gown',
150
+ 'briefcase',
151
+ 'broach',
152
+ 'broccoli',
153
+ 'broom',
154
+ 'brownie',
155
+ 'brussels_sprouts',
156
+ 'bubble_gum',
157
+ 'bucket',
158
+ 'bulldog',
159
+ 'bulldozer',
160
+ 'bullet_train',
161
+ 'bulletin_board',
162
+ 'bulletproof_vest',
163
+ 'bullhorn',
164
+ 'bun',
165
+ 'bunk_bed',
166
+ 'buoy',
167
+ 'burrito',
168
+ 'bus_(vehicle)',
169
+ 'business_card',
170
+ 'butter',
171
+ 'butterfly',
172
+ 'button',
173
+ 'cab_(taxi)',
174
+ 'cabana',
175
+ 'cabin_car',
176
+ 'cabinet',
177
+ 'cake',
178
+ 'calculator',
179
+ 'calendar',
180
+ 'calf',
181
+ 'camcorder',
182
+ 'camel',
183
+ 'camera',
184
+ 'camera_lens',
185
+ 'camper_(vehicle)',
186
+ 'can',
187
+ 'can_opener',
188
+ 'candle',
189
+ 'candle_holder',
190
+ 'candy_bar',
191
+ 'candy_cane',
192
+ 'canister',
193
+ 'canoe',
194
+ 'cantaloup',
195
+ 'canteen',
196
+ 'cap_(headwear)',
197
+ 'cape',
198
+ 'cappuccino',
199
+ 'car_(automobile)',
200
+ 'car_battery',
201
+ 'card',
202
+ 'cardigan',
203
+ 'cargo_ship',
204
+ 'carnation',
205
+ 'carrot',
206
+ 'cart',
207
+ 'carton',
208
+ 'cash_register',
209
+ 'casserole',
210
+ 'cassette',
211
+ 'cast',
212
+ 'cat',
213
+ 'cauliflower',
214
+ 'cayenne_(spice)',
215
+ 'celery',
216
+ 'cellular_telephone',
217
+ 'chair',
218
+ 'chaise_longue',
219
+ 'chalice',
220
+ 'chandelier',
221
+ 'checkbook',
222
+ 'checkerboard',
223
+ 'cherry',
224
+ 'chessboard',
225
+ 'chicken_(animal)',
226
+ 'chili_(vegetable)',
227
+ 'chime',
228
+ 'chinaware',
229
+ 'chocolate_bar',
230
+ 'chocolate_cake',
231
+ 'chocolate_milk',
232
+ 'chocolate_mousse',
233
+ 'choker',
234
+ 'chopping_board',
235
+ 'chopstick',
236
+ 'cider',
237
+ 'cigar_box',
238
+ 'cigarette',
239
+ 'cigarette_case',
240
+ 'cincture',
241
+ 'cistern',
242
+ 'clarinet',
243
+ 'clasp',
244
+ 'cleansing_agent',
245
+ 'cleat_(for_securing_rope)',
246
+ 'clementine',
247
+ 'clip',
248
+ 'clipboard',
249
+ 'clippers_(for_plants)',
250
+ 'cloak',
251
+ 'clock',
252
+ 'clock_tower',
253
+ 'clothes_hamper',
254
+ 'clothespin',
255
+ 'clutch_bag',
256
+ 'coaster',
257
+ 'coat',
258
+ 'coat_hanger',
259
+ 'coatrack',
260
+ 'cock',
261
+ 'cockroach',
262
+ 'cocoa_(beverage)',
263
+ 'coconut',
264
+ 'coffee_maker',
265
+ 'coffee_table',
266
+ 'coffeepot',
267
+ 'coil',
268
+ 'coin',
269
+ 'colander',
270
+ 'coloring_material',
271
+ 'combination_lock',
272
+ 'comic_book',
273
+ 'compass',
274
+ 'computer_keyboard',
275
+ 'condiment',
276
+ 'cone',
277
+ 'control',
278
+ 'convertible_(automobile)',
279
+ 'cooker',
280
+ 'cookie',
281
+ 'cooking_utensil',
282
+ 'cooler_(for_food)',
283
+ 'cork_(bottle_plug)',
284
+ 'corkboard',
285
+ 'corkscrew',
286
+ 'cornbread',
287
+ 'cornet',
288
+ 'cornice',
289
+ 'cornmeal',
290
+ 'corset',
291
+ 'costume',
292
+ 'cougar',
293
+ 'cover',
294
+ 'coverall',
295
+ 'cow',
296
+ 'cowbell',
297
+ 'cowboy_hat',
298
+ 'crab_(animal)',
299
+ 'crabmeat',
300
+ 'cracker',
301
+ 'crape',
302
+ 'crate',
303
+ 'crawfish',
304
+ 'crayon',
305
+ 'cream_pitcher',
306
+ 'crescent_roll',
307
+ 'crib',
308
+ 'crisp_(potato_chip)',
309
+ 'crossbar',
310
+ 'crouton',
311
+ 'crow',
312
+ 'crowbar',
313
+ 'crown',
314
+ 'crucifix',
315
+ 'cruise_ship',
316
+ 'crutch',
317
+ 'cub_(animal)',
318
+ 'cube',
319
+ 'cucumber',
320
+ 'cufflink',
321
+ 'cup',
322
+ 'cupboard',
323
+ 'cupcake',
324
+ 'curtain',
325
+ 'cushion',
326
+ 'cylinder',
327
+ 'cymbal',
328
+ 'dagger',
329
+ 'dalmatian',
330
+ 'dartboard',
331
+ 'date_(fruit)',
332
+ 'deadbolt',
333
+ 'deck_chair',
334
+ 'deer',
335
+ 'desk',
336
+ 'detergent',
337
+ 'diaper',
338
+ 'diary',
339
+ 'die',
340
+ 'dinghy',
341
+ 'dining_table',
342
+ 'dirt_bike',
343
+ 'dish',
344
+ 'dish_antenna',
345
+ 'dishrag',
346
+ 'dishtowel',
347
+ 'dishwasher',
348
+ 'dishwasher_detergent',
349
+ 'dispenser',
350
+ 'dog',
351
+ 'dog_collar',
352
+ 'doll',
353
+ 'dollar',
354
+ 'dollhouse',
355
+ 'dolphin',
356
+ 'domestic_ass',
357
+ 'doorknob',
358
+ 'doormat',
359
+ 'doughnut',
360
+ 'dove',
361
+ 'dragonfly',
362
+ 'drawer',
363
+ 'dress',
364
+ 'dress_hat',
365
+ 'dress_suit',
366
+ 'dresser',
367
+ 'drill',
368
+ 'drone',
369
+ 'drum_(musical_instrument)',
370
+ 'drumstick',
371
+ 'duck',
372
+ 'duckling',
373
+ 'duct_tape',
374
+ 'duffel_bag',
375
+ 'dumbbell',
376
+ 'dumpster',
377
+ 'dustpan',
378
+ 'eagle',
379
+ 'earphone',
380
+ 'earplug',
381
+ 'earring',
382
+ 'easel',
383
+ 'eclair',
384
+ 'edible_corn',
385
+ 'eel',
386
+ 'egg',
387
+ 'egg_roll',
388
+ 'egg_yolk',
389
+ 'eggbeater',
390
+ 'eggplant',
391
+ 'elephant',
392
+ 'elevator_car',
393
+ 'elk',
394
+ 'envelope',
395
+ 'eraser',
396
+ 'escargot',
397
+ 'eyepatch',
398
+ 'falcon',
399
+ 'fan',
400
+ 'faucet',
401
+ 'fedora',
402
+ 'ferret',
403
+ 'ferry',
404
+ 'fig_(fruit)',
405
+ 'fighter_jet',
406
+ 'figurine',
407
+ 'file_(tool)',
408
+ 'file_cabinet',
409
+ 'fire_alarm',
410
+ 'fire_engine',
411
+ 'fire_extinguisher',
412
+ 'fire_hose',
413
+ 'fireplace',
414
+ 'fireplug',
415
+ 'first-aid_kit',
416
+ 'fish',
417
+ 'fish_(food)',
418
+ 'fishbowl',
419
+ 'fishing_rod',
420
+ 'flag',
421
+ 'flagpole',
422
+ 'flamingo',
423
+ 'flannel',
424
+ 'flap',
425
+ 'flash',
426
+ 'flashlight',
427
+ 'fleece',
428
+ 'flip-flop_(sandal)',
429
+ 'flipper_(footwear)',
430
+ 'flower_arrangement',
431
+ 'flowerpot',
432
+ 'flute_glass',
433
+ 'foal',
434
+ 'folding_chair',
435
+ 'food_processor',
436
+ 'football_(American)',
437
+ 'football_helmet',
438
+ 'footstool',
439
+ 'fork',
440
+ 'forklift',
441
+ 'freight_car',
442
+ 'freshener',
443
+ 'frisbee',
444
+ 'frog',
445
+ 'fruit_juice',
446
+ 'frying_pan',
447
+ 'fume_hood',
448
+ 'funnel',
449
+ 'futon',
450
+ 'gameboard',
451
+ 'garbage',
452
+ 'garbage_truck',
453
+ 'garden_hose',
454
+ 'gargle',
455
+ 'gargoyle',
456
+ 'garlic',
457
+ 'gasmask',
458
+ 'gazelle',
459
+ 'gelatin',
460
+ 'gemstone',
461
+ 'generator',
462
+ 'giant_panda',
463
+ 'gift_wrap',
464
+ 'ginger',
465
+ 'giraffe',
466
+ 'glass_(drink_container)',
467
+ 'globe',
468
+ 'glove',
469
+ 'goat',
470
+ 'goggles',
471
+ 'goldfish',
472
+ 'golf_club',
473
+ 'golfcart',
474
+ 'gondola_(boat)',
475
+ 'goose',
476
+ 'gorilla',
477
+ 'gourd',
478
+ 'grape',
479
+ 'grater',
480
+ 'gravestone',
481
+ 'gravy_boat',
482
+ 'green_bean',
483
+ 'green_onion',
484
+ 'grill',
485
+ 'grits',
486
+ 'grizzly',
487
+ 'grocery_bag',
488
+ 'guitar',
489
+ 'gull',
490
+ 'gun',
491
+ 'hair_dryer',
492
+ 'hairbrush',
493
+ 'hairnet',
494
+ 'halter_top',
495
+ 'ham',
496
+ 'hamburger',
497
+ 'hammer',
498
+ 'hammock',
499
+ 'hamper',
500
+ 'hamster',
501
+ 'hand_glass',
502
+ 'hand_towel',
503
+ 'handbag',
504
+ 'handcart',
505
+ 'handcuff',
506
+ 'handkerchief',
507
+ 'handle',
508
+ 'handsaw',
509
+ 'hardback_book',
510
+ 'harmonium',
511
+ 'hat',
512
+ 'hatbox',
513
+ 'headband',
514
+ 'headboard',
515
+ 'headlight',
516
+ 'headscarf',
517
+ 'headset',
518
+ 'headstall_(for_horses)',
519
+ 'heart',
520
+ 'heater',
521
+ 'helicopter',
522
+ 'helmet',
523
+ 'heron',
524
+ 'highchair',
525
+ 'hinge',
526
+ 'hippopotamus',
527
+ 'hockey_stick',
528
+ 'hog',
529
+ 'honey',
530
+ 'hook',
531
+ 'hookah',
532
+ 'horned_cow',
533
+ 'hornet',
534
+ 'horse',
535
+ 'horse_buggy',
536
+ 'horse_carriage',
537
+ 'hose',
538
+ 'hot-air_balloon',
539
+ 'hot_sauce',
540
+ 'hotplate',
541
+ 'hourglass',
542
+ 'houseboat',
543
+ 'hummingbird',
544
+ 'iPod',
545
+ 'ice_maker',
546
+ 'ice_pack',
547
+ 'ice_skate',
548
+ 'icecream',
549
+ 'identity_card',
550
+ 'igniter',
551
+ 'inhaler',
552
+ 'inkpad',
553
+ 'iron_(for_clothing)',
554
+ 'ironing_board',
555
+ 'jacket',
556
+ 'jam',
557
+ 'jar',
558
+ 'jean',
559
+ 'jeep',
560
+ 'jersey',
561
+ 'jet_plane',
562
+ 'jewel',
563
+ 'jewelry',
564
+ 'joystick',
565
+ 'jumpsuit',
566
+ 'kayak',
567
+ 'keg',
568
+ 'kennel',
569
+ 'kettle',
570
+ 'key',
571
+ 'keycard',
572
+ 'kilt',
573
+ 'kimono',
574
+ 'kitchen_sink',
575
+ 'kitchen_table',
576
+ 'kite',
577
+ 'kitten',
578
+ 'kiwi_fruit',
579
+ 'knee_pad',
580
+ 'knife',
581
+ 'knitting_needle',
582
+ 'knob',
583
+ 'knocker_(on_a_door)',
584
+ 'koala',
585
+ 'lab_coat',
586
+ 'ladder',
587
+ 'ladle',
588
+ 'ladybug',
589
+ 'lamb-chop',
590
+ 'lamb_(animal)',
591
+ 'lamp',
592
+ 'lamppost',
593
+ 'lampshade',
594
+ 'lantern',
595
+ 'laptop_computer',
596
+ 'lasagna',
597
+ 'latch',
598
+ 'lawn_mower',
599
+ 'leather',
600
+ 'legging_(clothing)',
601
+ 'legume',
602
+ 'lemon',
603
+ 'lemonade',
604
+ 'lettuce',
605
+ 'license_plate',
606
+ 'life_buoy',
607
+ 'life_jacket',
608
+ 'lightbulb',
609
+ 'lightning_rod',
610
+ 'lime',
611
+ 'limousine',
612
+ 'lion',
613
+ 'lip_balm',
614
+ 'liquor',
615
+ 'lizard',
616
+ 'locker',
617
+ 'log',
618
+ 'lollipop',
619
+ 'loveseat',
620
+ 'machine_gun',
621
+ 'magazine',
622
+ 'magnet',
623
+ 'mail_slot',
624
+ 'mailbox_(at_home)',
625
+ 'mallard',
626
+ 'mallet',
627
+ 'mammoth',
628
+ 'manatee',
629
+ 'mandarin_orange',
630
+ 'manger',
631
+ 'manhole',
632
+ 'map',
633
+ 'marker',
634
+ 'martini',
635
+ 'mascot',
636
+ 'mashed_potato',
637
+ 'mask',
638
+ 'mast',
639
+ 'mat_(gym_equipment)',
640
+ 'matchbox',
641
+ 'mattress',
642
+ 'measuring_cup',
643
+ 'measuring_stick',
644
+ 'meatball',
645
+ 'medicine',
646
+ 'melon',
647
+ 'microphone',
648
+ 'microscope',
649
+ 'microwave_oven',
650
+ 'milestone',
651
+ 'milk',
652
+ 'milk_can',
653
+ 'milkshake',
654
+ 'minivan',
655
+ 'mint_candy',
656
+ 'mirror',
657
+ 'mitten',
658
+ 'mixer_(kitchen_tool)',
659
+ 'money',
660
+ 'monitor_(computer_equipment) computer_monitor',
661
+ 'monkey',
662
+ 'mop',
663
+ 'motor',
664
+ 'motor_scooter',
665
+ 'motor_vehicle',
666
+ 'motorcycle',
667
+ 'mound_(baseball)',
668
+ 'mouse_(computer_equipment)',
669
+ 'mousepad',
670
+ 'muffin',
671
+ 'mug',
672
+ 'mushroom',
673
+ 'music_stool',
674
+ 'musical_instrument',
675
+ 'nailfile',
676
+ 'napkin',
677
+ 'neckerchief',
678
+ 'necklace',
679
+ 'necktie',
680
+ 'needle',
681
+ 'nest',
682
+ 'newspaper',
683
+ 'newsstand',
684
+ 'nightshirt',
685
+ 'notebook',
686
+ 'notepad',
687
+ 'nut',
688
+ 'nutcracker',
689
+ 'oar',
690
+ 'octopus_(animal)',
691
+ 'octopus_(food)',
692
+ 'oil_lamp',
693
+ 'olive_oil',
694
+ 'omelet',
695
+ 'onion',
696
+ 'orange_(fruit)',
697
+ 'orange_juice',
698
+ 'ostrich',
699
+ 'ottoman',
700
+ 'oven',
701
+ 'overalls_(clothing)',
702
+ 'owl',
703
+ 'pacifier',
704
+ 'packet',
705
+ 'paddle',
706
+ 'padlock',
707
+ 'paintbrush',
708
+ 'painting',
709
+ 'pajamas',
710
+ 'palette',
711
+ 'pan_(for_cooking)',
712
+ 'pan_(metal_container)',
713
+ 'pancake',
714
+ 'papaya',
715
+ 'paper_plate',
716
+ 'paper_towel',
717
+ 'paperback_book',
718
+ 'paperweight',
719
+ 'parachute',
720
+ 'parakeet',
721
+ 'parasail_(sports)',
722
+ 'parasol',
723
+ 'parchment',
724
+ 'parka',
725
+ 'parking_meter',
726
+ 'parrot',
727
+ 'passenger_car_(part_of_a_train)',
728
+ 'passenger_ship',
729
+ 'passport',
730
+ 'pastry',
731
+ 'patty_(food)',
732
+ 'pea_(food)',
733
+ 'peach',
734
+ 'peanut_butter',
735
+ 'pear',
736
+ 'peeler_(tool_for_fruit_and_vegetables)',
737
+ 'pegboard',
738
+ 'pelican',
739
+ 'pen',
740
+ 'pencil',
741
+ 'pencil_box',
742
+ 'pencil_sharpener',
743
+ 'pendulum',
744
+ 'penguin',
745
+ 'pennant',
746
+ 'penny_(coin)',
747
+ 'pepper',
748
+ 'pepper_mill',
749
+ 'perfume',
750
+ 'persimmon',
751
+ 'person',
752
+ 'pet',
753
+ 'pew_(church_bench)',
754
+ 'phonebook',
755
+ 'phonograph_record',
756
+ 'piano',
757
+ 'pickle',
758
+ 'pickup_truck',
759
+ 'pie',
760
+ 'pigeon',
761
+ 'piggy_bank',
762
+ 'pillow',
763
+ 'pineapple',
764
+ 'pinecone',
765
+ 'ping-pong_ball',
766
+ 'pinwheel',
767
+ 'pipe',
768
+ 'pipe_bowl',
769
+ 'pirate_flag',
770
+ 'pistol',
771
+ 'pita_(bread)',
772
+ 'pitcher_(vessel_for_liquid)',
773
+ 'pitchfork',
774
+ 'pizza',
775
+ 'place_mat',
776
+ 'plastic_bag',
777
+ 'plate',
778
+ 'platter',
779
+ 'playpen',
780
+ 'pliers',
781
+ 'plow_(farm_equipment)',
782
+ 'plume',
783
+ 'pocket_watch',
784
+ 'pocketknife',
785
+ 'poker_(fire_stirring_tool)',
786
+ 'poker_chip',
787
+ 'polar_bear',
788
+ 'pole',
789
+ 'police_cruiser',
790
+ 'polo_shirt',
791
+ 'poncho',
792
+ 'pony',
793
+ 'pool_table',
794
+ 'pop_(soda)',
795
+ 'popsicle',
796
+ 'postbox_(public)',
797
+ 'postcard',
798
+ 'poster',
799
+ 'pot',
800
+ 'potato',
801
+ 'potholder',
802
+ 'pottery',
803
+ 'pouch',
804
+ 'power_shovel',
805
+ 'prawn',
806
+ 'pretzel',
807
+ 'printer',
808
+ 'projectile_(weapon)',
809
+ 'projector',
810
+ 'propeller',
811
+ 'prune',
812
+ 'pudding',
813
+ 'puffer_(fish)',
814
+ 'puffin',
815
+ 'pug-dog',
816
+ 'pumpkin',
817
+ 'puncher',
818
+ 'puppet',
819
+ 'puppy',
820
+ 'quesadilla',
821
+ 'quiche',
822
+ 'quilt',
823
+ 'rabbit',
824
+ 'race_car',
825
+ 'racket',
826
+ 'radar',
827
+ 'radiator',
828
+ 'radio_receiver',
829
+ 'radish',
830
+ 'raft',
831
+ 'rag_doll',
832
+ 'railcar_(part_of_a_train)',
833
+ 'raincoat',
834
+ 'ram_(animal)',
835
+ 'raspberry',
836
+ 'rat',
837
+ 'reamer_(juicer)',
838
+ 'rearview_mirror',
839
+ 'receipt',
840
+ 'recliner',
841
+ 'record_player',
842
+ 'reflector',
843
+ 'refrigerator',
844
+ 'remote_control',
845
+ 'rhinoceros',
846
+ 'rib_(food)',
847
+ 'rifle',
848
+ 'ring',
849
+ 'river_boat',
850
+ 'road_map',
851
+ 'robe',
852
+ 'rocking_chair',
853
+ 'rodent',
854
+ 'roller_skate',
855
+ 'rolling_pin',
856
+ 'root_beer',
857
+ 'router_(computer_equipment)',
858
+ 'rubber_band',
859
+ 'runner_(carpet)',
860
+ 'saddle_(on_an_animal)',
861
+ 'saddle_blanket',
862
+ 'saddlebag',
863
+ 'safety_pin',
864
+ 'sail',
865
+ 'salad',
866
+ 'salad_plate',
867
+ 'salami',
868
+ 'salmon_(fish)',
869
+ 'salmon_(food)',
870
+ 'salsa',
871
+ 'saltshaker',
872
+ 'sandal_(type_of_shoe)',
873
+ 'sandwich',
874
+ 'satchel',
875
+ 'saucepan',
876
+ 'saucer',
877
+ 'sausage',
878
+ 'sawhorse',
879
+ 'saxophone',
880
+ 'scale_(measuring_instrument)',
881
+ 'scarecrow',
882
+ 'scarf',
883
+ 'school_bus',
884
+ 'scissors',
885
+ 'scoreboard',
886
+ 'scraper',
887
+ 'screwdriver',
888
+ 'scrubbing_brush',
889
+ 'sculpture',
890
+ 'seabird',
891
+ 'seahorse',
892
+ 'seaplane',
893
+ 'seashell',
894
+ 'sewing_machine',
895
+ 'shaker',
896
+ 'shampoo',
897
+ 'shark',
898
+ 'sharpener',
899
+ 'shaver_(electric)',
900
+ 'shaving_cream',
901
+ 'shawl',
902
+ 'shears',
903
+ 'sheep',
904
+ 'shepherd_dog',
905
+ 'sherbert',
906
+ 'shield',
907
+ 'shirt',
908
+ 'shoe',
909
+ 'shopping_bag',
910
+ 'shopping_cart',
911
+ 'short_pants',
912
+ 'shot_glass',
913
+ 'shoulder_bag',
914
+ 'shovel',
915
+ 'shower_cap',
916
+ 'shower_curtain',
917
+ 'shower_head',
918
+ 'shredder_(for_paper)',
919
+ 'signboard',
920
+ 'silo',
921
+ 'sink',
922
+ 'skateboard',
923
+ 'skewer',
924
+ 'ski',
925
+ 'ski_boot',
926
+ 'ski_parka',
927
+ 'ski_pole',
928
+ 'skirt',
929
+ 'skullcap',
930
+ 'sled',
931
+ 'sleeping_bag',
932
+ 'slide',
933
+ 'slipper_(footwear)',
934
+ 'smoothie',
935
+ 'snake',
936
+ 'snowboard',
937
+ 'snowman',
938
+ 'snowmobile',
939
+ 'soap',
940
+ 'soccer_ball',
941
+ 'sock',
942
+ 'sofa',
943
+ 'sofa_bed',
944
+ 'softball',
945
+ 'solar_array',
946
+ 'sombrero',
947
+ 'soup',
948
+ 'soup_bowl',
949
+ 'soupspoon',
950
+ 'soya_milk',
951
+ 'space_shuttle',
952
+ 'sparkler_(fireworks)',
953
+ 'spatula',
954
+ 'speaker_(stero_equipment)',
955
+ 'spear',
956
+ 'spectacles',
957
+ 'spice_rack',
958
+ 'spider',
959
+ 'sponge',
960
+ 'spoon',
961
+ 'sportswear',
962
+ 'spotlight',
963
+ 'squid_(food)',
964
+ 'squirrel',
965
+ 'stagecoach',
966
+ 'stapler_(stapling_machine)',
967
+ 'starfish',
968
+ 'statue_(sculpture)',
969
+ 'steak_(food)',
970
+ 'steak_knife',
971
+ 'steering_wheel',
972
+ 'step_stool',
973
+ 'stepladder',
974
+ 'stereo_(sound_system)',
975
+ 'stew',
976
+ 'stirrer',
977
+ 'stirrup',
978
+ 'stool',
979
+ 'stop_sign',
980
+ 'stove',
981
+ 'strainer',
982
+ 'strap',
983
+ 'straw_(for_drinking)',
984
+ 'strawberry',
985
+ 'street_sign',
986
+ 'streetlight',
987
+ 'string_cheese',
988
+ 'stylus',
989
+ 'subwoofer',
990
+ 'sugar_bowl',
991
+ 'sugarcane_(plant)',
992
+ 'suit_(clothing)',
993
+ 'suitcase',
994
+ 'sunflower',
995
+ 'sunglasses',
996
+ 'sunhat',
997
+ 'surfboard',
998
+ 'sushi',
999
+ 'suspenders',
1000
+ 'sweat_pants',
1001
+ 'sweatband',
1002
+ 'sweater',
1003
+ 'sweatshirt',
1004
+ 'sweet_potato',
1005
+ 'swimsuit',
1006
+ 'sword',
1007
+ 'syringe',
1008
+ 'table',
1009
+ 'table-tennis_table',
1010
+ 'table_lamp',
1011
+ 'tablecloth',
1012
+ 'tachometer',
1013
+ 'taco',
1014
+ 'tag',
1015
+ 'taillight',
1016
+ 'tambourine',
1017
+ 'tank_(storage_vessel)',
1018
+ 'tank_top_(clothing)',
1019
+ 'tape_(sticky_cloth_or_paper)',
1020
+ 'tape_measure',
1021
+ 'tapestry',
1022
+ 'tarp',
1023
+ 'tartan',
1024
+ 'tassel',
1025
+ 'teacup',
1026
+ 'teakettle',
1027
+ 'teapot',
1028
+ 'teddy_bear',
1029
+ 'telephone',
1030
+ 'telephone_booth',
1031
+ 'telephone_pole',
1032
+ 'telephoto_lens',
1033
+ 'television_camera',
1034
+ 'television_set',
1035
+ 'tennis_ball',
1036
+ 'tennis_racket',
1037
+ 'tequila',
1038
+ 'thermometer',
1039
+ 'thermos_bottle',
1040
+ 'thermostat',
1041
+ 'thimble',
1042
+ 'thread',
1043
+ 'thumbtack',
1044
+ 'tiara',
1045
+ 'tiger',
1046
+ 'tights_(clothing)',
1047
+ 'timer',
1048
+ 'tinfoil',
1049
+ 'tinsel',
1050
+ 'tissue_paper',
1051
+ 'toast_(food)',
1052
+ 'toaster',
1053
+ 'toaster_oven',
1054
+ 'tobacco_pipe',
1055
+ 'toilet',
1056
+ 'toilet_tissue',
1057
+ 'tomato',
1058
+ 'tongs',
1059
+ 'toolbox',
1060
+ 'toothbrush',
1061
+ 'toothpaste',
1062
+ 'toothpick',
1063
+ 'tortilla',
1064
+ 'tote_bag',
1065
+ 'tow_truck',
1066
+ 'towel',
1067
+ 'towel_rack',
1068
+ 'toy',
1069
+ 'tractor_(farm_equipment)',
1070
+ 'traffic_light',
1071
+ 'trailer_truck',
1072
+ 'train_(railroad_vehicle)',
1073
+ 'trampoline',
1074
+ 'trash_can',
1075
+ 'tray',
1076
+ 'trench_coat',
1077
+ 'triangle_(musical_instrument)',
1078
+ 'tricycle',
1079
+ 'tripod',
1080
+ 'trophy_cup',
1081
+ 'trousers',
1082
+ 'truck',
1083
+ 'truffle_(chocolate)',
1084
+ 'trunk',
1085
+ 'turban',
1086
+ 'turkey_(food)',
1087
+ 'turnip',
1088
+ 'turtle',
1089
+ 'turtleneck_(clothing)',
1090
+ 'tux',
1091
+ 'typewriter',
1092
+ 'umbrella',
1093
+ 'underdrawers',
1094
+ 'underwear',
1095
+ 'unicycle',
1096
+ 'urinal',
1097
+ 'urn',
1098
+ 'vacuum_cleaner',
1099
+ 'vase',
1100
+ 'veil',
1101
+ 'vending_machine',
1102
+ 'vent',
1103
+ 'vest',
1104
+ 'videotape',
1105
+ 'vinegar',
1106
+ 'violin',
1107
+ 'visor',
1108
+ 'vodka',
1109
+ 'volleyball',
1110
+ 'vulture',
1111
+ 'waffle',
1112
+ 'waffle_iron',
1113
+ 'wagon',
1114
+ 'walking_cane',
1115
+ 'walking_stick',
1116
+ 'wall_clock',
1117
+ 'wall_socket',
1118
+ 'wallet',
1119
+ 'walrus',
1120
+ 'wardrobe',
1121
+ 'washbasin',
1122
+ 'watch',
1123
+ 'water_bottle',
1124
+ 'water_cooler',
1125
+ 'water_faucet',
1126
+ 'water_gun',
1127
+ 'water_heater',
1128
+ 'water_jug',
1129
+ 'water_scooter',
1130
+ 'water_ski',
1131
+ 'water_tower',
1132
+ 'watering_can',
1133
+ 'watermelon',
1134
+ 'weathervane',
1135
+ 'webcam',
1136
+ 'wedding_cake',
1137
+ 'wedding_ring',
1138
+ 'wet_suit',
1139
+ 'wheel',
1140
+ 'wheelchair',
1141
+ 'whipped_cream',
1142
+ 'wig',
1143
+ 'wind_chime',
1144
+ 'windmill',
1145
+ 'window_box_(for_plants)',
1146
+ 'windsock',
1147
+ 'wine_bottle',
1148
+ 'wine_bucket',
1149
+ 'wineglass',
1150
+ 'wok',
1151
+ 'wolf',
1152
+ 'wooden_leg',
1153
+ 'wooden_spoon',
1154
+ 'wreath',
1155
+ 'wrench',
1156
+ 'wristband',
1157
+ 'wristlet',
1158
+ 'yacht',
1159
+ 'yogurt',
1160
+ 'zebra',
1161
+ 'zucchini'
1162
+ ]
openshape/demo/lvis_cats.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71baf2d3f89884a082f1db75d0e94ac9a3b8036553877a3fdd98861cd01c4aec
3
+ size 5919467
openshape/demo/misc_utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import trimesh
3
+ import trimesh.sample
4
+ import trimesh.visual
5
+ import trimesh.proximity
6
+ import objaverse
7
+ import streamlit as st
8
+ import plotly.graph_objects as go
9
+ import matplotlib.pyplot as plotlib
10
+
11
+
12
+ def get_bytes(x: str):
13
+ import io, requests
14
+ return io.BytesIO(requests.get(x).content)
15
+
16
+
17
+ def get_image(x: str):
18
+ try:
19
+ return plotlib.imread(get_bytes(x), 'auto')
20
+ except Exception:
21
+ raise ValueError("Invalid image", x)
22
+
23
+
24
+ def model_to_pc(mesh: trimesh.Trimesh, n_sample_points=10000):
25
+ f32 = numpy.float32
26
+ rad = numpy.sqrt(mesh.area / (3 * n_sample_points))
27
+ for _ in range(24):
28
+ pcd, face_idx = trimesh.sample.sample_surface_even(mesh, n_sample_points, rad)
29
+ rad *= 0.85
30
+ if len(pcd) == n_sample_points:
31
+ break
32
+ else:
33
+ raise ValueError("Bad geometry, cannot finish sampling.", mesh.area)
34
+ if isinstance(mesh.visual, trimesh.visual.ColorVisuals):
35
+ rgba = mesh.visual.face_colors[face_idx]
36
+ elif isinstance(mesh.visual, trimesh.visual.TextureVisuals):
37
+ bc = trimesh.proximity.points_to_barycentric(mesh.triangles[face_idx], pcd)
38
+ if mesh.visual.uv is None or len(mesh.visual.uv) < mesh.faces[face_idx].max():
39
+ uv = numpy.zeros([len(bc), 2])
40
+ st.warning("Invalid UV, filling with zeroes")
41
+ else:
42
+ uv = numpy.einsum('ntc,nt->nc', mesh.visual.uv[mesh.faces[face_idx]], bc)
43
+ material = mesh.visual.material
44
+ if hasattr(material, 'materials'):
45
+ if len(material.materials) == 0:
46
+ rgba = numpy.ones_like(pcd) * 0.8
47
+ texture = None
48
+ st.warning("Empty MultiMaterial found, falling back to light grey")
49
+ else:
50
+ material = material.materials[0]
51
+ if hasattr(material, 'image'):
52
+ texture = material.image
53
+ if texture is None:
54
+ rgba = numpy.zeros([len(uv), len(material.main_color)]) + material.main_color
55
+ elif hasattr(material, 'baseColorTexture'):
56
+ texture = material.baseColorTexture
57
+ if texture is None:
58
+ rgba = numpy.zeros([len(uv), len(material.main_color)]) + material.main_color
59
+ else:
60
+ texture = None
61
+ rgba = numpy.ones_like(pcd) * 0.8
62
+ st.warning("Unknown material, falling back to light grey")
63
+ if texture is not None:
64
+ rgba = trimesh.visual.uv_to_interpolated_color(uv, texture)
65
+ if rgba.max() > 1:
66
+ if rgba.max() > 255:
67
+ rgba = rgba.astype(f32) / rgba.max()
68
+ else:
69
+ rgba = rgba.astype(f32) / 255.0
70
+ return numpy.concatenate([numpy.array(pcd, f32), numpy.array(rgba, f32)[:, :3]], axis=-1)
71
+
72
+
73
+ def trimesh_to_pc(scene_or_mesh):
74
+ if isinstance(scene_or_mesh, trimesh.Scene):
75
+ meshes = []
76
+ for node_name in scene_or_mesh.graph.nodes_geometry:
77
+ # which geometry does this node refer to
78
+ transform, geometry_name = scene_or_mesh.graph[node_name]
79
+
80
+ # get the actual potential mesh instance
81
+ geometry = scene_or_mesh.geometry[geometry_name].copy()
82
+ if not hasattr(geometry, 'triangles'):
83
+ continue
84
+ geometry: trimesh.Trimesh
85
+ geometry = geometry.apply_transform(transform)
86
+ meshes.append(geometry)
87
+ total_area = sum(geometry.area for geometry in meshes)
88
+ if total_area < 1e-6:
89
+ raise ValueError("Bad geometry: total area too small (< 1e-6)")
90
+ pcs = []
91
+ for geometry in meshes:
92
+ pcs.append(model_to_pc(geometry, max(1, round(geometry.area / total_area * 10000))))
93
+ if not len(pcs):
94
+ raise ValueError("Unsupported mesh object: no triangles found")
95
+ return numpy.concatenate(pcs)
96
+ else:
97
+ assert isinstance(scene_or_mesh, trimesh.Trimesh)
98
+ return model_to_pc(scene_or_mesh, 10000)
99
+
100
+
101
+ def input_3d_shape():
102
+ objaid = st.text_input("Enter an Objaverse ID")
103
+ model = st.file_uploader("Or upload a model (.glb/.obj/.ply)")
104
+ npy = st.file_uploader("Or upload a point cloud numpy array (.npy of Nx3 XYZ or Nx6 XYZRGB)")
105
+ swap_yz_axes = st.checkbox("Swap Y/Z axes of input (Y is up for OpenShape)")
106
+ f32 = numpy.float32
107
+
108
+ def load_data(prog):
109
+ # load the model
110
+ prog.progress(0.05, "Preparing Point Cloud")
111
+ if npy is not None:
112
+ pc: numpy.ndarray = numpy.load(npy)
113
+ elif model is not None:
114
+ pc = trimesh_to_pc(trimesh.load(model, model.name.split(".")[-1]))
115
+ elif objaid:
116
+ prog.progress(0.1, "Downloading Objaverse Object")
117
+ objamodel = objaverse.load_objects([objaid])[objaid]
118
+ prog.progress(0.2, "Preparing Point Cloud")
119
+ pc = trimesh_to_pc(trimesh.load(objamodel))
120
+ else:
121
+ raise ValueError("You have to supply 3D input!")
122
+ prog.progress(0.25, "Preprocessing Point Cloud")
123
+ assert pc.ndim == 2, "invalid pc shape: ndim = %d != 2" % pc.ndim
124
+ assert pc.shape[1] in [3, 6], "invalid pc shape: should have 3/6 channels, got %d" % pc.shape[1]
125
+ if swap_yz_axes:
126
+ pc[:, [1, 2]] = pc[:, [2, 1]]
127
+ pc[:, :3] = pc[:, :3] - numpy.mean(pc[:, :3], axis=0)
128
+ pc[:, :3] = pc[:, :3] / numpy.linalg.norm(pc[:, :3], axis=-1).max()
129
+ if pc.shape[1] == 3:
130
+ pc = numpy.concatenate([pc, numpy.ones_like(pc)], axis=-1)
131
+ prog.progress(0.3, "Preprocessed Point Cloud")
132
+ return pc.astype(f32)
133
+
134
+ return load_data
135
+
136
+
137
+ def render_pc(pc):
138
+ rand = numpy.random.permutation(len(pc))[:2048]
139
+ pc = pc[rand]
140
+ rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
141
+ g = go.Scatter3d(
142
+ x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
143
+ mode='markers',
144
+ marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
145
+ )
146
+ fig = go.Figure(data=[g])
147
+ fig.update_layout(scene_camera=dict(up=dict(x=0, y=1, z=0)))
148
+ fig.update_scenes(aspectmode="data")
149
+ col1, col2 = st.columns(2)
150
+ with col1:
151
+ st.plotly_chart(fig, use_container_width=True)
152
+ # st.caption("Point Cloud Preview")
153
+ return col2
openshape/demo/retrieval.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ meta = json.load(
8
+ open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", token=True, repo_type='dataset'))
9
+ )
10
+ # {
11
+ # "u": "94db219c315742909fee67deeeacae15",
12
+ # "name": "knife", "like": 0, "view": 35,
13
+ # "tags": ["game-ready", "damascus", "damascus_steel", "kabar-knife", "knife", "blender", "blender3d", "gameready"],
14
+ # "cats": ["weapons-military"],
15
+ # "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg",
16
+ # "desc": "", "faces": 1724, "size": 11955, "lic": "by",
17
+ # "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb"
18
+ # }
19
+ meta = {x['u']: x for x in meta['entries']}
20
+ deser = torch.load(
21
+ hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", token=True, repo_type='dataset'), map_location='cpu'
22
+ )
23
+ us = deser['us']
24
+ feats = deser['feats']
25
+
26
+
27
+ def retrieve(embedding, top):
28
+ sims = []
29
+ embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze()
30
+ for chunk in torch.split(feats, 10240):
31
+ sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
32
+ sims = torch.cat(sims)
33
+ sims, idx = torch.topk(sims, top * 2)
34
+ results = []
35
+ for i, sim in zip(idx, sims):
36
+ if us[i] in meta:
37
+ results.append(dict(meta[us[i]], sim=sim))
38
+ if len(results) >= top:
39
+ break
40
+ return results
openshape/demo/sd_pc2img.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_redstone as rst
3
+ import transformers
4
+ from diffusers import StableUnCLIPImg2ImgPipeline
5
+
6
+
7
+ class Wrapper(transformers.modeling_utils.PreTrainedModel):
8
+ def __init__(self) -> None:
9
+ super().__init__(transformers.configuration_utils.PretrainedConfig())
10
+ self.param = torch.nn.Parameter(torch.tensor(0.))
11
+
12
+ def forward(self, x):
13
+ return rst.ObjectProxy(image_embeds=x)
14
+
15
+
16
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
17
+ "diffusers/stable-diffusion-2-1-unclip-i2i-l",
18
+ image_encoder = Wrapper()
19
+ )
20
+ if torch.cuda.is_available():
21
+ pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
22
+ pipe.enable_model_cpu_offload(torch.cuda.current_device())
23
+
24
+
25
+ @torch.no_grad()
26
+ def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
27
+ ref_dev = next(pc_encoder.parameters()).device
28
+ enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
29
+ return pipe(
30
+ prompt="best quality, super high resolution, " + prompt,
31
+ negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
32
+ image=torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2,
33
+ width=width, height=height,
34
+ guidance_scale=cfg_scale,
35
+ noise_level=noise_level,
36
+ callback=callback,
37
+ num_inference_steps=num_steps
38
+ ).images[0]
openshape/pointnet_util.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from time import time
5
+ import numpy as np
6
+ import dgl.geometry
7
+
8
+ def timeit(tag, t):
9
+ print("{}: {}s".format(tag, time() - t))
10
+ return time()
11
+
12
+ def pc_normalize(pc):
13
+ l = pc.shape[0]
14
+ centroid = np.mean(pc, axis=0)
15
+ pc = pc - centroid
16
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
17
+ pc = pc / m
18
+ return pc
19
+
20
+ def square_distance(src, dst):
21
+ """
22
+ Calculate Euclid distance between each two points.
23
+
24
+ src^T * dst = xn * xm + yn * ym + zn * zm;
25
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
26
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
27
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
28
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
29
+
30
+ Input:
31
+ src: source points, [B, N, C]
32
+ dst: target points, [B, M, C]
33
+ Output:
34
+ dist: per-point square distance, [B, N, M]
35
+ """
36
+ B, N, _ = src.shape
37
+ _, M, _ = dst.shape
38
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
39
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
40
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
41
+ return dist
42
+
43
+
44
+ def index_points(points, idx):
45
+ """
46
+
47
+ Input:
48
+ points: input points data, [B, N, C]
49
+ idx: sample index data, [B, S]
50
+ Return:
51
+ new_points:, indexed points data, [B, S, C]
52
+ """
53
+ device = points.device
54
+ B = points.shape[0]
55
+ view_shape = list(idx.shape)
56
+ view_shape[1:] = [1] * (len(view_shape) - 1)
57
+ repeat_shape = list(idx.shape)
58
+ repeat_shape[0] = 1
59
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
60
+ new_points = points[batch_indices, idx, :]
61
+ return new_points
62
+
63
+
64
+ def farthest_point_sample(xyz, npoint):
65
+ """
66
+ Input:
67
+ xyz: pointcloud data, [B, N, 3]
68
+ npoint: number of samples
69
+ Return:
70
+ centroids: sampled pointcloud index, [B, npoint]
71
+ """
72
+ return dgl.geometry.farthest_point_sampler(xyz, npoint)
73
+ device = xyz.device
74
+ B, N, C = xyz.shape
75
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76
+ distance = torch.ones(B, N).to(device) * 1e10
77
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
79
+ for i in range(npoint):
80
+ centroids[:, i] = farthest
81
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82
+ dist = torch.sum((xyz - centroid) ** 2, -1)
83
+ mask = dist < distance
84
+ distance[mask] = dist[mask]
85
+ farthest = torch.max(distance, -1)[1]
86
+ return centroids
87
+
88
+
89
+ def query_ball_point(radius, nsample, xyz, new_xyz):
90
+ """
91
+ Input:
92
+ radius: local region radius
93
+ nsample: max sample number in local region
94
+ xyz: all points, [B, N, 3]
95
+ new_xyz: query points, [B, S, 3]
96
+ Return:
97
+ group_idx: grouped points index, [B, S, nsample]
98
+ """
99
+ device = xyz.device
100
+ B, N, C = xyz.shape
101
+ _, S, _ = new_xyz.shape
102
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
103
+ sqrdists = square_distance(new_xyz, xyz)
104
+ group_idx[sqrdists > radius ** 2] = N
105
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
106
+ group_first = group_idx[..., :1].repeat([1, 1, nsample])
107
+ mask = group_idx == N
108
+ group_idx[mask] = group_first[mask]
109
+ return group_idx
110
+
111
+
112
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
113
+ """
114
+ Input:
115
+ npoint:
116
+ radius:
117
+ nsample:
118
+ xyz: input points position data, [B, N, 3]
119
+ points: input points data, [B, N, D]
120
+ Return:
121
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
122
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
123
+ """
124
+ B, N, C = xyz.shape
125
+ S = npoint
126
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
127
+ # torch.cuda.empty_cache()
128
+ new_xyz = index_points(xyz, fps_idx)
129
+ # torch.cuda.empty_cache()
130
+ idx = query_ball_point(radius, nsample, xyz, new_xyz)
131
+ # torch.cuda.empty_cache()
132
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
133
+ # torch.cuda.empty_cache()
134
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
135
+ # torch.cuda.empty_cache()
136
+
137
+ if points is not None:
138
+ grouped_points = index_points(points, idx)
139
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
140
+ else:
141
+ new_points = grouped_xyz_norm
142
+ if returnfps:
143
+ return new_xyz, new_points, grouped_xyz, fps_idx
144
+ else:
145
+ return new_xyz, new_points
146
+
147
+
148
+ def sample_and_group_all(xyz, points):
149
+ """
150
+ Input:
151
+ xyz: input points position data, [B, N, 3]
152
+ points: input points data, [B, N, D]
153
+ Return:
154
+ new_xyz: sampled points position data, [B, 1, 3]
155
+ new_points: sampled points data, [B, 1, N, 3+D]
156
+ """
157
+ device = xyz.device
158
+ B, N, C = xyz.shape
159
+ new_xyz = torch.zeros(B, 1, C).to(device)
160
+ grouped_xyz = xyz.view(B, 1, N, C)
161
+ if points is not None:
162
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
163
+ else:
164
+ new_points = grouped_xyz
165
+ return new_xyz, new_points
166
+
167
+
168
+ class PointNetSetAbstraction(nn.Module):
169
+ def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
170
+ super(PointNetSetAbstraction, self).__init__()
171
+ self.npoint = npoint
172
+ self.radius = radius
173
+ self.nsample = nsample
174
+ self.mlp_convs = nn.ModuleList()
175
+ self.mlp_bns = nn.ModuleList()
176
+ last_channel = in_channel
177
+ for out_channel in mlp:
178
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
179
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
180
+ last_channel = out_channel
181
+ self.group_all = group_all
182
+
183
+ def forward(self, xyz, points):
184
+ """
185
+ Input:
186
+ xyz: input points position data, [B, C, N]
187
+ points: input points data, [B, D, N]
188
+ Return:
189
+ new_xyz: sampled points position data, [B, C, S]
190
+ new_points_concat: sample points feature data, [B, D', S]
191
+ """
192
+ xyz = xyz.permute(0, 2, 1)
193
+ if points is not None:
194
+ points = points.permute(0, 2, 1)
195
+
196
+ if self.group_all:
197
+ new_xyz, new_points = sample_and_group_all(xyz, points)
198
+ else:
199
+ new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
200
+ # new_xyz: sampled points position data, [B, npoint, C]
201
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
202
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
203
+ for i, conv in enumerate(self.mlp_convs):
204
+ bn = self.mlp_bns[i]
205
+ new_points = F.relu(bn(conv(new_points)))
206
+
207
+ new_points = torch.max(new_points, 2)[0]
208
+ new_xyz = new_xyz.permute(0, 2, 1)
209
+ return new_xyz, new_points
210
+
211
+
212
+ class PointNetSetAbstractionMsg(nn.Module):
213
+ def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
214
+ super(PointNetSetAbstractionMsg, self).__init__()
215
+ self.npoint = npoint
216
+ self.radius_list = radius_list
217
+ self.nsample_list = nsample_list
218
+ self.conv_blocks = nn.ModuleList()
219
+ self.bn_blocks = nn.ModuleList()
220
+ for i in range(len(mlp_list)):
221
+ convs = nn.ModuleList()
222
+ bns = nn.ModuleList()
223
+ last_channel = in_channel + 3
224
+ for out_channel in mlp_list[i]:
225
+ convs.append(nn.Conv2d(last_channel, out_channel, 1))
226
+ bns.append(nn.BatchNorm2d(out_channel))
227
+ last_channel = out_channel
228
+ self.conv_blocks.append(convs)
229
+ self.bn_blocks.append(bns)
230
+
231
+ def forward(self, xyz, points):
232
+ """
233
+ Input:
234
+ xyz: input points position data, [B, C, N]
235
+ points: input points data, [B, D, N]
236
+ Return:
237
+ new_xyz: sampled points position data, [B, C, S]
238
+ new_points_concat: sample points feature data, [B, D', S]
239
+ """
240
+ xyz = xyz.permute(0, 2, 1)
241
+ if points is not None:
242
+ points = points.permute(0, 2, 1)
243
+
244
+ B, N, C = xyz.shape
245
+ S = self.npoint
246
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
247
+ new_points_list = []
248
+ for i, radius in enumerate(self.radius_list):
249
+ K = self.nsample_list[i]
250
+ group_idx = query_ball_point(radius, K, xyz, new_xyz)
251
+ grouped_xyz = index_points(xyz, group_idx)
252
+ grouped_xyz -= new_xyz.view(B, S, 1, C)
253
+ if points is not None:
254
+ grouped_points = index_points(points, group_idx)
255
+ grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
256
+ else:
257
+ grouped_points = grouped_xyz
258
+
259
+ grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
260
+ for j in range(len(self.conv_blocks[i])):
261
+ conv = self.conv_blocks[i][j]
262
+ bn = self.bn_blocks[i][j]
263
+ grouped_points = F.relu(bn(conv(grouped_points)))
264
+ new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
265
+ new_points_list.append(new_points)
266
+
267
+ new_xyz = new_xyz.permute(0, 2, 1)
268
+ new_points_concat = torch.cat(new_points_list, dim=1)
269
+ return new_xyz, new_points_concat
270
+
271
+
272
+ class PointNetFeaturePropagation(nn.Module):
273
+ def __init__(self, in_channel, mlp):
274
+ super(PointNetFeaturePropagation, self).__init__()
275
+ self.mlp_convs = nn.ModuleList()
276
+ self.mlp_bns = nn.ModuleList()
277
+ last_channel = in_channel
278
+ for out_channel in mlp:
279
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
280
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
281
+ last_channel = out_channel
282
+
283
+ def forward(self, xyz1, xyz2, points1, points2):
284
+ """
285
+ Input:
286
+ xyz1: input points position data, [B, C, N]
287
+ xyz2: sampled input points position data, [B, C, S]
288
+ points1: input points data, [B, D, N]
289
+ points2: input points data, [B, D, S]
290
+ Return:
291
+ new_points: upsampled points data, [B, D', N]
292
+ """
293
+ xyz1 = xyz1.permute(0, 2, 1)
294
+ xyz2 = xyz2.permute(0, 2, 1)
295
+
296
+ points2 = points2.permute(0, 2, 1)
297
+ B, N, C = xyz1.shape
298
+ _, S, _ = xyz2.shape
299
+
300
+ if S == 1:
301
+ interpolated_points = points2.repeat(1, N, 1)
302
+ else:
303
+ dists = square_distance(xyz1, xyz2)
304
+ dists, idx = dists.sort(dim=-1)
305
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
306
+
307
+ dist_recip = 1.0 / (dists + 1e-8)
308
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
309
+ weight = dist_recip / norm
310
+ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
311
+
312
+ if points1 is not None:
313
+ points1 = points1.permute(0, 2, 1)
314
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
315
+ else:
316
+ new_points = interpolated_points
317
+
318
+ new_points = new_points.permute(0, 2, 1)
319
+ for i, conv in enumerate(self.mlp_convs):
320
+ bn = self.mlp_bns[i]
321
+ new_points = F.relu(bn(conv(new_points)))
322
+ return new_points
323
+
openshape/ppat_rgb.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch_redstone as rst
4
+ from einops import rearrange
5
+ from .pointnet_util import PointNetSetAbstraction
6
+
7
+
8
+ class PreNorm(nn.Module):
9
+ def __init__(self, dim, fn):
10
+ super().__init__()
11
+ self.norm = nn.LayerNorm(dim)
12
+ self.fn = fn
13
+ def forward(self, x, *extra_args, **kwargs):
14
+ return self.fn(self.norm(x), *extra_args, **kwargs)
15
+
16
+ class FeedForward(nn.Module):
17
+ def __init__(self, dim, hidden_dim, dropout = 0.):
18
+ super().__init__()
19
+ self.net = nn.Sequential(
20
+ nn.Linear(dim, hidden_dim),
21
+ nn.GELU(),
22
+ nn.Dropout(dropout),
23
+ nn.Linear(hidden_dim, dim),
24
+ nn.Dropout(dropout)
25
+ )
26
+ def forward(self, x):
27
+ return self.net(x)
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., rel_pe = False):
31
+ super().__init__()
32
+ inner_dim = dim_head * heads
33
+ project_out = not (heads == 1 and dim_head == dim)
34
+
35
+ self.heads = heads
36
+ self.scale = dim_head ** -0.5
37
+
38
+ self.attend = nn.Softmax(dim = -1)
39
+ self.dropout = nn.Dropout(dropout)
40
+
41
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
42
+
43
+ self.to_out = nn.Sequential(
44
+ nn.Linear(inner_dim, dim),
45
+ nn.Dropout(dropout)
46
+ ) if project_out else nn.Identity()
47
+
48
+ self.rel_pe = rel_pe
49
+ if rel_pe:
50
+ self.pe = nn.Sequential(nn.Conv2d(3, 64, 1), nn.ReLU(), nn.Conv2d(64, 1, 1))
51
+
52
+ def forward(self, x, centroid_delta):
53
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
54
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
55
+
56
+ pe = self.pe(centroid_delta) if self.rel_pe else 0
57
+ dots = (torch.matmul(q, k.transpose(-1, -2)) + pe) * self.scale
58
+
59
+ attn = self.attend(dots)
60
+ attn = self.dropout(attn)
61
+
62
+ out = torch.matmul(attn, v)
63
+ out = rearrange(out, 'b h n d -> b n (h d)')
64
+ return self.to_out(out)
65
+
66
+
67
+ class Transformer(nn.Module):
68
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., rel_pe = False):
69
+ super().__init__()
70
+ self.layers = nn.ModuleList([])
71
+ for _ in range(depth):
72
+ self.layers.append(nn.ModuleList([
73
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, rel_pe = rel_pe)),
74
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
75
+ ]))
76
+ def forward(self, x, centroid_delta):
77
+ for attn, ff in self.layers:
78
+ x = attn(x, centroid_delta) + x
79
+ x = ff(x) + x
80
+ return x
81
+
82
+
83
+ class PointPatchTransformer(nn.Module):
84
+ def __init__(self, dim, depth, heads, mlp_dim, sa_dim, patches, prad, nsamp, in_dim=3, dim_head=64, rel_pe=False, patch_dropout=0) -> None:
85
+ super().__init__()
86
+ self.patches = patches
87
+ self.patch_dropout = patch_dropout
88
+ self.sa = PointNetSetAbstraction(npoint=patches, radius=prad, nsample=nsamp, in_channel=in_dim + 3, mlp=[64, 64, sa_dim], group_all=False)
89
+ self.lift = nn.Sequential(nn.Conv1d(sa_dim + 3, dim, 1), rst.Lambda(lambda x: torch.permute(x, [0, 2, 1])), nn.LayerNorm([dim]))
90
+ self.cls_token = nn.Parameter(torch.randn(dim))
91
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, 0.0, rel_pe)
92
+
93
+ def forward(self, features):
94
+ self.sa.npoint = self.patches
95
+ if self.training:
96
+ self.sa.npoint -= self.patch_dropout
97
+ # print("input", features.shape)
98
+ centroids, feature = self.sa(features[:, :3], features)
99
+ # print("f", feature.shape, 'c', centroids.shape)
100
+ x = self.lift(torch.cat([centroids, feature], dim=1))
101
+
102
+ x = rst.supercat([self.cls_token, x], dim=-2)
103
+ centroids = rst.supercat([centroids.new_zeros(1), centroids], dim=-1)
104
+
105
+ centroid_delta = centroids.unsqueeze(-1) - centroids.unsqueeze(-2)
106
+ x = self.transformer(x, centroid_delta)
107
+
108
+ return x[:, 0]
109
+
110
+
111
+ class Projected(nn.Module):
112
+ def __init__(self, ppat, proj) -> None:
113
+ super().__init__()
114
+ self.ppat = ppat
115
+ self.proj = proj
116
+
117
+ def forward(self, features: torch.Tensor):
118
+ return self.proj(self.ppat(features))
setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+
4
+ def packages():
5
+ return setuptools.find_packages()
6
+
7
+
8
+ setuptools.setup(
9
+ name="openshape",
10
+ version="0.1",
11
+ author="flandre.info",
12
+ author_email="[email protected]",
13
+ description="Support library for OpenShape Demos.",
14
+ packages=packages(),
15
+ classifiers=[
16
+ "Programming Language :: Python :: 3 :: Only",
17
+ "License :: OSI Approved :: Apache Software License",
18
+ "Operating System :: OS Independent",
19
+ ],
20
+ python_requires='~=3.7',
21
+ )