Source code for bocoel.corpora.indices.interfaces.results

# Copyright (c) RenChu Wang - All Rights Reserved

import dataclasses as dcls
from typing import NamedTuple

from numpy.typing import NDArray


@dcls.dataclass(frozen=True)
class _SearchResult:
    query: NDArray
    """
    Query vector.
    If batched, should have shape [batch, dims].
    Or else, should have shape [dims].
    """

    vectors: NDArray
    """
    Nearest neighbors.
    If batched, should have shape [batch, k, dims].
    Or else, should have shape [k, dims].
    """

    distances: NDArray
    """
    Calculated distance.
    If batched, should have shape [batch, k].
    Or else, should have shape [k].
    """

    indices: NDArray
    """
    Index in the original embeddings. Must be integers.
    If batched, should have shape [batch, k].
    Or else, should have shape [k].
    """


[docs] @dcls.dataclass(frozen=True) class SearchResultBatch(_SearchResult): """ A batched version of search result. """ def __post_init__(self) -> None: if self.query.ndim != 2: raise ValueError(f"Query should be batched. Got shape {self.query.shape}") if self.vectors.ndim != 3: raise ValueError( f"Vectors should be batched. Got shape {self.vectors.shape}." ) if self.distances.ndim != 2: raise ValueError( f"Distances should be batched. Got shape {self.distances.shape}." ) if self.indices.ndim != 2: raise ValueError( f"Indices should be batched. Got shape {self.indices.shape}." ) batches = { self.query.shape[0], self.vectors.shape[0], self.distances.shape[0], self.indices.shape[0], } if len(batches) != 1: raise ValueError( "Batched results should have the same batch size. " f"Got {len(self.query)}, {len(self.vectors)}, " f"{len(self.distances)}, {len(self.indices)}." ) ks = { self.vectors.shape[1], self.distances.shape[1], self.indices.shape[1], } if len(ks) != 1: raise ValueError( "Batched results should have the same number of neighbors. " f"Got {self.vectors.shape[1]}, {self.distances.shape[1]}, " f"{self.indices.shape[1]}." )
[docs] @dcls.dataclass(frozen=True) class SearchResult(_SearchResult): """ A non-batched version of search result. """ def __post_init__(self) -> None: if self.query.ndim != 1: raise ValueError( f"Query should not be batched. Got shape {self.query.shape}." ) if self.vectors.ndim != 2: raise ValueError( f"Vectors should not be batched. Got shape {self.vectors.shape}." ) if self.distances.ndim != 1: raise ValueError( f"Distances should not be batched. Got shape {self.distances.shape}." ) if self.indices.ndim != 1: raise ValueError( f"Indices should not be batched. Got shape {self.indices.shape}." ) ks = { self.vectors.shape[0], self.distances.shape[0], self.indices.shape[0], } if len(ks) != 1: raise ValueError( "Non-batched results should have the same number of neighbors. " f"Got {self.vectors.shape[0]}, {self.distances.shape[0]}, " f"{self.indices.shape[0]}." )
[docs] class InternalResult(NamedTuple): distances: NDArray """ Calculated distance. """ indices: NDArray """ Index in the original embeddings. Must be integers. """