Source code for bocoel.models.lms.huggingface.logits
# Copyright (c) RenChu Wang - All Rights Reserved
from collections.abc import Sequence
import torch
from numpy.typing import NDArray
from bocoel.models.lms.interfaces import ClassifierModel
from .causal import HuggingfaceCausalLM
[docs]
class HuggingfaceLogitsLM(HuggingfaceCausalLM, ClassifierModel):
"""
Logits classification model backed by huggingface's transformers library.
This means that the model would use the logits of ['1', '2', '3', '4', '5'] as the output,
if `choices = 5`, for the current batch of inputs.
"""
[docs]
def __init__(
self,
model_path: str,
batch_size: int,
device: str,
choices: Sequence[str],
add_sep_token: bool = False,
) -> None:
"""
Parameters:
model_path: The path to the model.
batch_size: The batch size to use.
device: The device to use.
choices: The choices to classify.
add_sep_token: Whether to add the sep token.
"""
super().__init__(
model_path=model_path,
batch_size=batch_size,
device=device,
add_sep_token=add_sep_token,
)
self._choices = choices
self._encoded_choices = self._encode_tokens(self._choices)
@property
def choices(self) -> Sequence[str]:
return self._choices
@torch.no_grad()
def _classify(self, prompts: Sequence[str], /) -> NDArray:
tokenized = self._tokenizer(prompts)
output = self._model(**tokenized)
# Logits has the shape [batch_size, seq_len, vocab_size].
logits = output.logits
# Using encoded to select the logits at the last position.
result = logits[:, -1, self._encoded_choices]
return result.cpu().numpy()
def _encode_tokens(self, tokens: Sequence[str]) -> Sequence[int]:
result: list[int] = []
for tok in tokens:
# Only adds the first token because we are only interested in the first token.
result.append(self._tokenizer.encode(tok, add_special_tokens=False)[0])
assert len(result) == len(tokens)
if len(result) != len(set(result)):
decoded = self._tokenizer.decode(self._tokenizer.encode(tokens))
raise ValueError(
"Each token must be converted to 1 unique id."
f"Got {tokens}, encoded into {decoded}."
)
return result