Source code for bocoel.corpora.embedders.sberts

# Copyright (c) RenChu Wang - All Rights Reserved

import typing
from collections.abc import Sequence

from torch import Tensor

from bocoel.corpora.embedders.interfaces import Embedder


[docs] class SbertEmbedder(Embedder): """ Sentence-BERT embedder. Uses the sentence_transformers library. """
[docs] def __init__( self, model_name: str = "all-mpnet-base-v2", device: str = "cpu", batch_size: int = 64, ) -> None: """ Initializes the Sbert embedder. Parameters: model_name: The model name to use. device: The device to use. batch_size: The batch size for encoding. Raises: ImportError: If sentence_transformers is not installed. """ # Optional dependency. from sentence_transformers import SentenceTransformer self._name = model_name self._sbert = SentenceTransformer(model_name, device=device) self._batch_size = batch_size
def __repr__(self) -> str: return f"Sbert({self._name})" @property def batch(self) -> int: return self._batch_size @property def dims(self) -> int: d = self._sbert.get_sentence_embedding_dimension() assert isinstance(d, int) return d def _encode(self, texts: Sequence[str]) -> Tensor: texts = list(texts) return typing.cast( Tensor, self._sbert.encode(texts, batch_size=len(texts), convert_to_tensor=True), )