Source code for bocoel.corpora.storages.concat
# Copyright (c) RenChu Wang - All Rights Reserved
from collections.abc import Collection, Iterable, Mapping, Sequence
from typing import Any
import numpy as np
import structlog
from bocoel.corpora.storages.interfaces import Storage
LOGGER = structlog.get_logger()
[docs]
class ConcatStorage(Storage):
"""
Storage that concatenates multiple storages together.
Concatenation is done on the first dimension.
The resulting storage is read-only and has length
equal to the sum of the lengths of the storages.
"""
[docs]
def __init__(self, storages: Sequence[Storage], /) -> None:
if len(storages) < 1:
raise ValueError("At least one storage is required")
diff_keys = set(frozenset(store.keys()) for store in storages)
if len(diff_keys) > 1:
raise ValueError("Keys are not equal")
# Unpack the only key in `diff_keys`.
(self._keys,) = diff_keys
self._storages = tuple(storages)
LOGGER.info("Concat storage created", storages=storages, keys=diff_keys)
storage_lengths = [len(store) for store in self._storages]
self._prefix_sum = np.cumsum(storage_lengths).tolist()
self._length = sum(storage_lengths)
def __repr___(self) -> str:
return f"Concat({list(self._storages)})"
def keys(self) -> Collection[str]:
return self._keys
def __len__(self) -> int:
return self._length
def _getitem(self, idx: int) -> Mapping[str, Any]:
if not -len(self) <= idx < len(self):
raise IndexError(
f"Index {idx} is out of bounds. Storage length is {len(self)}"
)
if idx < 0:
idx %= len(self)
found = np.searchsorted(self._prefix_sum, idx).item()
sub_idx = idx - self._prefix_sum[found]
assert sub_idx <= 0, {
"sub_idx": sub_idx,
"idx": idx,
"found": found,
"prefix_sum": self._prefix_sum,
}
return self._storages[found][sub_idx]
@classmethod
def join(cls, storages: Iterable[Storage], /) -> Storage:
storages = list(storages)
if len(storages) == 1:
return storages[0]
return ConcatStorage(storages)