Spaces:
Sleeping
Sleeping
import json | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import hf_hub_download | |
meta = json.load( | |
open(hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse_meta.json", repo_type='dataset')) | |
) | |
# { | |
# "u": "94db219c315742909fee67deeeacae15", | |
# "name": "knife", | |
# "like": 0, | |
# "view": 35, | |
# "anims": 0, | |
# "tags": ["game-ready"], | |
# "cats": ["weapons-military"], | |
# "img": "https://media.sketchfab.com/models/94db219c315742909fee67deeeacae15/thumbnails/c0bbbd475d264ff2a92972f5115564ee/0cd28a130ebd4d9c9ef73190f24d9a42.jpeg", | |
# "desc": "", | |
# "faces": 1724, | |
# "size": 11955, | |
# "lic": "by", | |
# "glb": "glbs/000-000/94db219c315742909fee67deeeacae15.glb" | |
# } | |
meta = {x['u']: x for x in meta['entries']} | |
""" | |
deser = torch.load( | |
hf_hub_download("OpenShape/openshape-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu' | |
) | |
""" | |
deser = torch.load( | |
hf_hub_download("TripletMix/tripletmix-objaverse-embeddings", "objaverse.pt", repo_type='dataset'), map_location='cpu' | |
) | |
us = deser['us'] | |
feats = deser['feats'] | |
def retrieve(embedding, top, sim_th=0.0, filter_fn=None): | |
sims = [] | |
embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze() | |
#for chunk in torch.split(feats, 10240): | |
for chunk in feats: | |
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T) | |
#sims = torch.cat(sims) | |
sims = torch.tensor(sims) | |
sims, idx = torch.sort(sims, descending=True) | |
sim_mask = sims > sim_th | |
sims = sims[sim_mask] | |
idx = idx[sim_mask] | |
results = [] | |
for i, sim in zip(idx, sims): | |
if us[i] in meta: | |
if filter_fn is None or filter_fn(meta[us[i]]): | |
results.append(dict(meta[us[i]], sim=sim)) | |
if len(results) >= top: | |
break | |
return results | |