Source code for bocoel.models.adaptors.bigbench.matching
# Copyright (c) RenChu Wang - All Rights Reserved
from collections.abc import Mapping, Sequence
from typing import Any
import structlog
import typeguard
from numpy.typing import NDArray
from bocoel.common import StrEnum
from bocoel.models.lms import GenerativeModel
from bocoel.models.scores import (
ExactMatch,
NltkBleuScore,
RougeScore,
RougeScore2,
SacreBleuScore,
Score,
)
from .interfaces import BigBenchAdaptor
LOGGER = structlog.get_logger()
[docs]
class BigBenchMatchType(StrEnum):
EXACT = "EXACT"
NLTK_BLEU = "NLTK_BLEU"
SACRE_BLEU = "SACRE_BLEU"
ROUGE_1 = "ROUGE_1"
ROUGE_2 = "ROUGE_2"
ROUGE_L = "ROUGE_L"
ROUGE_SCORE_1 = "ROUGE_SCORE_1"
ROUGE_SCORE_2 = "ROUGE_SCORE_2"
ROUGE_SCORE_L = "ROUGE_SCORE_L"
@property
def score(self) -> Score:
match self:
case BigBenchMatchType.EXACT:
return ExactMatch()
case BigBenchMatchType.NLTK_BLEU:
return NltkBleuScore()
case BigBenchMatchType.SACRE_BLEU:
return SacreBleuScore()
case BigBenchMatchType.ROUGE_L:
return RougeScore("rouge-l")
case BigBenchMatchType.ROUGE_1:
return RougeScore("rouge-1")
case BigBenchMatchType.ROUGE_2:
return RougeScore("rouge-2")
case BigBenchMatchType.ROUGE_SCORE_L:
return RougeScore2("rougeL")
case BigBenchMatchType.ROUGE_SCORE_1:
return RougeScore2("rouge1")
case BigBenchMatchType.ROUGE_SCORE_2:
return RougeScore2("rouge2")
[docs]
class BigBenchQuestionAnswer(BigBenchAdaptor):
[docs]
def __init__(
self,
lm: GenerativeModel,
inputs: str = "inputs",
targets: str = "targets",
matching_type: str | BigBenchMatchType = BigBenchMatchType.EXACT,
) -> None:
self.lm = lm
self.inputs = inputs
self.targets = targets
self._matching_type = BigBenchMatchType.lookup(matching_type)
self._score_fn = self._matching_type.score
def __repr__(self) -> str:
return f"BigBenchQA({self.lm}, {self.inputs}, {self.targets}, {self._matching_type})"
[docs]
def evaluate(self, data: Mapping[str, Sequence[Any]]) -> Sequence[float] | NDArray:
# Get data.
inputs = data[self.inputs]
targets = data[self.targets]
LOGGER.debug("Evaluating", inputs=inputs, targets=targets)
# Check data.
typeguard.check_type(inputs, Sequence[str])
typeguard.check_type(targets, Sequence[list[str]])
generated = self.lm.generate(inputs)
LOGGER.debug("Generated by lm", generated=generated)
return [self._score_fn(g, t) for g, t in zip(generated, targets)]