Source code for bocoel.corpora.embedders.ensemble

# Copyright (c) RenChu Wang - All Rights Reserved

import os
from collections.abc import Sequence

import torch
from torch import Tensor

from bocoel.corpora.embedders.interfaces import Embedder


[docs] class EnsembleEmbedder(Embedder): """ An ensemble of embedders. The embeddings are concatenated together. """
[docs] def __init__(self, embedders: Sequence[Embedder], sequential: bool = False) -> None: """ Parameters: embedders: The embedders to use. sequential: Whether to use sequential processing. Raises: ValueError: If the embedders have different batch sizes. """ # Check if all embedders have the same batch size. self._embedders = embedders self._batch_size = embedders[0].batch if len(set(emb.batch for emb in embedders)) != 1: raise ValueError("All embedders must have the same batch size") self._sequential = sequential cpus = os.cpu_count() assert cpus is not None self._cpus = cpus
def __repr__(self) -> str: return f"Ensemble({[str(emb) for emb in self._embedders]})" @property def batch(self) -> int: return self._batch_size @property def dims(self) -> int: return sum(emb.dims for emb in self._embedders) def _encode(self, texts: Sequence[str]) -> Tensor: results = [emb._encode(texts) for emb in self._embedders] return torch.cat([res.cpu() for res in results], dim=-1)