# Copyright (c) RenChu Wang - All Rights Reserved
import datetime as dt
import hashlib
import itertools
from collections import OrderedDict
from collections.abc import Generator, Mapping
from pathlib import Path
from typing import Any
import alive_progress as ap
import pandas as pd
import structlog
from pandas import DataFrame
from bocoel.core.optim import Optimizer
from bocoel.corpora import Corpus, Embedder
from bocoel.models import Adaptor, ClassifierModel, GenerativeModel
from . import columns
from .examinators import Examinator
LOGGER = structlog.get_logger()
[docs]
class Manager:
"""
The manager for running and saving evaluations.
"""
_examinator: Examinator
"""
The examinator that would perform evaluations on the results.
"""
[docs]
def __init__(self, root: str | Path | None = None, skip_rerun: bool = True) -> None:
"""
Parameters:
root: The path to save the scores to.
skip_rerun: Whether to skip rerunning the optimizer if the scores already exist.
Raises:
ValueError: If the path is not a directory.
"""
if root is not None:
root = Path(root)
if root.exists() and not root.is_dir():
raise ValueError(f"{root} is not a directory")
root.mkdir(parents=True, exist_ok=True)
# Prevent data from being tracked by git.
gitigore = root / ".gitignore"
if not gitigore.exists():
with open(gitigore, "w+") as f:
f.write("# Automatically generated by BoCoEL.\n*")
self._start = self.current()
self._examinator = Examinator.presets()
# Public attributes. Can be overwritten at any time.
self.root = root
self.skip_rerun = skip_rerun
[docs]
def run(
self,
steps: int | None = None,
*,
optimizer: Optimizer,
embedder: Embedder,
corpus: Corpus,
model: GenerativeModel | ClassifierModel,
adaptor: Adaptor,
) -> DataFrame:
"""
Runs the optimizer until the end.
If the root path is set in the constructor,
the scores are saved to the path.
Parameters:
optimizer: The optimizer to run.
embedder: The embedder to run the optimizer with.
corpus: The corpus to run the optimizer on.
model: The model to run the optimizer with.
adaptor: The adaptor to run the optimizer with.
steps: The number of steps to run the optimizer for.
Returns:
The final state of the optimizer.
Keys are the indices of the queries,
and values are the corresponding scores.
"""
md5 = self.md5(
optimizer=optimizer,
embedder=embedder,
corpus=corpus,
model=model,
adaptor=adaptor,
)
if self.skip_rerun and self.root is not None and (self.root / md5).exists():
LOGGER.warning("Previous scores found. Skip", md5=md5)
return self.load(self.root / md5)
# Run the optimizer and collect the results.
LOGGER.info("Running the optimizer", steps=steps)
results: OrderedDict[int, float] = OrderedDict()
for res in self._launch(optimizer=optimizer, steps=steps):
results.update(res)
# Examine the results.
LOGGER.info("Examing the results")
scores = self._examinator.examine(index=corpus.index, results=results)
self.save(
scores=scores,
optimizer=optimizer,
corpus=corpus,
model=model,
adaptor=adaptor,
embedder=embedder,
md5=md5,
)
return scores
[docs]
def save(
self,
*,
scores: DataFrame,
optimizer: Optimizer,
corpus: Corpus,
model: GenerativeModel | ClassifierModel,
adaptor: Adaptor,
embedder: Embedder,
md5: str,
) -> None:
"""
Saves the scores to the path.
If the root path is not set in the constructor, the scores are not saved.
Parameters:
scores: The scores to save.
optimizer: The optimizer used to generate the scores.
corpus: The corpus used to generate the scores.
model: The model used to generate the scores.
adaptor: The adaptor used to generate the scores.
embedder: The embedder used to generate the scores.
md5: The md5 hash of the identifier columns.
Raises:
ValueError: If the path is not set.
"""
if self.root is None:
LOGGER.warning("No path set to save the scores. Skip")
return
scores = self.with_cols(
scores,
{
columns.OPTIMIZER: optimizer,
columns.MODEL: model,
columns.ADAPTOR: adaptor,
columns.INDEX: corpus.index,
columns.STORAGE: corpus.storage,
columns.EMBEDDER: embedder,
columns.TIME: self._start,
columns.MD5: md5,
},
)
(self.root / md5).mkdir(exist_ok=True)
scores.to_csv(self.root / md5 / f"{self._start}.csv", index=False)
[docs]
def with_cols(self, df: DataFrame, columns: dict[str, Any]) -> DataFrame:
"""
Adds identifier columns to the DataFrame.
Parameters:
df: The DataFrame to add the columns to.
mappings: The columns to add to the DataFrame.
Returns:
The md5 hash of the identifier columns and the DataFrame with the columns added.
"""
df = df.copy()
for key, value in columns.items():
df[key] = [str(value)] * len(df)
return df
@staticmethod
def _launch(
optimizer: Optimizer, steps: int | None = None
) -> Generator[Mapping[int, float], None, None]:
"Launches the optimizer as a generator."
steps_range = range(steps) if steps is not None else itertools.count()
for _ in ap.alive_it(steps_range, title="Running the optimizer"):
# Raises StopIteration (converted to RuntimError per PEP 479) if done.
try:
results = optimizer.step()
except StopIteration:
break
yield results
[docs]
@staticmethod
def load(path: str | Path) -> DataFrame:
"""
Loads the scores from the path.
Parameters:
path: The path to load the scores from.
Returns:
The loaded scores.
Raises:
ValueError: If the path does not exist or is not a directory.
ValueError: If no csv files are found in the path.
"""
# Iterate over all csv files in the path.
dfs = [pd.read_csv(csv) for csv in Path(path).rglob(f"*.csv")]
if not dfs:
raise ValueError(f"No csv files found in {path}")
return pd.concat(dfs)
[docs]
@staticmethod
def md5(
*,
optimizer: Optimizer,
embedder: Embedder,
corpus: Corpus,
model: GenerativeModel | ClassifierModel,
adaptor: Adaptor,
) -> str:
"""
Generates an md5 hash from the given data.
Parameters:
optimizer: The optimizer used to generate the scores.
corpus: The corpus used to generate the scores.
model: The model used to generate the scores.
adaptor: The adaptor used to generate the scores.
embedder: The embedder used to generate the scores.
time: The time the scores were generated.
Returns:
The md5 hash of the given data.
"""
data = [optimizer, embedder, corpus.index, corpus.storage, model, adaptor]
return hashlib.md5(
str.encode(" ".join([str(item) for item in data]))
).hexdigest()
@staticmethod
def current() -> str:
return dt.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")