Source code for bocoel.factories.lms
# Copyright (c) RenChu Wang - All Rights Reserved
from collections.abc import Sequence
from bocoel import (
ClassifierModel,
GenerativeModel,
HuggingfaceGenerativeLM,
HuggingfaceLogitsLM,
HuggingfaceSequenceLM,
)
from bocoel.common import StrEnum
from . import common
[docs]
class GeneratorName(StrEnum):
"""
The generator names.
"""
HUGGINGFACE_GENERATIVE = "HUGGINGFACE_GENERATIVE"
"Corresponds to `HuggingfaceGenerativeLM`."
[docs]
class ClassifierName(StrEnum):
"""
The classifier names.
"""
HUGGINGFACE_LOGITS = "HUGGINGFACE_LOGITS"
"Corresponds to `HuggingfaceLogitsLM`."
HUGGINGFACE_SEQUENCE = "HUGGINGFACE_SEQUENCE"
"Corresponds to `HuggingfaceSequenceLM`."
[docs]
def generative(
name: str | GeneratorName,
/,
*,
model_path: str,
batch_size: int,
device: str = "auto",
add_sep_token: bool = False,
) -> GenerativeModel:
"""
Create a generative model.
Parameters:
name: The name of the model.
model_path: The path to the model.
batch_size: The batch size to use.
device: The device to use.
add_sep_token: Whether to add the sep token.
Returns:
The generative model instance.
Raises:
ValueError: If the name is unknown.
"""
device = common.auto_device(device)
match GeneratorName.lookup(name):
case GeneratorName.HUGGINGFACE_GENERATIVE:
return common.correct_kwargs(HuggingfaceGenerativeLM)(
model_path=model_path,
batch_size=batch_size,
device=device,
add_sep_token=add_sep_token,
)
case _:
raise ValueError(f"Unknown LM name {name}")
def classifier(
name: str | ClassifierName,
/,
*,
model_path: str,
batch_size: int,
choices: Sequence[str],
device: str = "auto",
add_sep_token: bool = False,
) -> ClassifierModel:
device = common.auto_device(device)
match ClassifierName.lookup(name):
case ClassifierName.HUGGINGFACE_LOGITS:
return common.correct_kwargs(HuggingfaceLogitsLM)(
model_path=model_path,
batch_size=batch_size,
device=device,
choices=choices,
add_sep_token=add_sep_token,
)
case ClassifierName.HUGGINGFACE_SEQUENCE:
return common.correct_kwargs(HuggingfaceSequenceLM)(
model_path=model_path,
device=device,
choices=choices,
add_sep_token=add_sep_token,
)
case _:
raise ValueError(f"Unknown LM name {name}")