Source code for bocoel.core.exams.managers

import datetime as dt
import hashlib
import itertools
from collections import OrderedDict
from collections.abc import Generator, Mapping
from pathlib import Path
from typing import Any

import alive_progress as ap
import pandas as pd
import structlog
from pandas import DataFrame

from bocoel.core.optim import Optimizer
from bocoel.corpora import Corpus, Embedder
from bocoel.models import Adaptor, ClassifierModel, GenerativeModel

from . import columns
from .examinators import Examinator

LOGGER = structlog.get_logger()


[docs] class Manager: """ The manager for running and saving evaluations. """ _examinator: Examinator """ The examinator that would perform evaluations on the results. """
[docs] def __init__(self, root: str | Path | None = None, skip_rerun: bool = True) -> None: """ Parameters: root: The path to save the scores to. skip_rerun: Whether to skip rerunning the optimizer if the scores already exist. Raises: ValueError: If the path is not a directory. """ if root is not None: root = Path(root) if root.exists() and not root.is_dir(): raise ValueError(f"{root} is not a directory") root.mkdir(parents=True, exist_ok=True) # Prevent data from being tracked by git. gitigore = root / ".gitignore" if not gitigore.exists(): with open(gitigore, "w+") as f: f.write("# Automatically generated by BoCoEL.\n*") self._start = self.current() self._examinator = Examinator.presets() # Public attributes. Can be overwritten at any time. self.root = root self.skip_rerun = skip_rerun
[docs] def run( self, steps: int | None = None, *, optimizer: Optimizer, embedder: Embedder, corpus: Corpus, model: GenerativeModel | ClassifierModel, adaptor: Adaptor, ) -> DataFrame: """ Runs the optimizer until the end. If the root path is set in the constructor, the scores are saved to the path. Parameters: optimizer: The optimizer to run. embedder: The embedder to run the optimizer with. corpus: The corpus to run the optimizer on. model: The model to run the optimizer with. adaptor: The adaptor to run the optimizer with. steps: The number of steps to run the optimizer for. Returns: The final state of the optimizer. Keys are the indices of the queries, and values are the corresponding scores. """ md5 = self.md5( optimizer=optimizer, embedder=embedder, corpus=corpus, model=model, adaptor=adaptor, ) if self.skip_rerun and self.root is not None and (self.root / md5).exists(): LOGGER.warning("Previous scores found. Skip", md5=md5) return self.load(self.root / md5) # Run the optimizer and collect the results. LOGGER.info("Running the optimizer", steps=steps) results: OrderedDict[int, float] = OrderedDict() for res in self._launch(optimizer=optimizer, steps=steps): results.update(res) # Examine the results. LOGGER.info("Examing the results") scores = self._examinator.examine(index=corpus.index, results=results) self.save( scores=scores, optimizer=optimizer, corpus=corpus, model=model, adaptor=adaptor, embedder=embedder, md5=md5, ) return scores
[docs] def save( self, *, scores: DataFrame, optimizer: Optimizer, corpus: Corpus, model: GenerativeModel | ClassifierModel, adaptor: Adaptor, embedder: Embedder, md5: str, ) -> None: """ Saves the scores to the path. If the root path is not set in the constructor, the scores are not saved. Parameters: scores: The scores to save. optimizer: The optimizer used to generate the scores. corpus: The corpus used to generate the scores. model: The model used to generate the scores. adaptor: The adaptor used to generate the scores. embedder: The embedder used to generate the scores. md5: The md5 hash of the identifier columns. Raises: ValueError: If the path is not set. """ if self.root is None: LOGGER.warning("No path set to save the scores. Skip") return scores = self.with_cols( scores, { columns.OPTIMIZER: optimizer, columns.MODEL: model, columns.ADAPTOR: adaptor, columns.INDEX: corpus.index, columns.STORAGE: corpus.storage, columns.EMBEDDER: embedder, columns.TIME: self._start, columns.MD5: md5, }, ) (self.root / md5).mkdir(exist_ok=True) scores.to_csv(self.root / md5 / f"{self._start}.csv", index=False)
[docs] def with_cols(self, df: DataFrame, columns: dict[str, Any]) -> DataFrame: """ Adds identifier columns to the DataFrame. Parameters: df: The DataFrame to add the columns to. mappings: The columns to add to the DataFrame. Returns: The md5 hash of the identifier columns and the DataFrame with the columns added. """ df = df.copy() for key, value in columns.items(): df[key] = [str(value)] * len(df) return df
@staticmethod def _launch( optimizer: Optimizer, steps: int | None = None ) -> Generator[Mapping[int, float], None, None]: "Launches the optimizer as a generator." steps_range = range(steps) if steps is not None else itertools.count() for _ in ap.alive_it(steps_range, title="Running the optimizer"): # Raises StopIteration (converted to RuntimError per PEP 479) if done. try: results = optimizer.step() except StopIteration: break yield results
[docs] @staticmethod def load(path: str | Path) -> DataFrame: """ Loads the scores from the path. Parameters: path: The path to load the scores from. Returns: The loaded scores. Raises: ValueError: If the path does not exist or is not a directory. ValueError: If no csv files are found in the path. """ # Iterate over all csv files in the path. dfs = [pd.read_csv(csv) for csv in Path(path).rglob(f"*.csv")] if not dfs: raise ValueError(f"No csv files found in {path}") return pd.concat(dfs)
[docs] @staticmethod def md5( *, optimizer: Optimizer, embedder: Embedder, corpus: Corpus, model: GenerativeModel | ClassifierModel, adaptor: Adaptor, ) -> str: """ Generates an md5 hash from the given data. Parameters: optimizer: The optimizer used to generate the scores. corpus: The corpus used to generate the scores. model: The model used to generate the scores. adaptor: The adaptor used to generate the scores. embedder: The embedder used to generate the scores. time: The time the scores were generated. Returns: The md5 hash of the given data. """ data = [optimizer, embedder, corpus.index, corpus.storage, model, adaptor] return hashlib.md5( str.encode(" ".join([str(item) for item in data])) ).hexdigest()
@staticmethod def current() -> str: return dt.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")