Source code for bocoel.corpora.indices.backend.faiss

# Copyright (c) RenChu Wang - All Rights Reserved

import functools
import warnings
from types import ModuleType
from typing import Any

import numpy as np
from numpy.typing import NDArray

from bocoel.corpora.indices import utils
from bocoel.corpora.indices.interfaces import Distance, Index, InternalResult


@functools.cache
def _faiss() -> ModuleType:
    # Optional dependency.
    # Faiss also spits out deprecation warnings.
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=DeprecationWarning)

        import faiss

    return faiss


[docs] class FaissIndex(Index): """ Faiss index. Uses the faiss library. """
[docs] def __init__( self, embeddings: NDArray, distance: str | Distance, *, normalize: bool = True, index_string: str, cuda: bool = False, batch_size: int = 64, ) -> None: """ Initializes the Faiss index. Parameters: embeddings: The embeddings to index. distance: The distance metric to use. index_string: The index string to use. cuda: Whether to use CUDA. batch_size: The batch size to use for searching. """ if normalize: embeddings = utils.normalize(embeddings) self.__embeddings = embeddings self._batch_size = batch_size self._dist = Distance.lookup(distance) self._index_string = index_string self._init_index(index_string=index_string, cuda=cuda)
def __repr__(self) -> str: return f"{type(self).__name__}({self._index_string}, {self.dims})" @property def batch(self) -> int: return self._batch_size @property def data(self) -> NDArray: return self.__embeddings @property def distance(self) -> Distance: return self._dist @property def dims(self) -> int: return self.__embeddings.shape[1] def _search(self, query: NDArray, k: int = 1) -> InternalResult: results = [ self._index.search(query[i : i + self._batch_size], k) for i in range(0, len(query), self._batch_size) ] distances, indices = map(np.concatenate, zip(*results)) return InternalResult(distances=distances, indices=indices) def _init_index(self, index_string: str, cuda: bool) -> None: metric = self._faiss_metric(self.distance) # Using Any as type hint to prevent errors coming up in add / search. # Faiss is not type check ready yet. # https://github.com/facebookresearch/faiss/issues/2891 index: Any = _faiss().index_factory(self.dims, index_string, metric) index.train(self.data) index.add(self.data) if cuda: index = _faiss().index_cpu_to_all_gpus(index) self._index = index @staticmethod def _faiss_metric(distance: Distance) -> Any: match distance: case Distance.L2: return _faiss().METRIC_L2 case Distance.INNER_PRODUCT: return _faiss().METRIC_INNER_PRODUCT