Source code for bocoel.corpora.indices.interfaces.indices
# Copyright (c) RenChu Wang - All Rights Reserved
import abc
from typing import Any, Protocol
import numpy as np
from numpy.typing import ArrayLike, NDArray
from bocoel import common
from .boundaries import Boundary
from .distances import Distance
from .results import InternalResult, SearchResultBatch
[docs]
class Index(Protocol):
"""
Index is responsible for fast retrieval given a vector query.
"""
def __init__(
self, embeddings: NDArray, distance: str | Distance, **kwargs: Any
) -> None:
# Included s.t. constructors of Index can be used.
...
def __repr__(self) -> str:
name = common.remove_base_suffix(self, Index)
return f"{name}({self.dims})"
def __len__(self) -> int:
"""
The number of items in the index.
Returns:
The number of items.
"""
return len(self.data)
def __getitem__(self, idx: int) -> NDArray:
"""
Get the item at the given index.
Parameters:
idx: The index of the item.
Returns:
The item.
"""
return self.data[idx]
[docs]
def search(self, query: ArrayLike, k: int = 1) -> SearchResultBatch:
"""
Calls the search function and performs some checks.
Parameters:
query: The query vector. Must be of shape `[batch, query_dims]`.
k: The number of nearest neighbors to return.
Returns:
A `SearchResultBatch` instance. See `SearchResultBatch` for details.
"""
query = np.array(query)
if (ndim := query.ndim) != 2:
raise ValueError(
f"Expected query to be a 2D vector, got a vector of dim {ndim}."
)
if (dim := query.shape[1]) != self.dims:
raise ValueError(f"Expected query to have dimension {self.dims}, got {dim}")
if k < 1:
raise ValueError(f"Expected k to be at least 1, got {k}")
results: list[InternalResult] = []
for idx in range(0, len(query), self.batch):
query_batch = query[idx : idx + self.batch]
result = self._search(query_batch, k=k)
results.append(result)
indices = np.concatenate([res.indices for res in results], axis=0)
distances = np.concatenate([res.distances for res in results], axis=0)
vectors = self.data[indices]
return SearchResultBatch(
query=query, vectors=vectors, distances=distances, indices=indices
)
def in_range(self, query: NDArray) -> bool:
return all(query >= self.lower[None, :] & query <= self.upper[None, :])
@property
@abc.abstractmethod
def data(self) -> NDArray:
"""
The underly data that the index is used for searching.
NOTE:
This has the shape of [n, dims], where dims is the transformed space.
Returns:
The data.
"""
...
@property
@abc.abstractmethod
def batch(self) -> int:
"""
The batch size used for searching.
Returns:
The batch size.
"""
...
@property
def boundary(self) -> Boundary:
"""
The boundary of the queries.
This is used to check if the query is in range.
By default, this is [-1, 1] for all dimensions,
since embeddings are normalized.
Returns:
The boundary of the input.
"""
return Boundary.fixed(lower=-1, upper=1, dims=self.dims)
@property
@abc.abstractmethod
def distance(self) -> Distance:
"""
The distance metric used by the index.
Returns:
The distance metric.
"""
...
@abc.abstractmethod
def _search(self, query: NDArray, k: int = 1) -> InternalResult:
"""
Search the index with a given query.
Parameters:
query: The query vector. Must be of shape [query_dims].
k: The number of nearest neighbors to return.
Returns:
A numpy array of shape [k].
This corresponds to the indices of the nearest neighbors.
"""
...
@property
def dims(self) -> int:
"""
The number of dimensions that the query vector should be.
Returns:
The number of dimensions.
"""
return self.data.shape[-1]
@property
def lower(self) -> NDArray:
return self.boundary.lower
@property
def upper(self) -> NDArray:
return self.boundary.upper