Spaces:
Runtime error
Runtime error
import faiss | |
import faiss.contrib.torch_utils | |
import time | |
import logging | |
import torch | |
import numpy as np | |
code_size = 64 | |
class DatastoreBatch(): | |
def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None: | |
self.indices = [] | |
self.batch_size = batch_size | |
self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu') | |
for i in range(batch_size): | |
self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device)) | |
def move_to_gpu(self): | |
for i in range(self.batch_size): | |
self.indices[i].move_to_gpu() | |
def add_keys(self, keys, num_keys_to_add_at_a_time=100000): | |
for i in range(self.batch_size): | |
self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time) | |
def train_index(self, keys): | |
for index, example_keys in zip(self.indices, keys): | |
index.train_index(example_keys) | |
def search(self, queries, k): | |
found_scores, found_values = [], [] | |
for i in range(self.batch_size): | |
scores, values = self.indices[i].search(queries[i], k) | |
found_scores.append(scores) | |
found_values.append(values) | |
return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0) | |
def search_and_reconstruct(self, queries, k): | |
found_scores, found_values = [], [] | |
found_vectors = [] | |
for i in range(self.batch_size): | |
scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k) | |
found_scores.append(scores) | |
found_values.append(values) | |
found_vectors.append(vectors) | |
return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0) | |
class Datastore(): | |
def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None: | |
self.dimension = dim | |
self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu') | |
self.logger = logging.getLogger('index_building') | |
self.logger.setLevel(20) | |
self.use_flat_index = use_flat_index | |
self.gpu_index = gpu_index | |
# Initialize faiss index | |
# TODO: is preprocessing efficient enough to spend time on? | |
if not use_flat_index: | |
self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention | |
# need to wrap in index ID map to enable add_with_ids | |
# self.index = faiss.IndexIDMap(self.index) | |
self.index_size = 0 | |
# if self.gpu_index: | |
# self.move_to_gpu() | |
def move_to_gpu(self): | |
if self.use_flat_index: | |
# self.keys = self.keys.to(self.device) | |
return | |
else: | |
co = faiss.GpuClonerOptions() | |
co.useFloat16 = True | |
self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co) | |
def train_index(self, keys): | |
if self.use_flat_index: | |
self.add_keys(keys=keys, index_is_trained=True) | |
else: | |
keys = keys.cpu().float() | |
ncentroids = int(keys.shape[0] / 128) | |
self.index = faiss.IndexIVFPQ(self.index, self.dimension, | |
ncentroids, code_size, 8) | |
self.index.nprobe = min(32, ncentroids) | |
# if not self.gpu_index: | |
# keys = keys.cpu() | |
self.logger.info('Training index') | |
start_time = time.time() | |
self.index.train(keys) | |
self.logger.info(f'Training took {time.time() - start_time} s') | |
self.add_keys(keys=keys, index_is_trained=True) | |
# self.keys = None | |
if self.gpu_index: | |
self.move_to_gpu() | |
def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False): | |
self.keys = keys | |
if not self.use_flat_index and index_is_trained: | |
start = 0 | |
while start < keys.shape[0]: | |
end = min(len(keys), start + num_keys_to_add_at_a_time) | |
to_add = keys[start:end] | |
# if not self.gpu_index: | |
# to_add = to_add.cpu() | |
# self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size)) | |
self.index.add(to_add) | |
self.index_size += end - start | |
start += end | |
if (start % 1000000) == 0: | |
self.logger.info(f'Added {start} tokens so far') | |
# else: | |
# self.keys.append(keys) | |
# self.logger.info(f'Adding total {start} keys') | |
# self.logger.info(f'Adding took {time.time() - start_time} s') | |
def search_and_reconstruct(self, queries, k): | |
if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim | |
self.logger.info("Searching for a single vector; unsqueezing") | |
queries = queries.unsqueeze(0) | |
# self.logger.info("Searching with reconstruct") | |
assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors | |
scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k) | |
# self.logger.info("Searching done") | |
return scores, values, vectors | |
def search(self, queries, k): | |
# model_device = queries.device | |
# model_dtype = queries.dtype | |
if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim | |
self.logger.info("Searching for a single vector; unsqueezing") | |
queries = queries.unsqueeze(0) | |
assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors | |
# if not self.gpu_index: | |
# queries = queries.cpu() | |
# else: | |
# queries = queries.to(self.device) | |
if self.use_flat_index: | |
if self.gpu_index: | |
scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k, | |
metric=faiss.METRIC_INNER_PRODUCT, device=self.device.index) | |
else: | |
scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT) | |
scores = torch.from_numpy(scores).to(queries.dtype) | |
values = torch.from_numpy(values) #.to(model_dtype) | |
else: | |
scores, values = self.index.search(queries.float(), k) | |
# avoid returning -1 as a value | |
# TODO: get a handle on the attention mask and mask the values that were -1 | |
values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values) | |
# self.logger.info("Searching done") | |
# return scores.to(model_dtype).to(model_device), values.to(model_device) | |
return scores, values | |