# Copyright (c) RenChu Wang - All Rights Reserved
from typing import Any
import numpy as np
import structlog
from numpy.typing import ArrayLike, NDArray
from bocoel.corpora.indices import utils
from bocoel.corpora.indices.interfaces import Boundary, Distance, Index, InternalResult
LOGGER = structlog.get_logger()
[docs]
class PolarIndex(Index):
"""
Index that uses N-sphere coordinates as interfaces.
See wikipedia linked below for details.
Converting the spatial indices into spherical coordinates has the following benefits:
- Since the coordinates are normalized, the radius is always 1.
- The search region is rectangular in spherical coordinates,
ideal for bayesian optimization.
[Wikipedia link on N-sphere](https://en.wikipedia.org/wiki/N-sphere#Spherical_coordinates)
"""
[docs]
def __init__(
self,
embeddings: NDArray,
distance: str | Distance,
*,
polar_backend: type[Index],
**backend_kwargs: Any,
) -> None:
"""
Parameters:
embeddings: The embeddings to index.
distance: The distance metric to use.
polar_backend: The backend to use for indexing.
**backend_kwargs: The backend specific keyword arguments.
"""
embeddings = utils.normalize(embeddings)
self._index = polar_backend(
embeddings=embeddings,
distance=distance,
**backend_kwargs,
)
dims = self._index.dims - 1
self._boundary = self._polar_boundary(dims)
self._data = self._polar_coordinates()
def _search(self, query: NDArray, k: int = 1) -> InternalResult:
# Ignores the length of the query. Only direction is preserved.
spatial = self.polar_to_spatial(np.ones([len(query)]), query)
assert spatial.shape[1] == self.dims + 1, (
"Spatial dimensions do not match embeddings. "
f"Expected {self.dims + 1}. Got {spatial.shape[1]}."
)
return self._index._search(spatial, k=k)
@property
def batch(self) -> int:
return self._index.batch
@property
def data(self) -> NDArray:
return self._data
@property
def distance(self) -> Distance:
return self._index.distance
@property
def boundary(self) -> Boundary:
return self._boundary
def _polar_boundary(self, dims: int) -> Boundary:
"""
The boundary of the queries.
For polar coordinate it is [0, pi] for all dimensions
except the last one which is [0, 2 * pi].
Returns:
The boundary of the input.
"""
# See wikipedia linked in the class documentation for details.
upper = np.concatenate([[np.pi] * (dims - 1), [2 * np.pi]])
lower = np.zeros_like(upper)
return Boundary(np.stack([lower, upper], axis=-1))
def _polar_coordinates(self) -> NDArray:
LOGGER.info(
"Converting embeddings to polar coordinates.", batch_size=self.batch
)
embeddings = self._index.data
results = []
for idx in range(0, len(embeddings), self.batch):
batch = embeddings[idx : idx + self.batch]
_, polar = self.spatial_to_polar(batch)
results.append(polar)
transformed = np.concatenate(results, axis=0)
assert (
transformed.shape[1] == self._index.dims - 1
), "Polar dimensions do not match embeddings."
return transformed
[docs]
@staticmethod
def polar_to_spatial(r: ArrayLike, theta: ArrayLike) -> NDArray:
"""
Convert an N-sphere coordinates to cartesian coordinates.
See wikipedia linked in the class documentation for details.
Parameters:
r: The radius of the N-sphere. Has the shape [N].
theta: The angles of the N-sphere. Hash the shape [N, D].
Returns:
The cartesian coordinates of the N-sphere.
"""
r = np.array(r)
theta = np.array(theta)
if r.ndim != 1:
raise ValueError(f"Expected r to be 1D, got {r.ndim}")
if theta.ndim != 2:
raise ValueError(f"Expected theta to be 2D, got {theta.ndim}")
if r.shape[0] != theta.shape[0]:
raise ValueError(
f"Expected r and theta to have the same length, got {r.shape[0]} and {theta.shape[0]}"
)
# Add 1 dimension to the front because spherical coordinate's first dimension is r.
sin = np.concatenate([np.ones([len(r), 1]), np.sin(theta)], axis=1)
sin = np.cumprod(sin, axis=1)
cos = np.concatenate([np.cos(theta), np.ones([len(r), 1])], axis=1)
return sin * cos * r[:, None]
[docs]
@staticmethod
def spatial_to_polar(x: ArrayLike) -> tuple[NDArray, NDArray]:
"""
Convert cartesian coordinates to N-sphere coordinates.
See wikipedia linked in the class documentation for details.
Parameters:
x: The cartesian coordinates. Has the shape [N, D].
Returns:
A tuple. The radius and the angles of the N-sphere.
"""
x = np.array(x)
if x.ndim != 2:
raise ValueError(f"Expected x to be 2D, got {x.ndim}")
# Since the function requires a lot of sum of squares, cache it.
x_2 = x[:, 1:] ** 2
r = np.sqrt(x_2.sum(axis=1))
cumsum_back = np.cumsum(x_2[:, ::-1], axis=1)[:, ::-1]
theta = np.arctan2(np.sqrt(cumsum_back), x[:, 1:])
return r, theta