Source code for malet.plot_utils.data_processor

from __future__ import annotations

from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, Union

import numpy as np
import pandas as pd

ValueLike = Union[Any, List[Any], Tuple[Any, ...], Set[Any], np.ndarray, pd.Index]


def _as_list(v: ValueLike) -> List[Any]:
    """Normalize a filter value into a python list."""
    if isinstance(v, list):
        return v
    if isinstance(v, (tuple, set, np.ndarray, pd.Index)):
        return list(v)
    return [v]


def _ensure_index_levels(df: pd.DataFrame, keys: Iterable[str]) -> None:
    idx_names = set(df.index.names)
    missing = set(keys) - idx_names
    if missing:
        raise KeyError(f"filt_dict keys {missing} is not in df.index.names={df.index.names}")


def _subindex_on_levels(index: pd.MultiIndex, keep_levels: Sequence[str]) -> pd.MultiIndex:
    """Return a MultiIndex containing only keep_levels, in that order."""
    drop_levels = [lvl for lvl in index.names if lvl not in set(keep_levels)]
    sub_idx = index.droplevel(drop_levels) if drop_levels else index
    if list(sub_idx.names) != list(keep_levels):
        sub_idx = sub_idx.reorder_levels(list(keep_levels))
    return sub_idx


[docs] def select_df( df: pd.DataFrame, filt_dict: Dict[str, ValueLike], *exclude_fields: str, equal: bool = True, drop: bool = False, validate: bool = True, ) -> pd.DataFrame: """Select df rows with matching values from ``filt_dict`` except ``exclude_fields``. This is a vectorized, single-pass version of the original implementation. Original behavior preserved: - Asserts that df is non-empty. - Asserts that filt_dict keys exist in df.index.names. - Validates that requested values exist in each index level. - Raises early if intermediate filtering yields an empty dataframe. - Supports ``equal`` (keep matches) and ``drop`` (drop filtered levels). Performance notes: - Builds ONE boolean mask and slices once, instead of repeated df.loc calls. - Avoids repeated DataFrame materialization inside Python loops. Args: df (pandas.DataFrame): DataFrame with MultiIndex. filt_dict (Dict[str, Any]): Mapping from index level to allowed values. exclude_fields (str): Index levels to exclude from filtering. equal (bool): If True, keep matching rows; otherwise exclude them. drop (bool): If True, drop filtered index levels. validate (bool): If True, run key/value existence checks. Returns: pandas.DataFrame: Filtered DataFrame. """ assert not df.empty, "Given dataframe is empty." if not filt_dict: return df if validate: _ensure_index_levels(df, filt_dict.keys()) filt_keys = [k for k in filt_dict if k not in set(exclude_fields)] idx = df.index mask = np.ones(len(df), dtype=bool) for i, k in enumerate(filt_keys): values = _as_list(filt_dict[k]) if validate: vs = pd.Index(idx.get_level_values(k).unique()) bad = set(values) - set(vs.tolist()) assert not bad, f"Values {bad} are not in field '{k}': {sorted(vs.tolist())}" fltr = idx.get_level_values(k).isin(values) mask &= fltr if equal else ~fltr if validate and not mask.any(): partial = {kk: filt_dict[kk] for kk in filt_keys[: i + 1]} raise AssertionError(f"Filter {k}:{values} return empty dataframe. Inspect {partial}") out = df.loc[mask] if drop and filt_keys: out = out.reset_index([*filt_keys], drop=True) return out
[docs] def homogenize_df( df: pd.DataFrame, ref_df: pd.DataFrame, filt_dict: Dict[str, ValueLike], *exclude_fields: str, validate: bool = True, ) -> pd.DataFrame: """Homogenize index values of ``df`` with reference to ``select_df(ref_df, filt_dict)``. Original intent (unchanged): - Align ``df`` so that its remaining index grid matches the grid induced by ``select_df(ref_df, filt_dict, drop=True)``. Original caveats (preserved verbatim): - grid should be complete, else some fields in filt_dict will be missing. - also, when metric in filt_dict, step and total_steps can be metric-dependent and could return empty df. Performance improvement: - Replaces per-row ``select_df`` + ``concat`` with a single vectorized MultiIndex membership test using ``isin``. Args: df (pandas.DataFrame): DataFrame to homogenize. ref_df (pandas.DataFrame): Reference DataFrame. filt_dict (Dict[str, Any]): Filter used to define the reference grid. exclude_fields (str): Index levels excluded from filtering. validate (bool): Run validation checks. Returns: pandas.DataFrame: Homogenized DataFrame. """ ref_idx = select_df(ref_df, filt_dict, *exclude_fields, drop=True, validate=validate).index if len(ref_idx) == 0: return df.iloc[0:0] keep_levels = list(ref_idx.names) sub_idx = _subindex_on_levels(df.index, keep_levels) mask = sub_idx.isin(ref_idx) return df.loc[mask]
[docs] def avgbest_df( df: pd.DataFrame, metric_field: str, avg_over: Set[str] = set(), best_over: Set[str] = set(), best_of: Dict[str, Any] = dict(), best_at_max: bool = True, validate: bool = True, ) -> pd.DataFrame: """Average over ``avg_over`` and get best result over ``best_over``. Original semantics preserved: - ``avg_over``: aggregate (mean + SEM) over these index levels. - ``best_over``: choose hyperparameter values yielding best ``metric_field``. - ``best_of``: restrict best search to a fixed subset of index values, then apply the chosen hyperparameter globally. - ``best_at_max`` controls argmax vs argmin selection. Original internal logic (preserved): ''' - aggregate index : avg_over, best_over - key index : best_of, others ''' Performance improvements: - Vectorized filtering and grouping. - No repeated slicing inside loops. - ``homogenize_df`` uses index membership instead of concat. Args: df (pandas.DataFrame): Base dataframe to operate over. metric_field (str): Metric used to select best hyperparameter. avg_over (Set[str]): MultiIndex levels to average over. best_over (Set[str]): MultiIndex levels to select best over. best_of (Dict[str, Any]): Fixed index values for best selection. best_at_max (bool): True if larger metric is better. validate (bool): Enable validation checks. Returns: pandas.DataFrame: Processed DataFrame. """ assert not df.empty, "Given dataframe is empty." if validate and metric_field not in df.columns: raise KeyError(f"metric_field='{metric_field}' not in df.columns={list(df.columns)}") df_fields = set(df.index.names) # avg over avg_over if avg_over: if validate: _ensure_index_levels(df, avg_over) df_fields -= set(avg_over) g = df.groupby([*df_fields], dropna=True, sort=False) df = g.mean(numeric_only=True) df[metric_field + "_std"] = g.sem(numeric_only=True)[metric_field] df_fields = set(df.index.names) # best result over best_over if best_over: if validate: _ensure_index_levels(df, best_over) _ensure_index_levels(df, best_of.keys()) df_fields -= set(best_over) best_df = select_df(df, best_of, validate=validate) if df_fields: g = best_df.groupby([*df_fields], dropna=True, sort=False)[metric_field] idx = g.idxmax() if best_at_max else g.idxmin() best_df = best_df.loc[idx] else: idx = best_df[metric_field].idxmax() if best_at_max else best_df[metric_field].idxmin() best_df = best_df.loc[[idx]] df_fields -= set(best_of) df = homogenize_df(df, best_df, best_of, validate=validate) return df