|
import torch |
|
import torch.nn.functional as F |
|
from collections import OrderedDict |
|
from . import lvis |
|
|
|
|
|
@torch.no_grad() |
|
def pred_lvis_sims(pc_encoder: torch.nn.Module, pc): |
|
ref_dev = next(pc_encoder.parameters()).device |
|
enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu() |
|
sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze()) |
|
argsort = torch.argsort(sim, descending=True) |
|
return OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories)) |
|
|
|
|
|
@torch.no_grad() |
|
def pred_custom_sims(pc_encoder: torch.nn.Module, pc, cats, feats): |
|
ref_dev = next(pc_encoder.parameters()).device |
|
enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu() |
|
sim = torch.matmul(F.normalize(feats, dim=-1), F.normalize(enc, dim=-1).squeeze()) |
|
argsort = torch.argsort(sim, descending=True) |
|
return OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats)) |
|
|