project_charles / prototypes.py
sohojoe's picture
refactor: use more of a MDP style structure
149eeaf
import os
import torch
from clip_transform import CLIPTransform
from PIL import Image
from torch.nn import functional as F
class Prototypes:
def __init__(self):
self._clip_transform = CLIPTransform()
self._load_prototypes()
def _prepare_prototypes(self):
image_embeddings = self.load_images_from_folder('prototypes')
assert image_embeddings is not None, "no image embeddings found"
assert len(image_embeddings) > 0, "no image embeddings found"
person_keys = [key for key in image_embeddings.keys() if key.startswith('person-')]
no_person_keys = [key for key in image_embeddings.keys() if key.startswith('no_person-')]
person_keys.sort()
no_person_keys.sort()
# create pytorch vector of person embeddings
person_embeddings = torch.cat([image_embeddings[key] for key in person_keys])
# create pytorch vector of no_person embeddings
no_person_embeddings = torch.cat([image_embeddings[key] for key in no_person_keys])
person_embedding = person_embeddings.mean(dim=0)
person_embedding /= person_embedding.norm(dim=-1, keepdim=True)
no_person_embedding = no_person_embeddings.mean(dim=0)
no_person_embedding /= no_person_embedding.norm(dim=-1, keepdim=True)
self.prototype_keys = ["person", "no_person"]
self.prototypes = torch.stack([person_embedding, no_person_embedding])
# save prototypes to file
torch.save(self.prototypes, 'prototypes.pt')
def _load_prototypes(self):
# check if file exists
if not os.path.exists('prototypes.pt'):
self._prepare_prototypes()
self.prototypes = torch.load('prototypes.pt')
self.prototype_keys = ["person", "no_person"]
def load_images_from_folder(self, folder):
image_embeddings = {}
supported_filetypes = ['.jpg','.png','.jpeg']
for filename in os.listdir(folder):
if not any([filename.endswith(ft) for ft in supported_filetypes]):
continue
image = Image.open(os.path.join(folder,filename))
embeddings = self._clip_transform.pil_image_to_embeddings(image)
image_embeddings[filename] = embeddings
return image_embeddings
def get_distances(self, embeddings):
# case not normalized
# distances = F.cosine_similarity(embeddings, self.prototypes)
# case normalized
distances = embeddings @ self.prototypes.T
closest_item_idex = distances.argmax().item()
closest_item_key = self.prototype_keys[closest_item_idex]
debug_str = ""
for key, value in zip(self.prototype_keys, distances):
debug_str += f"{key}: {value.item():.2f}, "
return distances, closest_item_key, debug_str
if __name__ == "__main__":
prototypes = Prototypes()
print ("prototypes:")
for key, value in zip(prototypes.prototype_keys, prototypes.prototypes):
print (f"{key}: {len(value)}")
embeddings = prototypes.prototypes[0]
distances, closest_item_key, debug_str = prototypes.get_distances(embeddings)
print (f"closest_item_key: {closest_item_key}")
print (f"distances: {debug_str}")
print ("done")