Source code for bocoel.factories.embedders

# Copyright (c) RenChu Wang - All Rights Reserved

from bocoel import Embedder, EnsembleEmbedder, HuggingfaceEmbedder, SbertEmbedder
from bocoel.common import StrEnum

from . import common


[docs] class EmbedderName(StrEnum): """ The names of the embedders. """ SBERT = "SBERT" "Corresponds to `SbertEmbedder`." HUGGINGFACE = "HUGGINGFACE" "Corresponds to `HuggingfaceEmbedder`." HUGGINGFACE_ENSEMBLE = "HUGGINGFACE_ENSEMBLE" "Corresponds to `EnsembleEmbedder` concatenating `HuggingfaceEmbedder`."
[docs] def embedder( name: str | EmbedderName, /, *, model_name: str | list[str], device: str = "auto", batch_size: int, ) -> Embedder: """ Create an embedder. Parameters: name: The name of the embedder. model_name: The model name to use. device: The device to use. batch_size: The batch size to use. Returns: The embedder instance. Raises: ValueError: If the name is unknown. TypeError: If the model name is not a string for SBERT or Huggingface, or not a list of strings for HuggingfaceEnsemble. """ match EmbedderName.lookup(name): case EmbedderName.SBERT: if not isinstance(model_name, str): raise TypeError( "SbertEmbedder requires a single model name. " f"Got {model_name} instead." ) return common.correct_kwargs(SbertEmbedder)( model_name=model_name, device=common.auto_device(device), batch_size=batch_size, ) case EmbedderName.HUGGINGFACE: if not isinstance(model_name, str): raise TypeError( "HuggingfaceEmbedder requires a single model name. " f"Got {model_name} instead." ) return common.correct_kwargs(HuggingfaceEmbedder)( path=model_name, device=common.auto_device(device), batch_size=batch_size, ) case EmbedderName.HUGGINGFACE_ENSEMBLE: if not isinstance(model_name, list): raise TypeError( "HuggingfaceEnsembleEmbedder requires a list of model names. " f"Got {model_name} instead." ) device_list = common.auto_device_list(device, len(model_name)) return common.correct_kwargs(EnsembleEmbedder)( [ HuggingfaceEmbedder(path=model, device=dev, batch_size=batch_size) for model, dev in zip(model_name, device_list) ] ) case _: raise ValueError(f"Unknown embedder name: {name}")