Source code for bocoel.corpora.indices.backend.faiss
import functools
import warnings
from typing import Any
from numpy.typing import NDArray
from bocoel.corpora.indices import utils
from bocoel.corpora.indices.interfaces import Distance, Index, InternalResult
@functools.cache
def _faiss():
# Optional dependency.
# Faiss also spits out deprecation warnings.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
import faiss
return faiss
[docs]
class FaissIndex(Index):
"""
Faiss index. Uses the faiss library.
"""
[docs]
def __init__(
self,
embeddings: NDArray,
distance: str | Distance,
*,
normalize: bool = True,
index_string: str,
cuda: bool = False,
batch_size: int = 64,
) -> None:
"""
Initializes the Faiss index.
Parameters:
embeddings: The embeddings to index.
distance: The distance metric to use.
index_string: The index string to use.
cuda: Whether to use CUDA.
batch_size: The batch size to use for searching.
"""
if normalize:
embeddings = utils.normalize(embeddings)
self.__embeddings = embeddings
self._batch_size = batch_size
self._dist = Distance.lookup(distance)
self._index_string = index_string
self._init_index(index_string=index_string, cuda=cuda)
def __repr__(self) -> str:
return f"{type(self).__name__}({self._index_string}, {self.dims})"
@property
def batch(self) -> int:
return self._batch_size
@property
def data(self) -> NDArray:
return self.__embeddings
@property
def distance(self) -> Distance:
return self._dist
@property
def dims(self) -> int:
return self.__embeddings.shape[1]
def _search(self, query: NDArray, k: int = 1) -> InternalResult:
distances, indices = self._index.search(query, k)
return InternalResult(distances=distances, indices=indices)
def _init_index(self, index_string: str, cuda: bool) -> None:
metric = self._faiss_metric(self.distance)
# Using Any as type hint to prevent errors coming up in add / search.
# Faiss is not type check ready yet.
# https://github.com/facebookresearch/faiss/issues/2891
index: Any = _faiss().index_factory(self.dims, index_string, metric)
index.train(self.data)
index.add(self.data)
if cuda:
index = _faiss().index_cpu_to_all_gpus(index)
self._index = index
@staticmethod
def _faiss_metric(distance: Distance) -> Any:
match distance:
case Distance.L2:
return _faiss().METRIC_L2
case Distance.INNER_PRODUCT:
return _faiss().METRIC_INNER_PRODUCT