"""This module provides classes and utilities for managing and executing
experiments with structured configurations, logging, and checkpointing.
"""
import copy
import glob
import io
import os
import re
import shutil
import traceback
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import reduce
from itertools import chain, product
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import wandb
import yaml
from absl import logging
from git import Repo
from ml_collections.config_dict import ConfigDict
from rich import print
from rich.align import Align
from rich.panel import Panel
from rich.progress import track
from rich.table import Table
from .utils import (
FuncTimeoutError,
QueuedFileLock,
df2richtable,
list2tuple,
path_common_decomposition,
settimeout_func,
str2value,
)
warnings.simplefilter(action="ignore")
ExpFunc = Union[Callable[[ConfigDict], dict], Callable[[ConfigDict, "Experiment"], dict]]
[docs]
class ConfigIter:
"""Iterator over experiment configurations defined in a structured YAML file.
This class reads a YAML file containing both static parameters and
parameter grids, then generates all combinations of configurations by
expanding the specified grid fields. Each configuration is returned as a
ConfigDict, suitable for iterating in experiment loops.
Attributes:
static_configs (dict): Configuration values that remain constant across
all runs.
grid_fields (list of str): Ordered list of grid field names, specifying
expansion order.
grid (list of dict): Raw grid specifications parsed from the YAML file.
grid_iter (list of ConfigDict): List of fully expanded configuration
dictionaries.
Example:
```python
>>> for config in ConfigIter('exp_file.yaml'):
... train_func(config)
```
## YAML Schema Example:
The YAML file should define top-level static fields and a `grid` field.
Each entry under `grid` defines a parameter sweep. Nested `group` fields
are expanded via Cartesian product before merging with other fields.
```yaml
model: ResNet32
dataset: cifar10
...
grid:
- optimizer: [sgd]
group:
pai_type: [[random], [snip]]
pai_scope: [[local], [global]]
rho: [0.05]
seed: [1, 2, 3]
- optimizer: [sam]
pai_type: [random, snip, lth]
pai_sparsity: [0, 0.9, 0.95, 0.98, 0.99]
rho: [0.05, 0.1, 0.2, 0.3]
seed: [1, 2, 3]
```
## Notes:
- ConfigDict is assumed to be a mutable configuration container (e.g.,
from `ml_collections`).
- The expansion order of configurations follows `grid_fields` if
provided; otherwise, it is inferred.
- Nested `group` fields are flattened into individual configurations
using Cartesian product logic.
Raises:
FileNotFoundError: If the YAML file path is invalid.
ValueError: If the YAML structure does not conform to the expected
schema.
"""
def __init__(self, exp_config_path: str):
"""Initializes the experiment configuration by loading and processing the
configuration file.
Args:
exp_config_path (str): The file path to the experiment configuration file.
"""
with open(exp_config_path) as f:
cnfg_str = self.__sub_cmd(f.read())
self.static_configs = yaml.safe_load(cnfg_str)
self.grid_fields = self.__extract_grid_order(cnfg_str)
self.grid = self.static_configs.pop("grid", {})
assert not (f := set(self.static_configs) & set(self.grid_fields)), (
f"Overlapping fields {f} in Static configs and grid fields."
)
self.grid_iter = self.__get_iter()
[docs]
def filter_iter(self, filt_fn: Callable[[int, dict], bool]):
"""Filters ConfigIter with ``filt_fn`` which has (idx, dict) as arguments.
Args:
filt_fn (Callable[[int, dict], bool]): Filter function to filter
ConfigIter.
"""
self.grid_iter = [d for i, d in enumerate(self.grid_iter) if filt_fn(i, d)]
@property
def grid_dict(self) -> dict:
"""Dictionary of all grid values"""
if not self.grid_fields:
return dict()
st, *sts = self.grid
castl2t = lambda vs: map(lambda v: tuple(v) if isinstance(v, list) else v, vs)
acc = lambda a, d: {k: [*{*castl2t(va), *castl2t(vd)}] for (k, va), (_, vd) in zip(a.items(), d.items())}
grid_dict = {k: [*map(self.field_type(k), vs)] for k, vs in reduce(acc, sts, st).items()}
return grid_dict
[docs]
def field_type(self, field: str):
"""Returns type of a field in grid.
Args:
field (str): Name of the field.
Returns:
Any: Type of the field.
"""
return type(self[0][field])
@staticmethod
def __sub_cmd(cfg_str):
"""Compiles list comprehension syntax in YAML config strings.
Expands bracket expressions like ``[expr;start:step:stop]`` into
evaluated Python lists, supporting multiple range segments joined
by ``--``.
Args:
cfg_str (str): Raw YAML config string potentially containing
list comprehension expressions.
Returns:
str: Config string with all list comprehension expressions
replaced by their evaluated list representations.
"""
# \[__;0:2:10--11:5:20]
# -> [__ for i in range(0, 10, 2)] + [__ for i in range(11, 20, 5)]
for entry in re.finditer(p := "\\[[^;\n]+;(\\d+:\\d+:\\d+(--)?)+\\]", cfg_str):
f, rngs = entry.group()[1:-1].split(";")
sdes = [map(int, rng.split(":")) for rng in rngs.split("--")]
assert "i" in f, f"Variable i should be in the expression '{entry}'"
assert re.sub(r"[^\+\-\*/\[\]i\d\(\), ]", "", f) == f, (
f"Cannot use alphabet other than 'i' in expression '{entry}'"
)
rep = sum([eval(f"[{f} for i in {range(s, e, d)}]") for s, d, e in sdes], start=[])
cfg_str = re.sub(p, str(rep), cfg_str, 1)
return cfg_str
@staticmethod
def __extract_grid_order(cfg_str) -> List[str]:
"""Parses the ordered list of grid field names from a raw config string.
Extracts field names appearing under the ``grid:`` section of the
YAML string, preserving their first-occurrence order and excluding
the ``group`` keyword.
Args:
cfg_str (str): Raw YAML config string.
Returns:
List[str]: Deduplicated, ordered list of grid field names.
"""
if "grid" not in cfg_str:
return []
grid = re.split("grid ?:", cfg_str)[1]
names = re.findall(r"[\w_]+(?= ?:)", grid)
dupless_names = []
for n in names:
if n in dupless_names or n == "group":
continue
dupless_names.append(n)
return dupless_names
@staticmethod
def __ravel_group(grid):
"""Flattens grouped fields in a grid into individual config dicts.
Expands ``group`` entries by zipping co-indexed values (for a single
group dict) or computing the Cartesian product (for a list of group
dicts), then merges each combination with the remaining grid fields.
Args:
grid (dict): A single grid specification dict, potentially
containing a ``group`` key.
Returns:
Iterable[dict]: An iterable of grid dicts with groups unraveled.
"""
# Return list of grid if there is no 'group'
if "group" not in grid:
return [grid]
group = grid["group"]
# Ravel grouped fields into list of config grid.
def grid_g(g):
g_len = [*map(len, g.values())]
assert all([l == g_len[0] for l in g_len]), (
f"Grouped fields should have same length, got fields with length {dict(zip(g.keys(), g_len))}"
)
return ([*zip(g.keys(), value)] for value in zip(*g.values()))
if isinstance(group, dict):
r_groups = grid_g(group)
elif isinstance(group, list):
r_groups = (chain(*gs) for gs in product(*map(grid_g, group)))
grid.pop("group")
raveled_study = ({**grid, **dict(g)} for g in r_groups)
return raveled_study
def __get_iter(self):
"""Builds the full grid iterator from the parsed config.
Unravels grouped fields, then computes the Cartesian product of
all grid field values to produce a list of config dicts.
Returns:
list[dict]: List of dicts, each mapping grid field names to
a specific combination of values.
"""
if self.grid_fields is None:
return [dict()]
# Prepare Experiment, create config grid
if type(self.grid) == dict:
self.grid = [*self.__ravel_group(self.grid)]
elif type(self.grid) == list:
self.grid = [*chain(*map(self.__ravel_group, self.grid))]
grid_s = lambda s: product(*map(s.get, self.grid_fields))
grid_iter = chain(*map(grid_s, self.grid))
grid_iter = [dict(zip(self.grid_fields, i)) for i in grid_iter]
return grid_iter
def __getitem__(self, idx):
"""Returns a ConfigDict for an integer index, or a sliced copy of this ConfigIter.
Args:
idx (int | slice): Index or slice into the grid iterator.
Returns:
ConfigDict | ConfigIter: A single merged config dict, or a new
ConfigIter with a sliced grid iterator.
"""
if isinstance(idx, int):
return ConfigDict({**self.static_configs, **self.grid_iter[idx]})
elif isinstance(idx, slice):
new_ci = copy.deepcopy(self)
new_ci.grid_iter = new_ci.grid_iter[idx]
return new_ci
def __len__(self):
"""Returns the number of configurations in the grid iterator."""
return len(self.grid_iter)
# Only a temporary measure for empty grid_fields
pd.DataFrame.old_set_index = pd.DataFrame.set_index
pd.DataFrame.old_reset_index = pd.DataFrame.reset_index
pd.DataFrame.old_drop = pd.DataFrame.drop
pd.DataFrame.set_index = lambda self, idx, *__, **_: self if not idx else self.old_set_index(idx, *__, **_)
pd.DataFrame.reset_index = lambda self, *__, **_: self if self.index.names == [None] else self.old_reset_index(*__, **_)
pd.DataFrame.drop = lambda self, *_, axis=0, **__: (
pd.DataFrame(columns=self.columns)
if self.index.names == [None] and len(self) < 2 and axis == 0
else self.old_drop(*_, axis=axis, **__)
)
[docs]
@dataclass
class ExperimentLog:
"""Logging class for experiment results.
Logs all configs for reproduction, and resulting pre-defined metrics from
experiment run as DataFrame. Changing configs are stored as multiindex and
metrics are stored as columns. Other static configs are passed in and stored
as dictionary.
These can be written to tsv file with yaml header, and loaded back from it.
Filelocks can be used to prevent race conditions when multiple processes are
reading or writing to the same log file.
Attributes:
df (pd.DataFrame): DataFrame of experiment results.
static_configs (dict): Dictionary of static configs of the experiment.
logs_file (str): File path to tsv file.
use_filelock (bool, optional): Whether to use file lock for
reading/writing log file. Defaults to False.
"""
df: pd.DataFrame
static_configs: dict
logs_file: str
use_filelock: bool = False
__sep: ClassVar[str] = "-" * 45 + "\n"
def __post_init__(self):
if self.use_filelock:
self.filelock = QueuedFileLock(self.logs_file + ".lock", timeout=3 * 60 * 60)
@property
def grid_fields(self):
return list(self.df.index.names) if self.df.index.names != [None] else []
@property
def metric_fields(self):
return list(self.df)
@property
def all_fields(self):
"""Get all static, grid, and metric fields in the log."""
return list(self.static_configs) + self.grid_fields + self.metric_fields
[docs]
def grid_dict(self) -> Dict[str, Any]:
"""Get all values for each index field in the log.
This is useful for getting all possible values for each field in the log.
For example, if the log.df has 3 fields: 'a', 'b', 'c', and the values are:
```
optimizer lr weight_decay
sgd 0.1 0.01
sgd 0.1 0.001
adam 0.01 0.01
adam 0.01 0.001
```
Then the output will be:
```python
{"optimizer": ["sgd", "adam"], "lr": [0.1, 0.01], "weight_decay": [0.01, 0.001]}
```
Returns:
dict: Dictionary of all values for each index field in the log.
"""
def __safe_sorted(l):
try:
return sorted(l)
except:
return l
return {k: __safe_sorted(set(self.df.index.get_level_values(k))) for k in self.grid_fields}
# Constructors.
# -----------------------------------------------------------------------------
[docs]
@classmethod
def from_fields(
cls, grid_fields: list, metric_fields: list, static_configs: dict, logs_file: str, use_filelock: bool = False
) -> "ExperimentLog":
"""Create ExperimentLog from grid and metric fields.
Args:
grid_fields (list): Field names of configs to be grid-searched.
metric_fields (list): Field names of metrics to be logged from
experiment results.
static_configs (dict): Other static configs of the experiment.
logs_file (str): File path to tsv file.
use_filelock (bool, optional): Whether to use file lock for
reading/writing log file. Defaults to False.
Returns:
ExperimentLog: New experiment log object.
"""
assert metric_fields is not None, "Specify the metric fields of the experiment."
assert not (f := set(grid_fields) & set(metric_fields)), (
f"Overlapping field names {f} in grid_fields and metric_fields. Remove one of them."
)
return cls(
pd.DataFrame(columns=grid_fields + metric_fields).set_index(grid_fields),
static_configs,
logs_file=logs_file,
use_filelock=use_filelock,
)
[docs]
@classmethod
def from_config_iter(
cls, config_iter: ConfigIter, metric_fields: list, logs_file: str, use_filelock: bool = False
) -> "ExperimentLog":
"""Create ExperimentLog from ConfigIter object.
Args:
config_iter (ConfigIter): ConfigIter object to reference static_configs
and grid_fields.
metric_fields (list): list of metric fields.
logs_file (str): File path to tsv file.
use_filelock (bool, optional): Whether to use file lock for
reading/writing log file. Defaults to False.
Returns:
ExperimentLog: New experiment log object.
"""
return cls.from_fields(
config_iter.grid_fields, metric_fields, config_iter.static_configs, logs_file, use_filelock=use_filelock
)
[docs]
@classmethod
def from_tsv(cls, logs_file: str, use_filelock: bool = False, parse_str=True) -> "ExperimentLog":
"""Create ExperimentLog from tsv file with yaml header.
Args:
logs_file (str): File path to tsv file.
use_filelock (bool, optional): Whether to use file lock for
reading/writing log file. Defaults to False.
parse_str (bool, optional): Whether to parse and cast string into
speculated type. Defaults to True.
Returns:
ExperimentLog: New experiment log object.
"""
if use_filelock:
with QueuedFileLock(logs_file + ".lock", timeout=3 * 60 * 60):
logs = cls.parse_tsv(logs_file, parse_str=parse_str)
else:
logs = cls.parse_tsv(logs_file, parse_str=parse_str)
return cls(logs["df"], logs["static_configs"], logs_file, use_filelock=use_filelock)
[docs]
@classmethod
def from_wandb_sweep(
cls,
entity: str,
project: str,
sweep_ids: List[str],
logs_file: str,
get_all_steps: bool = False,
filter_dict: Optional[dict] = None,
get_metrics: Optional[List[str]] = None,
) -> "ExperimentLog":
"""Create ExperimentLog from wandb sweep.
Args:
sweep_ids (List[str]): List of wandb sweep ids to load.
entity (str): wandb entity name.
project (str): wandb project name.
logs_file (str): File path to tsv file.
get_all_steps (bool, optional): Whether to get all steps of the metrics.
Defaults to False.
filter_dict (dict | None, optional): Filter for runs. Defaults to None.
get_metrics (List[str] | None, optional): List of metrics to get.
Defaults to None.
Returns:
ExperimentLog: New experiment log object.
"""
if filter_dict is not None:
filter_dict = {k: (v if isinstance(v, Sequence) else [v]) for k, v in filter_dict.items()}
api = wandb.Api()
# Get all runs from the sweeps
runs = []
for sweep_id in sweep_ids:
sweep = api.sweep("/".join([entity, project, sweep_id]))
runs += sweep.runs
# Process runs with progress bar
static_configs = {}
grid_fields = set()
metric_fields = set()
filtered_runs = []
state_stats = {}
for run in runs:
# Filter runs based on filter
if filter_dict is not None and not all((run.config[k] in vs) for k, vs in filter_dict.items()):
continue
if (s := run.state) not in ["finished", "completed"]:
state_stats[s] = state_stats.get(s, 0) + 1
continue
# Add to filtered runs
filtered_runs.append(run)
# Update static_configs
for key, value in run.config.items():
if key not in static_configs:
static_configs[key] = value
elif static_configs[key] != value:
static_configs.pop(key, None) # Remove if differing cases occur
grid_fields.add(key) # Add to grid fields if differing cases occur
metric_fields.update(run.summary.keys()) # Collect metric fields
if state_stats:
logging.warning(f"Skipping {sum(state_stats.values())}/{len(runs)} runs with states: {state_stats}")
# Filter runs if filter is provided
if get_metrics is not None:
metric_fields &= set(get_metrics)
grid_fields = list(grid_fields)
metric_fields = list(metric_fields)
def __preprocess_value(v):
if isinstance(v, list):
return tuple(v)
return v
# process dataframe
data = []
for run in track(filtered_runs, description="Processing runs"):
grid_configs = {key: __preprocess_value(run.config[key]) for key in grid_fields}
row = {**grid_configs, **run.summary}
if get_all_steps:
history = run.history(samples=None)
all_step_metrics = {}
for k in metric_fields:
if k not in history.columns:
continue
s = history[k].dropna() # drops both None and NaN
if s.empty:
continue # IMPORTANT: don't store useless [] (or [None,...])
all_step_metrics[k] = s.tolist()
row.update(all_step_metrics)
data.append(row)
try:
df = pd.DataFrame(data)
df.set_index(grid_fields, inplace=True)
except:
list_cols = [c for c in df.columns if df[c].dropna().apply(lambda x: isinstance(x, list)).any()]
print(df[list_cols].head())
list_grid_fields = [c for c in grid_fields if df[c].dropna().apply(lambda x: isinstance(x, list)).any()]
print(df[list_grid_fields].head())
raise
return cls(df, static_configs, logs_file=logs_file, use_filelock=False)
# tsv handlers.
# ---------------------------------------------------------------------------
[docs]
@classmethod
def parse_tsv(cls, log_file: str, parse_str=True) -> dict:
"""Parse tsv file into usable datas.
Parse tsv file generated by ExperimentLog.to_tsv method.
Has static_config as yaml header, and DataFrame as tsv body where
multiindices is set as different line with column names.
Args:
log_file (str): File path to tsv file.
parse_str (bool, optional): Whether to parse and cast string into
speculated type. Defaults to True.
Raises:
Exception: Error while reading log file.
Returns:
dict: Dictionary of pandas.DataFrame, grid_fields, metric_fields, and
static_configs.
"""
assert os.path.exists(log_file), f'File path "{log_file}" does not exists.'
try:
with open(log_file) as fd:
# process yaml config header
def header():
next(fd)
header = ""
for s in fd:
if s == cls.__sep:
break
header += s
return header
# get workload data from yaml header
static_configs = yaml.safe_load(header())
# get dataframe from tsv body
tsv_str = fd.read()
except:
raise Exception(f"Error while reading log file: {log_file}")
tsv_col, tsv_idx, *tsv_body = tsv_str.split("\n")
col = tsv_col.strip().split("\t")
idx = tsv_idx.strip().split("\t")
tsv_head = "\t".join(idx + col)
tsv_str = "\n".join([tsv_head, *tsv_body])
df = pd.read_csv(io.StringIO(tsv_str), sep="\t")
df = df.drop(["id"], axis=1)
if parse_str:
df = df.applymap(str2value) if not hasattr(df, "map") else df.map(str2value)
# set grid_fields to multiindex
df = df.set_index(idx[1:])
return {"df": df, "grid_fields": idx[1:], "metric_fields": col, "static_configs": static_configs}
[docs]
def lock_file(func):
"""Decorator that wraps a method with filelock acquire/release.
If ``self.use_filelock`` is True, acquires the filelock before
calling ``func`` and releases it afterward. Otherwise calls
``func`` directly.
Args:
func (Callable): The method to wrap.
Returns:
Callable: Wrapped method with conditional file locking.
"""
def wrapped(self, *args, **kwargs):
if self.use_filelock:
with self.filelock:
return func(self, *args, **kwargs)
else:
return func(self, *args, **kwargs)
return wrapped
[docs]
@lock_file
def load_tsv(self, logs_file: Optional[str] = None, parse_str: bool = True):
"""Load tsv with yaml header into ExperimentLog object.
Args:
logs_file (Optional[str], optional): Specify other file path to tsv
file. Defaults to None.
parse_str (bool, optional): Whether to parse and cast string into
speculated type. Defaults to True.
"""
if logs_file is not None:
self.logs_file = logs_file
logs = self.parse_tsv(self.logs_file, parse_str=parse_str)
self.df = logs["df"]
self.static_configs = logs["static_configs"]
[docs]
@lock_file
def to_tsv(self, logs_file: Optional[str] = None):
"""Write ExperimentLog object to tsv file with yaml header.
Args:
logs_file (Optional[str], optional): Specify other file path to tsv
file. Defaults to None.
"""
logs_file = self.logs_file if logs_file == None else logs_file
logs_path, _ = os.path.split(logs_file)
if not os.path.exists(logs_path):
os.makedirs(logs_path)
# pandas dataframe to tsv string
df = self.df.reset_index()
df["id"] = [*range(len(df))]
df = df.set_index(["id", *self.grid_fields])
tsv_str = df.to_csv(sep="\t")
tsv_head, *tsv_body = tsv_str.split("\n")
tsv_head = tsv_head.split("\t")
col = "\t".join([" " * len(i) if i in df.index.names else i for i in tsv_head])
idx = "\t".join([i if i in df.index.names else " " * len(i) for i in tsv_head])
tsv_str = "\n".join([col, idx, *tsv_body])
# write static_configs and table of results
with open(logs_file, "w") as fd:
fd.write("[Static Configs]\n")
yaml.dump(self.static_configs, fd)
fd.write(self.__sep)
fd.write(tsv_str)
# Add results.
# ---------------------------------------------------------------------------
[docs]
def add_result(self, configs: Mapping[str, Any], **metrics):
"""Add experiment run result to dataframe.
Args:
configs (Mapping[str, Any]): Dictionary or Mapping of configurations of
the result of the experiment instance to add.
**metrics (Any): Metrics of the result of the experiment instance to
add.
"""
if configs in self:
cur_gridval = list2tuple([configs[k] for k in self.grid_fields])
self.df = self.df.drop(cur_gridval)
configs = {k: list2tuple(configs[k]) for k in self.grid_fields}
metrics = {k: metrics.get(k) for k in self.metric_fields}
result_dict = {k: [v] for k, v in {**configs, **metrics}.items()}
result_df = pd.DataFrame(result_dict).set_index(self.grid_fields)
self.df = pd.concat([self.df, result_df])[self.metric_fields]
# Field manipulations.
# ---------------------------------------------------------------------------
@staticmethod
def __add_column(df, new_column_name: str, fn: Callable, *fn_arg_fields: str) -> pd.DataFrame:
"""Adds a new column to a DataFrame computed from existing columns.
Applies ``fn`` row-wise to the values in ``fn_arg_fields``. If any
argument is not a primitive type, the result is set to None.
Args:
df (pd.DataFrame): DataFrame to add the column to.
new_column_name (str): Name of the new column.
fn (Callable): Function to compute the new column value.
*fn_arg_fields (str): Column names whose values are passed
as arguments to ``fn``.
Returns:
pd.DataFrame: The DataFrame with the new column added.
"""
def mapper(*args):
if all(isinstance(i, (int, float, str, tuple, list)) for i in args):
return fn(*args)
return None
df[new_column_name] = df.apply(lambda df: mapper(*[df[c] for c in fn_arg_fields]), axis=1)
return df
[docs]
def derive_field(self, new_field_name: str, fn: Callable, *fn_arg_fields: str, is_index: bool = False):
"""Add new field computed from existing fields in self.df.
Args:
new_field_name (str): Name of the new field.
fn (Callable): Function to compute new field.
*fn_arg_fields (str): Field names to be used as arguments for the
function.
is_index (bool, optional): Whether to add field as index. Defaults to
False.
"""
df = self.df.reset_index(self.grid_fields)
df = self.__add_column(df, new_field_name, fn, *fn_arg_fields)
new_grid_fields = self.grid_fields
if is_index:
new_grid_fields.append(new_field_name)
self.df = df.set_index(new_grid_fields)
[docs]
def drop_fields(self, field_names: List[str]):
"""Drop fields from the log.
Args:
field_names (List[str]): list of field names to drop.
"""
assert not (ns := set(field_names) - set(self.all_fields)), (
f"Field names {ns} not in any of static {list(self.static_configs)}, "
f"grid {self.grid_fields}, or metric {self.metric_fields} field names."
)
grid_ns, metric_ns = [], []
for fn in field_names:
if fn in self.static_configs:
del self.static_configs[fn] # remove static fields
elif fn in self.grid_fields:
grid_ns.append(fn)
elif fn in self.metric_fields:
metric_ns.append(fn)
self.df = self.df.reset_index(grid_ns, drop=True) # remove grid field
self.df = self.df.drop(columns=metric_ns) # remove metric field
[docs]
def rename_fields(self, name_map: Dict[str, str]):
"""Rename fields in the log.
Args:
name_map (Dict[str, str]): Mapping of old field names to new field
names.
"""
assert not (ns := set(name_map) - set(self.all_fields)), (
f"Field names {ns} not in any of static {list(self.static_configs)}, "
f"grid {self.grid_fields}, or metric {self.metric_fields} field names."
)
grid_l, metric_d = self.grid_fields, {}
for on, nn in name_map.items():
if on in self.static_configs:
# update static field name
self.static_configs[nn] = self.static_configs.pop(on)
elif on in self.grid_fields:
grid_l[grid_l.index(on)] = nn
elif on in self.metric_fields:
metric_d[on] = nn
self.df.index.rename(grid_l, inplace=True) # update grid field names
self.df.rename(columns=metric_d, inplace=True) # update metric field names
# Merge ExperimentLogs.
# ---------------------------------------------------------------------------
[docs]
def resolve_merge_conflicts(self, other: "ExperimentLog") -> Tuple["ExperimentLog", "ExperimentLog"]:
"""CLI to summarize merge conflicts and accept user input for resolution.
Args:
other (ExperimentLog): Target log to merge with self.
Returns:
Tuple[ExperimentLog, ExperimentLog]: Resolved logs (self, other).
"""
if self == other:
return self, other
print("\nConflict detected between logs: ")
print(f" - Self :{self.logs_file}")
print(f" - Other:{other.logs_file}")
print("Start resolving conflict...")
self_d, other_d = {}, {}
for log, d in [(self, self_d), (other, other_d)]:
d["sttc_d"] = log.static_configs
d["grid_d"] = {k: sorted(set(log.df.index.get_level_values(k))) for k in log.grid_fields}
d["dict"] = {**d["sttc_d"], **d["grid_d"]}
d["fields"] = list(log.static_configs.keys()) + list(log.grid_fields)
sfs, sfo = map(lambda d: set(d["fields"]), (self_d, other_d))
same_fields = sorted(sfs & sfo)
new_to_self = sorted(sfo - sfs)
new_to_othr = sorted(sfs - sfo)
ln_k = max([len(k) for k in same_fields + new_to_self + new_to_othr])
ln_s = max([len(str(self_d["dict"].get(k, ""))) for k in same_fields + new_to_self + new_to_othr])
ln_o = max([len(str(other_d["dict"].get(k, ""))) for k in same_fields + new_to_self + new_to_othr])
####################### Print conflict summary ############################
_, (self_post, othr_post) = path_common_decomposition([self.logs_file, other.logs_file])
summary_tab = Table(title="Log field conflict summary")
summary_tab.add_column("Field", style="bold")
summary_tab.add_column(f"[blue]Self[/blue] ({self_post[:-4]})")
summary_tab.add_column(f"[green]Other[/green] ({othr_post[:-4]})")
for i, k in enumerate(same_fields):
summary_tab.add_row(
f"{k:{ln_k}s}",
f"{self_d['dict'].get(k, '')!s:{ln_s}s}",
f"{other_d['dict'].get(k, '')!s:{ln_o}s}",
style="on bright_black" if i % 2 else "",
end_section=(i == len(same_fields) - 1),
)
rd = lambda s, i: f"[on {'red' if i % 2 else 'dark_red'}]{s} [/on {'red' if i % 2 else 'dark_red'}]"
for i, k in enumerate(new_to_self):
i += len(same_fields)
summary_tab.add_row(
f"{k:{ln_k}s}",
rd(f"{self_d['dict'].get(k, '')!s:{ln_s}s}", i),
f"{other_d['dict'].get(k, '')!s:{ln_o}s}",
style="on bright_black" if i % 2 else "",
end_section=(i == len(same_fields) + len(new_to_self) - 1),
)
for i, k in enumerate(new_to_othr):
i += len(same_fields) + len(new_to_self)
summary_tab.add_row(
f"{k:{ln_k}s}",
f"{self_d['dict'].get(k, '')!s:{ln_s}s}",
rd(f"{other_d['dict'].get(k, '')!s:{ln_o}s}", i),
style="on bright_black" if i % 2 else "",
)
print(Align(summary_tab, align="center"))
print(
Align(
Panel(
f"Detected [bold red]{len(new_to_self + new_to_othr)}[/bold red] conflicts to resolve.",
padding=(1, 3),
),
align="center",
)
)
############################# Resolve conflicts ###########################
i_cfl, n_cfl = 0, len(new_to_self + new_to_othr)
logs = [
(self, f"[blue]{self_post[:-4]}[/blue]", self_d, new_to_self),
(other, f"[green]{othr_post[:-4]}[/green]", other_d, new_to_othr),
]
# resolve conflict for each log
for i in (False, True):
tlog, ts, td, ntt = logs[i]
flog, fs, fd, ntf = logs[not i]
if ntt:
print(f"\n[bold][Handle missing fields in {ts}][/bold]", f"(Default: same/first value of {fs})")
j = 0
while j < len(ntt):
k = ntt[j]
tab = Table()
tab.add_column("Field", style="bold")
tab.add_column(f"[blue]Self[/blue] ({self_post[:-4]})")
tab.add_column(f"[green]Other[/green] ({othr_post[:-4]})")
if tlog == self:
tab.add_row(
f"{k:{ln_k}s}",
rd(f"{self_d['dict'].get(k, '')!s:{ln_s}s}", 0),
f"{other_d['dict'].get(k, '')!s:{ln_o}s}",
)
elif tlog == other:
tab.add_row(
f"{k:{ln_k}s}",
f"{self_d['dict'].get(k, '')!s:{ln_s}s}",
rd(f"{other_d['dict'].get(k, '')!s:{ln_o}s}", 0),
)
i_cfl += 1
print(f"│\n├─[{i_cfl}/{n_cfl}] [bold]({k})[/bold]", tab)
# set default value
dflt = False
dflt_val = fd["dict"].get(k, "")
# set list to single value if it is in grid_fields
if k in fd["grid_d"]:
dflt_val = dflt_val[0] if len(dflt_val) > 0 else None
# choose mode
modes = ["Add new value"]
if ntf:
modes.append("merge with existing field")
if k in fd["sttc_d"]:
modes.append("remove")
if 0 < j < len(ntt):
modes.append("revert")
mode = 0
if len(modes) > 1:
print(
"│ Choose process mode ("
+ " / ".join([f"{i}: {md}" for i, md in enumerate(modes)])
+ f"/ else: set value to {dflt_val})"
)
mode = str2value(input("│ ↳ "))
# process for each modes
if isinstance(mode, int):
if modes[mode] == "Add new value":
print(f"│ ({mode}) Add new value")
new_val = str2value(input("│ ↳ "))
if new_val:
tlog.static_configs[k] = new_val
print(f"│ - Set to {new_val}")
else:
dflt = True
elif modes[mode] == "merge with existing field":
print(f"│ ({mode}) Merge with existing field in {ts}: {ntf}")
while True:
new_field = input("│ ↳ ")
if new_field in ntf + [""]:
break
print(f"│ There is no field:{new_field} to merge with.", f"Choose from {ntf}")
if new_field:
flog.rename_fields({k: new_field})
ntf.remove(new_field)
n_cfl -= 1
print(f"│ - Merged with {new_field}")
else:
dflt = True
elif modes[mode] == "remove":
print(f"│ ({mode}) Remove field")
del flog.static_configs[k]
elif modes[mode] == "revert":
print(f"│ ({mode}) Revert to prior field")
i_cfl -= 2
j -= 1
continue
else:
dflt = True
else:
dflt = True
if dflt:
print(f"│ - Set to {dflt_val}")
tlog.static_configs[k] = str2value(dflt_val)
j += 1
print("│\n└─[[bold cyan]Done[/bold cyan]]")
return self, other
def __merge_one(self, other: "ExperimentLog", same=True) -> "ExperimentLog":
"""Merge two logs into self.
Notes:
- The order of grid_fields follows self
- Static fields stays only if they are same for both logs.
- else move to grid_fields if not in grid_fields
"""
if same:
assert self == other, "Different experiments cannot be merged by default."
# new static_field: field in both log.static_field and have same value
sc1, sc2 = (log.static_configs for log in (self, other))
new_sttc = {k: sc1[k] for k in set(sc1) & set(sc2) if sc1[k] == sc2[k]}
new_gridf = self.grid_fields + list(set(sc1) - set(new_sttc))
new_mtrcf = self.metric_fields + [k for k in other.metric_fields if k not in self.metric_fields]
# field static->grid: if not in new static_field and not in grid
dfs = []
for log in (self, other):
dfs.append(log.df.reset_index())
for k in set(log.static_configs) - set(new_sttc):
if k in log.grid_fields:
continue
dfs[-1][k] = [list2tuple(log.static_configs.get(k, np.nan))] * len(log)
# merge and update self
self.static_configs = new_sttc
self.df = pd.concat(dfs).set_index(new_gridf)[new_mtrcf]
return self
[docs]
def merge(self, *others: "ExperimentLog", same: bool = True):
"""Merge multiple logs into self.
Args:
*others (ExperimentLog): Logs to merge with self.
same (bool, optional): Whether to raise error when logs are not of
matching experiments. Defaults to True.
"""
for other in others:
self, other = self.resolve_merge_conflicts(other)
self.__merge_one(other, same=same)
[docs]
@staticmethod
def merge_tsv(*log_files: str, save_path: Optional[str] = None, same: bool = True) -> "ExperimentLog":
"""Merge multiple logs into one from tsv file paths.
Args:
*logs_path (str): Path to logs.
save_path (Optional[str]): Path to save merged log.
same (bool, optional): Whether to raise error when logs are not of
matching experiments. Defaults to True.
"""
base, *logs = [ExperimentLog.from_tsv(f, parse_str=False) for f in log_files]
base.merge(*logs, same=same)
if save_path:
base.to_tsv(save_path)
return base
[docs]
@staticmethod
def merge_folder(logs_path: str, save_path: Optional[str] = None, same: bool = True) -> "ExperimentLog":
"""Merge multiple logs into one from tsv files in folder.
Args:
logs_path (str): Folder path to logs.
save_path (Optional[str], optional): Path to save merged log. Defaults
to None.
same (bool, optional): Whether to raise error when logs are not of
matching experiments. Defaults to True.
"""
log_files = glob.glob(os.path.join(logs_path, "*.tsv"))
assert log_files, f"No tsv files found in {logs_path}"
return ExperimentLog.merge_tsv(*log_files, save_path=save_path, same=same)
# Utilities.
# ----------------------------------------------------------------------------
def __cfg_match_row(self, config):
"""Finds rows in the log DataFrame matching the given config's grid fields.
Args:
config (Mapping[str, Any]): Configuration to match against
the DataFrame's multi-index.
Returns:
pd.DataFrame: Subset of ``self.df`` whose index values match
the config's grid field values.
"""
if not self.grid_fields:
return self.df
grid_filt = reduce(
lambda l, r: l & r,
(
self.df.index.get_level_values(k) == (str(config[k]) if isinstance(config[k], list) else config[k])
for k in self.grid_fields
),
)
return self.df[grid_filt]
[docs]
def isin(self, config: Mapping[str, Any]) -> bool:
"""Check if specific experiment config was already executed in log.
Args:
config (Mapping[str, Any]): Configuration instance to check if it is in
the log.
Returns:
bool: Whether the config is in the log.
"""
if self.df.empty:
return False
cfg_same_in_static = all([config[k] == v for k, v in self.static_configs.items() if k in config])
cfg_matched_df = self.__cfg_match_row(config)
return cfg_same_in_static and not cfg_matched_df.empty
[docs]
def get_metric(self, config: Mapping[str, Any]) -> dict:
"""Search matching log with given config dict and return metric_dict,
info_dict.
Args:
config (Mapping[str, Any]): Configuration instance to search in the log.
Returns:
dict: Found metric dictionary of the given config.
"""
assert config in self, "config should be in self when using get_metric_dict."
cfg_matched_df = self.__cfg_match_row(config)
metric_dict = {k: (v.iloc[0] if not (v := cfg_matched_df[k]).empty else None) for k in self.metric_fields}
return metric_dict
[docs]
def is_same_exp(self, other: "ExperimentLog") -> bool:
"""Check if both logs have same config fields.
Args:
other (ExperimentLog): Log to compare with.
Returns:
bool: Whether both logs have same config fields.
"""
fields = lambda log: set(log.static_configs.keys()) | set(log.grid_fields)
return fields(self) == fields(other)
[docs]
def drop_duplicates(self):
"""Drop duplicate entries and provides CLI to resolve conflicting
duplicates.
Resolves duplicates with same grid fields but different metric fields.
If the duplicates also has same metric fields, they will be automatically
removed but one.
"""
if self.df.index.nlevels == 0:
return
# get duplicates
grids_dup = self.df.index.duplicated(keep=False)
duplicates = pd.DataFrame(self.df.loc[grids_dup])
non_duplicates = pd.DataFrame(self.df.loc[~grids_dup])
# print duplicates
print(
f"\nDuplicate entries detected: {len(duplicates)}/{len(self)}",
f"({len(duplicates.index.unique()) + (len(non_duplicates))} unique configs)",
)
# first change all lists to tuples for duplicate comparison (hashable)
# and remove other unhashable types
def __process_unhashable(x):
try:
hash(x)
return x
except TypeError:
if isinstance(x, list):
return list2tuple(x)
return None
duplicates = duplicates.map(__process_unhashable)
# separate trivial duplicates (same metric fields) and conflict duplicates
trivial_dups = duplicates.duplicated(keep=False)
conflict_dups = duplicates[~trivial_dups]
trivial_dups = duplicates[trivial_dups]
# Automatically remove duplicates with identical metric fields
to_remove_dups = trivial_dups.duplicated(keep="first")
trivial_deduped = trivial_dups[~to_remove_dups]
non_duplicates = pd.concat([non_duplicates, trivial_deduped])
print(f"Removed {len(to_remove_dups)} trivial duplicates.")
# Iterate through remaining duplicates for resolution
if len(conflict_dups) > 1:
print(
Align(
Panel(
f"Detected [bold red]{len(conflict_dups)}[/bold red] conflicting "
f"duplicates to resolve. "
f"([bold cyan]{len(conflict_dups.index.unique())}[/bold cyan] "
"unique configs)",
padding=(1, 3),
),
align="center",
)
)
while True:
conf_handle = input("Manually resolve? (y/n): ")
if conf_handle.lower() in ["y", "yes", "n", "no"]:
break
print('Invalid input. Please enter "y" or "n".')
manually_resolve = conf_handle.lower()[0] == "y"
if manually_resolve:
print("\n[bold][Manually resolving duplicates][/bold]")
else:
print("Automatically resolving conflict by keeping first row of each config.")
conflict_config_idxs = conflict_dups.index.unique()
for c_i, idx in enumerate(conflict_config_idxs):
duplicate_group = conflict_dups.loc[idx]
if manually_resolve:
# print config
tab = Table()
for k in self.grid_fields:
tab.add_column(k)
tab.add_row(*map(str, idx))
print(f"│\n├─[{c_i + 1}/{len(conflict_config_idxs)}]", tab)
# find conflicting column and resolve by user input individually
conflict_cn = [c for c in duplicate_group.columns if duplicate_group[c].nunique() > 1]
print(f"│ Conflicts in {len(conflict_cn)} column(s):", ", ".join(conflict_cn))
for i, cn in enumerate(conflict_cn):
col = duplicate_group[cn]
print(f"│ Resolve column {cn} ({i + 1}/{len(conflict_cn)})")
print(
df2richtable(
duplicate_group,
highlight_columns=[cn],
max_col_width=3,
col_center=cn,
max_seq_value_len=3,
alternating_row_colors=True,
)
)
while True:
choice = input(f"│ Select row index (0-{len(col) - 1}): ")
if choice in [*map(str, range(len(col)))]:
break
print(f"│ Invalid input. Please enter a number between 0 and {len(col) - 1}.")
duplicate_group.iloc[0][cn] = list(col)[int(choice)]
# update the DataFrame with resolved duplicates
non_duplicates.loc[idx, :] = list(duplicate_group.iloc[0])
# Update the DataFrame to remove resolved duplicates
non_duplicates = non_duplicates.map(lambda x: list(x) if isinstance(x, tuple) else x)
print(
Align(
Panel(
f"Duplicates resolved (rows: [bold cyan]{len(self)}[/bold cyan]->"
f"[bold cyan]{len(non_duplicates)}[/bold cyan]).",
padding=(1, 3),
),
align="center",
)
)
self.df = non_duplicates
[docs]
def melt_and_explode_metric(
self, df: Optional[pd.DataFrame] = None, step: Optional[int] = None, dropna: bool = True
) -> pd.DataFrame:
"""Melt and explode metric values in DataFrame.
Melt column (metric) names into 'metric' field (multi-index) and their
values into 'metric_value' columns. Explode metric with list of values into
multiple rows with new 'step' and 'total_steps' field. If step is specified,
only that step is selected, otherwise all steps are exploded.
Args:
df (Optional[pd.DataFrame], optional): Base DataFrame to operate over.
Defaults to None.
step (Optional[int], optional): Specific step to select. Defaults to
None.
dropna (bool, optional): Whether to drop rows with NaN metric values.
Defaults to True.
Returns:
pd.DataFrame: Melted and exploded DataFrame.
"""
if df is None:
df = self.df
mov_to_index = lambda *fields: df.reset_index().set_index(
(dn if (dn := df.index.names) != [None] else []) + [*fields]
)
# melt
df = df.melt(value_vars=list(df), var_name="metric", value_name="metric_value", ignore_index=False)
df = mov_to_index("metric")
# Create step field and explode
pseudo_len = lambda x: len(x) if isinstance(x, list) else 1
df["total_steps"] = df["metric_value"].map(pseudo_len)
if step is None:
df["step"] = df["metric_value"].map(lambda x: range(1, pseudo_len(x) + 1))
# explode metric list so each step gets its own row
df = df.explode("step")
else:
df["step"] = df["metric_value"].map(lambda x: step + (pseudo_len(x) + 1 if step < 0 else 0))
df["metric_value"] = df.apply(
lambda df: df["metric_value"][df.step - 1] if isinstance(df["metric_value"], list) else df["metric_value"],
axis=1,
) # list[epoch] for all fields
df = mov_to_index("step", "total_steps")
# delete string and NaN valued rows
if dropna:
df = df[pd.to_numeric(df["metric_value"], errors="coerce").notnull()].dropna().astype("float")
return df
def __contains__(self, config: Mapping[str, Any]) -> bool:
return self.isin(config)
def __getitem__(self, config: Mapping[str, Any]) -> dict:
return self.get_metric(config)
def __eq__(self, other: "ExperimentLog") -> bool:
return self.is_same_exp(other)
def __len__(self):
return len(self.df)
def __str__(self):
lines = ["[Static Configs]", *(f"{k}: {v}" for k, v in self.static_configs.items()), self.__sep, str(self.df)]
return "\n".join(lines)
[docs]
class RunInfo:
"""Use for tracking and managing information about a specific run or
execution, including the start time, duration, and the current Git commit
hash.
Attributes:
infos (ClassVar[list]): A class-level attribute that lists the keys of the
run information ('datetime', 'duration', 'commit_hash').
"""
infos: ClassVar[list] = ["datetime", "duration", "commit_hash"]
def __init__(self, prev_duration: timedelta = timedelta(0)):
self.__datetime = datetime.now()
self.__duration = prev_duration
try:
self.__commit_hash = Repo.init().head.commit.hexsha
except:
self.__commit_hash = None
logging.info("No git exist in current directory.")
[docs]
def get(self):
"""Returns the current run info as a dict.
Returns:
dict: Dictionary with keys ``datetime``, ``duration``, and
``commit_hash``.
"""
return {"datetime": self.__datetime, "duration": self.__duration, "commit_hash": self.__commit_hash}
[docs]
def update_and_get(self):
"""Updates the accumulated duration to now and returns run info.
Adds the elapsed time since the last update (or init) to the
stored duration, then returns the updated run info dict.
Returns:
dict: Dictionary with keys ``datetime``, ``duration``, and
``commit_hash``.
"""
curr_t = datetime.now()
self.__duration += curr_t - self.__datetime
self.__datetime = curr_t
return self.get()
[docs]
class Experiment:
"""Execute experiments based on provided configurations. It supports
parallel-friendly experiment scheduling and provides mechanisms for logging,
resuming, and managing experiment runs.
## Features:
- Supports two approaches for parallelizing experiments:
1. Splitting: Splits hyperparameter configs evenly across multiple
parallel processes. While this requires predetermining the number of
Experiment processes to use, this is more reliabe when using very many
parallel runs (many gpus).
2. Queueing: Each parallel processes check for unexecuted hyperparamter
configs (or queue), similairly to WandB Agents. This allows to
dynamically allocate new resources as they become available, but can be
less reliable when using many parallel runs. It utilizes the log file to
share currently running configs, by logging empty metrics of running
configurations. Here, file locking should be used to avoid race
conditions in read/writing to log file.
- Automatically resumes from or skip previously run config from hyperparam
grid using saved logs.
- Allows checkpointing metrics during training steps, allowing to stop and
resume mid-training much like model checkpointing. Here, file locking
should be used to avoid race conditions from frequent log file access.
Attributes:
name (str): Name of the experiment.
exp_func (ExpFunc): Function to execute the experiment. Is should recieve
the config (or self when checkpointing is enabled) and return the
resulting metrics to log.
configs (ConfigIter): Configuration iterator for the experiment.
log (ExperimentLog): Log object for tracking experiment results.
infos (list): List of information fields for logging, including status.
configs_save (bool): Whether to save only the config (with empty metrics)
of currently running config to logs file. This is used for queueing
mode.
checkpoint (bool): Whether mid-train checkpointing is executed in
exp_func. When set to True, Experiment.run will pass it self in addition
to the current config. It is expected that the user will use this to
update the log every time a model checkpoint is saved.
filelock (bool): Whether to use file locking for the log file. This will
create a lock file where the log file is located. The status on which
Experiment run has is accessing or has requested the access to the log
file is logged in the lock file, letting other runs to wait until the
lock file is released.
timeout (Optional[float]): Expected timeout for used resource system.
Experiment run will use this to execute necessary logging before the
system timeout terminates the run. In configs_save mode, some configs
might terminate while being status set as running in the log file, which
won't be executed in the next continuing Expreiment run.
"""
__RUNNING: ClassVar[str] = "R"
__FAILED: ClassVar[str] = "F"
__COMPLETED: ClassVar[str] = "C"
infos: ClassVar[list] = [*RunInfo.infos, "status"]
def __init__(
self,
exp_folder_path: str,
exp_function: ExpFunc,
exp_metrics: Optional[list] = None,
total_splits: Union[int, str] = 1,
curr_split: Union[int, str] = 0,
configs_save: bool = False,
checkpoint: bool = False,
filelock: bool = False,
timeout: Optional[float] = None,
):
"""Initializes an experiment instance with the specified parameters.
Args:
exp_folder_path (str): Path to the experiment folder. The experiment
folder should include a grid configuration file named 'exp_config.yaml'
written accoring to ConfigIter rules. This is used to save the logs and
other files related to the experiment.
exp_function (ExpFunc): Function to execute the experiment. Is should
recieve the config (or self when checkpointing is enabled) and return
the resulting metrics to log.
exp_metrics (Optional[list], optional): List of experiment metric names to
be logged. Defaults to None.
total_splits (Union[int, str], optional): Total number of process for
splitting mode. Defaults to 1.
curr_split (Union[int, str], optional): Current process index for the
splitting mode. Defaults to 0.
configs_save (bool, optional): Whether to save only the config (with empty
metrics) of currently running config to logs file. This is used for
queueing mode.
checkpoint (bool, optional): Whether mid-train checkpointing is executed
in exp_func. When set to True, Experiment.run will pass it self in
addition to the current config. It is expected that the user will use
this to update the log every time a model checkpoint is saved. Defaults
to False.
filelock (bool, optional): Whether to use file locking for the log file.
This will create a lock file where the log file is located. The status
on which Experiment run has is accessing or has requested the access to
the log file is logged in the lock file, letting other runs to wait
until the lock file is released. Defaults to False.
timeout (Optional[float], optional): Expected timeout for used resource
system. Experiment run will use this to execute necessary logging before
the system timeout terminates the run. In configs_save mode, some
configs might terminate while being status set as running in the log
file, which won't be executed in the next continuing Expreiment run.
Default to None.
Raises:
AssertionError: If `checkpoint` is True but `filelock` is not set to True.
"""
if checkpoint:
assert filelock, "Argument 'filelock' should be set to True when checkpointing."
self.name = exp_folder_path.split("/")[-1]
self.exp_func = exp_function
self.configs_save = configs_save
self.checkpoint = checkpoint
self.filelock = filelock
self.timeout = timeout
do_split = (isinstance(total_splits, int) and total_splits > 1) or isinstance(total_splits, str)
cfg_file, tsv_file, _ = self.get_paths(exp_folder_path, split=curr_split if do_split else None)
self.configs = self.__get_and_split_configs(cfg_file, total_splits, curr_split, self.name)
self.log = self.__get_log(tsv_file, self.infos + exp_metrics, filelock)
self.__check_matching_static_configs()
@staticmethod
def __get_and_split_configs(cfg_file, exp_bs, exp_bi, name):
"""Loads configs from a YAML file and splits them for parallel runs.
Supports two splitting modes: uniform index-based splitting when
``exp_bs`` is an integer, or field-value-based splitting when
``exp_bs`` is a grid field name.
Args:
cfg_file (str): Path to the YAML experiment config file.
exp_bs (int | str): Total number of splits (int) or grid field
name to split on (str).
exp_bi (int | str): Current split index (int) or space-separated
field values to select (str).
name (str): Experiment name for logging.
Returns:
ConfigIter: Filtered configuration iterator for this split.
"""
configiter = ConfigIter(cfg_file)
assert isinstance(exp_bs, int) or (exp_bs in configiter.grid_fields), (
f"Enter valid splits (int | Literal{configiter.grid_fields})."
)
# if total exp split is given as integer : uniformly split
if isinstance(exp_bs, int):
assert exp_bs > 0, "Total number of experiment splits should be larger than 0"
assert exp_bs > exp_bi, (
"Experiment split index should be smaller than the total number of experiment splits"
)
if exp_bs > 1:
configiter.filter_iter(lambda i, _: i % exp_bs == exp_bi)
logging.info(f"Experiment : {name} (split : {exp_bi + 1}/{exp_bs})")
# else split across certain study field
elif exp_bs in configiter.grid_fields:
exp_bi = [*map(str2value, exp_bi.split())]
configiter.filter_iter(lambda _, d: d[exp_bs] in exp_bi)
logging.info(f"Experiment : {name} (split : {exp_bi}/{configiter.grid_dict[exp_bs]})")
return configiter
def __get_log(self, logs_file, metric_fields=None, filelock=False):
"""Loads an existing experiment log from file, or creates a new one.
If the log file exists, it is loaded and resumed. Otherwise, a new
``ExperimentLog`` is created from the current config iterator and
saved to disk.
Args:
logs_file (str): Path to the TSV log file.
metric_fields (list | None): Metric field names for a new log.
filelock (bool): Whether to use file locking.
Returns:
ExperimentLog: The loaded or newly created experiment log.
"""
# Configure experiment log
if os.path.exists(logs_file): # Check if there already is a file
log = ExperimentLog.from_tsv( # resumes automatically
logs_file, use_filelock=filelock
)
else: # Create new log
logs_path, _ = os.path.split(logs_file)
if not os.path.exists(logs_path):
os.makedirs(logs_path)
log = ExperimentLog.from_config_iter(self.configs, metric_fields, logs_file, use_filelock=filelock)
log.to_tsv()
return log
def __check_matching_static_configs(self):
"""Validates static config consistency between the config iterator and log.
Asserts that both the keys and values of ``self.configs.static_configs``
match those in ``self.log.static_configs``.
Raises:
AssertionError: If keys or values differ between the config
iterator and the experiment log.
"""
iter_statics = self.configs.static_configs
log_statics = self.log.static_configs
# check matching keys
ist, lst = {*iter_statics.keys()}, {*log_statics.keys()}
assert not (k := ist ^ lst), f"Found non-matching keys {k} in static config of configiter and experiement log."
# check matching values
non_match = {k: (v1, v2) for k in ist if (v1 := iter_statics[k]) != (v2 := log_statics[k])}
assert not non_match, (
f"Found non-matching values {non_match} in static config of configiter and experiement log."
)
[docs]
@staticmethod
def get_paths(exp_folder, split=None):
"""Constructs and returns file paths for configuration, log, and figure
paths based on the given experiment folder and optional split identifier.
Args:
exp_folder (str): The path to the experiment folder.
split (int, optional): The split identifier for log files. If None, the
default log file path is used. Defaults to None.
Returns:
tuple: A tuple containing:
- cfg_file (str): Path to the experiment configuration file
('exp_config.yaml').
- tsv_file (str): Path to the log file. If `split` is provided, the path
corresponds to the split-specific log file
('log_splits/split_{split}.tsv'), otherwise it defaults to 'log.tsv'.
- fig_dir (str): Path to the figure directory ('figure').
"""
cfg_file = os.path.join(exp_folder, "exp_config.yaml")
if split == None:
tsv_file = os.path.join(exp_folder, "log.tsv")
else:
tsv_file = os.path.join(exp_folder, "log_splits", f"split_{split}.tsv")
fig_dir = os.path.join(exp_folder, "figure")
return cfg_file, tsv_file, fig_dir
[docs]
def get_metric_info(self, config):
"""Retrieves metric and info dictionaries for a given configuration.
Args:
config (str): The configuration key to look up in the log.
Returns:
tuple: A tuple containing two dictionaries:
- metric_dict (dict): A dictionary of metrics associated with the given
configuration, excluding NaN scalar values.
- info_dict (dict): A dictionary of additional information extracted
from the log, containing keys present in `self.infos` and non-NaN
values.
"""
if config not in self.log:
logging.info("Log of matching config is not found. Returning empty dictionaries.")
return {}, {} # return empty dictionaries if no log is found
metric_dict = self.log[config]
info_dict = {k: v for k in self.infos if (k in metric_dict and pd.notna(v := metric_dict.pop(k)))}
metric_dict = {k: v for k, v in metric_dict.items() if not (np.isscalar(v) and pd.isna(v))}
return metric_dict, info_dict
[docs]
def update_log(self, config, status=None, **metric_dict):
"""Updates the log with the given configuration, status, and metrics.
This method loads the current log, updates it with the provided
configuration, metrics, and run information, and then saves the updated log.
Args:
config (Mapping): The current config to log.
status (str|None, optional): The status of the current run. Defaults to
None, which sets the status to a predefined running state.
**metric_dict: Metrics to log.
"""
if status == None:
status = self.__RUNNING
self.log.load_tsv()
self.log.add_result(config, **metric_dict, **self.__curr_runinfo.update_and_get(), status=status)
self.log.to_tsv()
[docs]
def run(self):
"""Executes a series of experiments based on the provided configurations
according given execution setup.
"""
logging.info("Start running experiments.")
start_t = datetime.now()
if self.filelock:
logging.info((self.log.filelock.is_locked, self.log.filelock.id))
# initially obtain filelock before running experiments
self.log.filelock.acquire()
# run experiment plans
for i, config in enumerate(self.configs):
self.log.load_tsv()
metric_dict, info_dict = self.get_metric_info(config)
# skip already executed runs
if info_dict.get("status") in {self.__RUNNING, self.__COMPLETED}:
continue
# new run info
self.__curr_runinfo = RunInfo(prev_duration=pd.to_timedelta(info_dict.get("duration", "0")))
# if config not in self.log or status==self.__FAILED
if self.configs_save:
self.update_log(config, **metric_dict, status=self.__RUNNING)
# release filelock before running experiment
if self.filelock:
self.log.filelock.release(force=True)
logging.info("###################################")
logging.info(f" Experiment count : {i + 1}/{len(self.configs)}")
logging.info("###################################")
try:
exp_func = self.exp_func
if self.timeout:
exp_func = settimeout_func(
exp_func, timeout=self.timeout - (datetime.now() - start_t).total_seconds()
)
if self.checkpoint:
metric_dict = exp_func(config, self)
else:
metric_dict = exp_func(config)
status = self.__COMPLETED
except Exception as exc:
metric_dict, _ = self.get_metric_info(config)
status = self.__FAILED
if isinstance(exc, FuncTimeoutError):
logging.error(f"Experiment timeout ({self.timeout}s) occured:")
raise exc
else:
logging.error(f"Experiment failure occured:\n{traceback.format_exc()}{exc}")
finally:
# obtain filelock before updating log
if self.filelock:
self.log.filelock.acquire()
self.update_log(config, **metric_dict, status=status)
logging.info("Saved experiment data to log.")
# release filelock after running all experiments
if self.filelock:
self.log.filelock.release(force=True)
logging.info("Complete experiments.")
[docs]
@staticmethod
def resplit_logs(exp_folder_path: str, target_split: int = 1, save_backup: bool = True):
"""Resplit splitted logs into ``target_split`` number of splits.
Args:
exp_folder_path (str): The path to the experiment folder containing the
logs.
target_split (int, optional): The number of splits to divide the logs
into. Must be greater than 0. Defaults to 1.
save_backup (bool, optional): Whether to save a backup of the merged logs
before resplitting. Defaults to True.
"""
assert target_split > 0, "Target split should be larger than 0"
cfg_file, logs_file, _ = Experiment.get_paths(exp_folder_path)
logs_folder = os.path.join(exp_folder_path, "log_splits")
# merge original log_splits
if os.path.exists(logs_folder): # if log is splitted
os.chdir(logs_folder)
base, *logs = [
ExperimentLog.from_tsv(os.path.join(logs_folder, sp_n), parse_str=False) for sp_n in glob.glob("*.tsv")
]
base.merge(*logs)
shutil.rmtree(logs_folder)
elif os.path.exists(logs_file): # if only single log file exists
base = ExperimentLog.from_tsv(logs_file, parse_str=False)
shutil.rmtree(logs_file)
# save backup
if save_backup:
base.to_tsv(os.path.join(exp_folder_path, "logs_backup.tsv"))
# resplit merged logs based on target_split
if target_split == 1:
base.to_tsv(logs_file)
elif target_split > 1:
# get configs
configs = ConfigIter(cfg_file)
for n in range(target_split):
# empty log
logs = ExperimentLog.from_exp_config(
configs.__dict__,
os.path.join(
logs_folder,
f"split_{n}.tsv",
),
base.metric_fields,
)
# resplitting nth split
cfgs_temp = copy.deepcopy(configs)
cfgs_temp.filter_iter(lambda i, _: i % target_split == n)
for cfg in track(cfgs_temp, description=f"split: {n}/{target_split}"):
if cfg in base:
metric_dict = base[cfg]
logs.add_result(cfg, **metric_dict)
logs.to_tsv()
[docs]
@classmethod
def set_log_status_as_failed(cls, exp_folder_path: str):
"""Updates the status of logs in the specified experiment folder to 'FAILED'
if their current status is 'RUNNING'
"""
_, logs_file, _ = Experiment.get_paths(exp_folder_path)
logs_folder = os.path.join(exp_folder_path, "log_splits")
# merge original log_splits
if os.path.exists(logs_folder): # if log is splitted
os.chdir(logs_folder)
paths = [os.path.join(logs_folder, sp_n) for sp_n in glob.glob("*.tsv")]
elif os.path.exists(logs_file): # if only single log file exists
paths = [logs_file]
for p in paths:
log = ExperimentLog.from_tsv(p, parse_str=False)
log.df["status"] = log.df["status"].map(lambda x: cls.__FAILED if x == cls.__RUNNING else x)
log.to_tsv()