Source code for bocoel.corpora.indices.backend.hnswlib

# Copyright (c) RenChu Wang - All Rights Reserved

from typing import Literal

from numpy.typing import NDArray

from bocoel.corpora.indices import utils
from bocoel.corpora.indices.interfaces import Distance, Index, InternalResult

_HnswlibDist = Literal["l2", "ip", "cosine"]


[docs] class HnswlibIndex(Index): """ HNSWLIB index. Uses the hnswlib library. Score is calculated slightly differently https://github.com/nmslib/hnswlib#supported-distances """
[docs] def __init__( self, embeddings: NDArray, distance: str | Distance, *, normalize: bool = True, threads: int = -1, batch_size: int = 64, ) -> None: """ Initializes the HNSWLIB index. Parameters: embeddings: The embeddings to index. distance: The distance metric to use. normalize: Whether to normalize the embeddings. threads: The number of threads to use. batch_size: The batch size to use for searching. Raises: ValueError: If the distance is not supported. """ if normalize: embeddings = utils.normalize(embeddings) self.__embeddings = embeddings # Would raise ValueError if not a valid distance. self._dist = Distance.lookup(distance) self._batch_size = batch_size # A public attribute because this can be changed at anytime. self.threads = threads self._init_index()
@property def batch(self) -> int: return self._batch_size @property def data(self) -> NDArray: return self.__embeddings @property def distance(self) -> Distance: return self._dist def _search(self, query: NDArray, k: int = 1) -> InternalResult: indices, distances = self._index.knn_query(query, k=k, num_threads=self.threads) return InternalResult(indices=indices, distances=distances) def _init_index(self) -> None: # Optional dependency. from hnswlib import Index as _HnswlibIndex space = self._hnswlib_space(self.distance) self._index = _HnswlibIndex(space=space, dim=self.dims) self._index.init_index(max_elements=len(self.data)) self._index.add_items(self.data, num_threads=self.threads) @staticmethod def _hnswlib_space(distance: Distance) -> _HnswlibDist: match distance: case Distance.L2: return "l2" case Distance.INNER_PRODUCT: return "ip"