from typing import Any

import numpy as np
import structlog
from numpy.typing import ArrayLike, NDArray

from bocoel.corpora.indices import utils
from bocoel.corpora.indices.interfaces import Boundary, Distance, Index, InternalResult

LOGGER = structlog.get_logger()

[docs] class PolarIndex(Index): """ Index that uses N-sphere coordinates as interfaces. See wikipedia linked below for details. Converting the spatial indices into spherical coordinates has the following benefits: - Since the coordinates are normalized, the radius is always 1. - The search region is rectangular in spherical coordinates, ideal for bayesian optimization. [Wikipedia link on N-sphere]( """
[docs] def __init__( self, embeddings: NDArray, distance: str | Distance, *, polar_backend: type[Index], **backend_kwargs: Any, ) -> None: """ Parameters: embeddings: The embeddings to index. distance: The distance metric to use. polar_backend: The backend to use for indexing. **backend_kwargs: The backend specific keyword arguments. """ embeddings = utils.normalize(embeddings) self._index = polar_backend( embeddings=embeddings, distance=distance, **backend_kwargs, ) dims = self._index.dims - 1 self._boundary = self._polar_boundary(dims) self._data = self._polar_coordinates()
def _search(self, query: NDArray, k: int = 1) -> InternalResult: # Ignores the length of the query. Only direction is preserved. spatial = self.polar_to_spatial(np.ones([len(query)]), query) assert spatial.shape[1] == self.dims + 1, ( "Spatial dimensions do not match embeddings. " f"Expected {self.dims + 1}. Got {spatial.shape[1]}." ) return self._index._search(spatial, k=k) @property def batch(self) -> int: return self._index.batch @property def data(self) -> NDArray: return self._data @property def distance(self) -> Distance: return self._index.distance @property def boundary(self) -> Boundary: return self._boundary def _polar_boundary(self, dims: int) -> Boundary: """ The boundary of the queries. For polar coordinate it is [0, pi] for all dimensions except the last one which is [0, 2 * pi]. Returns: The boundary of the input. """ # See wikipedia linked in the class documentation for details. upper = np.concatenate([[np.pi] * (dims - 1), [2 * np.pi]]) lower = np.zeros_like(upper) return Boundary(np.stack([lower, upper], axis=-1)) def _polar_coordinates(self) -> NDArray: "Converting embeddings to polar coordinates.", batch_size=self.batch ) embeddings = results = [] for idx in range(0, len(embeddings), self.batch): batch = embeddings[idx : idx + self.batch] _, polar = self.spatial_to_polar(batch) results.append(polar) transformed = np.concatenate(results, axis=0) assert ( transformed.shape[1] == self._index.dims - 1 ), "Polar dimensions do not match embeddings." return transformed
[docs] @staticmethod def polar_to_spatial(r: ArrayLike, theta: ArrayLike) -> NDArray: """ Convert an N-sphere coordinates to cartesian coordinates. See wikipedia linked in the class documentation for details. Parameters: r: The radius of the N-sphere. Has the shape [N]. theta: The angles of the N-sphere. Hash the shape [N, D]. Returns: The cartesian coordinates of the N-sphere. """ r = np.array(r) theta = np.array(theta) if r.ndim != 1: raise ValueError(f"Expected r to be 1D, got {r.ndim}") if theta.ndim != 2: raise ValueError(f"Expected theta to be 2D, got {theta.ndim}") if r.shape[0] != theta.shape[0]: raise ValueError( f"Expected r and theta to have the same length, got {r.shape[0]} and {theta.shape[0]}" ) # Add 1 dimension to the front because spherical coordinate's first dimension is r. sin = np.concatenate([np.ones([len(r), 1]), np.sin(theta)], axis=1) sin = np.cumprod(sin, axis=1) cos = np.concatenate([np.cos(theta), np.ones([len(r), 1])], axis=1) return sin * cos * r[:, None]
[docs] @staticmethod def spatial_to_polar(x: ArrayLike) -> tuple[NDArray, NDArray]: """ Convert cartesian coordinates to N-sphere coordinates. See wikipedia linked in the class documentation for details. Parameters: x: The cartesian coordinates. Has the shape [N, D]. Returns: A tuple. The radius and the angles of the N-sphere. """ x = np.array(x) if x.ndim != 2: raise ValueError(f"Expected x to be 2D, got {x.ndim}") # Since the function requires a lot of sum of squares, cache it. x_2 = x[:, 1:] ** 2 r = np.sqrt(x_2.sum(axis=1)) cumsum_back = np.cumsum(x_2[:, ::-1], axis=1)[:, ::-1] theta = np.arctan2(np.sqrt(cumsum_back), x[:, 1:]) return r, theta