Source code for malet.plot

import os
import re
from functools import partial
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import yaml
from absl import app, flags
from matplotlib import animation, cm, colors, lines, style
from rich import print
from rich.align import Align
from rich.columns import Columns
from rich.panel import Panel
from rich.progress import Progress
from rich.status import Status

from .experiment import Experiment, ExperimentLog
from .plot_utils.data_processor import avgbest_df, homogenize_df, select_df
from .plot_utils.plot_drawer import (
    ax_draw_bar,
    ax_draw_best_stared_curve,
    ax_draw_curve,
    ax_draw_heatmap,
    ax_draw_scatter,
    ax_draw_scatter_heat,
)
from .plot_utils.utils import ax_styler, create_dir, default_style, merge_dict
from .utils import df2richtable, get_wandb_sweep_exp_dir, str2value

FLAGS = flags.FLAGS


[docs] def get_plot_config(plot_config: dict, plot_args: dict): """Retrieve and merge the plot configuration based on the provided mode and arguments. Args: plot_config (dict): A dictionary containing configuration settings for different plot modes. It should include a 'default_style' key for default configurations. plot_args (dict): A dictionary containing arguments for the plot, including a 'mode' key that specifies the desired plot mode. Returns: dict: A merged dictionary containing the final plot configuration. Notes: - If the mode in `plot_args` does not contain a hyphen ('-'), it is treated as an alias mode. In this case, the function retrieves the base configuration for the alias mode, merges it with `plot_args` and the default style, and then merges it with the alias-specific configuration. - If the mode contains a hyphen, the function directly merges `plot_args` with the mode-specific configuration. """ assert plot_args["mode"] in plot_config, f"Mode: {plot_args['mode']} does not exist." alias_mode = "-" not in plot_args["mode"] p_cfg = plot_config[plot_args["mode"]] if alias_mode: p_cfg_base = plot_config.get(p_cfg["mode"], dict()) p_cfg_base = merge_dict(p_cfg_base, plot_args) p_cfg_base = merge_dict(p_cfg_base, plot_config["default_style"]) return merge_dict(p_cfg, p_cfg_base) else: return {**plot_args, **p_cfg}
def __save_name_builder(pflt, pmlf, pcrf, pcfg, pani, save_name=""): """Builds a descriptive save filename from plot configuration parameters. Args: pflt: Filter dictionary mapping field names to filter values. pmlf: List of multi-line field names. pcrf: List of column-row field names for multi-plot layout. pcfg: Plot configuration dictionary containing best_ref and best_at_max settings. pani: Animation field name, or empty string if no animation. save_name: Optional prefix for the resulting filename. Returns: str: Constructed filename string with encoded plot settings. """ sn = [] if pmlf: sn.append(f"mlf({'-'.join(pmlf)})") if pcrf: sn.append(f"crf({'-'.join(pcrf)})") if pflt: sn.append(f"flt({'-'.join(['_'.join([k, *v]) for k, v in pflt.items()])})") if pani: sn.append(f"ani({pani})") fields = ["x_fields", "metric_field", "ml_fields"] if any([pcfg[f"best_ref_{k}"] for k in fields]): sn.append(f"bref({pcfg['best_ref_x_fields']}, {pcfg['best_ref_metric_field']}, {pcfg['best_ref_ml_fields']})") sn.append("max" if pcfg["best_at_max"] else "-min") return save_name + "-".join(sn)
[docs] def draw_metric(tsv_file, plot_config, save_name="", preprcs_df=lambda *x: x): """Some rules step: - Step in [filter] and [best_ref_x_fields] (when step is in x_fields) can have special values {'last', 'best'} - if step is not in x_fields and filter, the last step is chosen - if step is in filter, best_ref_x_fields is automatically set to the last step metric: - (if metric is in multi_line_fields, best_ref_metric_field is automatically set to the best metric) """ pcfg = plot_config ############################# Preprocess plot mode ############################# # Parse mode string mode, x_fields, metrics = pcfg["mode"].split("-") # ex) {sam}-{epoch}-{train_loss} x_fields = xs if (xs := x_fields.split(" ")) != [""] else [] metrics = metrics.split(" ") pflt, pcrf, pmlf, pani = map(pcfg.get, ["filter", "multi_plot_fields", "multi_line_fields", "animation_field"]) pflt = { fk: fvs for fk, *fvs in map(lambda flt: re.split("(?<!,) ", flt.strip()), pflt.split(" / ")) if fk } # split ' ' except ', ' # ---Set default pmlf if not pmlf: pcfg["ax_style"].pop("legend", None) # ---preprocess best_ref_x_fields and automatically set best_ref_x_field of step pcfg["best_ref_x_fields"] = [*map(str2value, pcfg["best_ref_x_fields"])] if "step" in x_fields and not pcfg["best_ref_x_fields"]: pcfg["best_ref_x_fields"] = ["" for _ in x_fields] pcfg["best_ref_x_fields"][x_fields.index("step")] = "last" # choose plot mode if mode in {"curve", "curve_best", "bar"}: assert len(x_fields) == 1, ( f"Number of x_fields shoud be {1} when using curve mode, but you passed {len(x_fields)}." ) assert len(metrics) == 1, f"Number of metric shoud be {1} when using curve mode, but you passed {metrics}." assert len(pmlf) <= 3, f"Number of multi_line_fields should be less than {3}, but you passed {len(pmlf)}" ax_draw = {"curve": ax_draw_curve, "curve_best": ax_draw_best_stared_curve, "bar": ax_draw_bar}[mode] ax_draw = partial( ax_draw, annotate=pcfg["annotate"], annotate_field=pcfg["annotate_field"], best_at_max=pcfg["best_at_max"] ) x_label, y_label = (f.replace("_", " ").capitalize() for f in (x_fields[0], metrics[0])) elif mode == "heatmap": assert len(x_fields) == 2, ( f"Number of x_fields shoud be {2} when using heatmap mode, but you passed {len(x_fields)}." ) assert len(metrics) == 1, f"Number of metric shoud be {1} when using heatmap mode, but you passed {metrics}." assert not pmlf, f"No multi_line_fields are allowed in heatmap mode, but you passed {pmlf}." ax_draw = partial( ax_draw_heatmap, annotate=pcfg["annotate"], annotate_field=pcfg["annotate_field"], best_at_max=pcfg["best_at_max"], ) x_label, y_label = (f.replace("_", " ").capitalize() for f in x_fields) elif mode == "scatter": assert not x_fields, f"No x_fields are allowed in scatter mode, but you passed {x_fields}." assert len(metrics) == 2, f"Number of metric shoud be {2} when using scatter mode, but you passed {metrics}." assert len(pmlf) <= 2, f"Number of multi_line_fields should be less than {2}, but you passed {len(pmlf)}" ax_draw = partial(ax_draw_scatter, y_fields=metrics) x_label, y_label = (f.replace("_", " ").capitalize() for f in metrics) elif mode == "scatter_heat": assert not x_fields, f"No x_fields are allowed in scatter_heat mode, but you passed {x_fields}." assert len(metrics) == 3, ( f"Number of metric shoud be {3} when using scatter_heat mode, but you passed {metrics}." ) assert len(pmlf) <= 1, f"Number of multi_line_fields should be less than {1}, but you passed {len(pmlf)}" ax_draw = partial(ax_draw_scatter_heat, y_fields=metrics) x_label, y_label = (f.replace("_", " ").capitalize() for f in metrics[:-1]) # build save name save_name = __save_name_builder(pflt, pmlf, pcrf, pcfg, pani, save_name=save_name) with Status("") as status: ############################# Get and filter dataframe ############################# status.update(f"Loading {tsv_file}") status.start() # get dataframe, drop unused metrics for efficient process log = ExperimentLog.from_tsv(tsv_file) post_melt_k = {"step", "total_steps", "metric"} assert all(x in log.df.index.names for x in x_fields if x not in post_melt_k), ( f"X-field {[x for x in x_fields if x not in log.df.index.names]} not in log. Choose between {set(log.df.index.names) | post_melt_k}" ) assert all(kk in log.df.index.names for k in pflt if (kk := k[:-1] if "!" in k else k) not in post_melt_k), ( f"Filter keys {[k for k in pflt if k not in log.df.index.names]} not in log. Choose between {set(log.df.index.names) | post_melt_k}" ) assert all(k in log.df.index.names for k in pcrf if k not in post_melt_k), ( f"Column-row fields {[k for k in pcrf if k not in log.df.index.names]} not in log. Choose between {set(log.df.index.names) | post_melt_k}" ) assert all(k in log.df.index.names for k in pmlf if k not in post_melt_k), ( f"Multi-line (style) fields {[k for k in pmlf if k not in log.df.index.names]} not in log. Choose between {set(log.df.index.names) | post_melt_k}" ) assert not pani or pani in {*log.df.index.names} | post_melt_k, ( f"Animate field {pani} not in log. Choose between {set(log.df.index.names) | post_melt_k}" ) assert all(m in log.df for m in metrics), ( f"Metric {[m for m in metrics if m not in log.df]} not in log. Choose between {list(log.df)}" ) # --- initial filter for df according to FLAGS.filter (except step and metric) if pflt: log.df = select_df( log.df, {fk: [*map(str2value, fvs)] for fk, fvs in pflt.items() if fk[-1] != "!" and fk not in post_melt_k}, ) log.df = select_df( log.df, { fk[:-1]: [*map(str2value, fvs)] for fk, fvs in pflt.items() if fk[-1] == "!" and fk[:-1] not in post_melt_k }, equal=False, ) # --- melt and explode metric in log.df if "metric" not in x_fields + pmlf + pcrf: log.df = log.df.drop(list(set(log.df) - {*metrics, pcfg["best_ref_metric_field"]}), axis=1) df = log.melt_and_explode_metric( step=None if (("step" in {*x_fields, pani}) or (pflt.get("step", "last") != "last")) else -1, dropna=(mode not in {"scatter", "scatter_heat"}), ) assert not df.empty, ( f"Metrics {metrics}" + ( f" and best_ref_metric_field {pcfg['best_ref_metric_field']} are" if pcfg["best_ref_metric_field"] else " is" ) + f" NaN in given dataframe: \n{log.df}" ) # ---filter df according to FLAGS.filter step and metrics if pflt: pflt = { k: v for k, v in pflt.items() if (k, v) != ("step", ["best"]) } # Let `avgbest_df` handle 'best' step, remove from pflt e_rng = lambda fvs: ( [*range(*map(int, fvs[0].split(":")))] if (len(fvs) == 1 and ":" in fvs[0]) else fvs ) # CNG 'a:b' step filter later df = select_df( df, {fk: [*map(str2value, e_rng(fvs))] for fk, fvs in pflt.items() if fk[-1] != "!" and fk in post_melt_k}, ) df = select_df( df, { fk[:-1]: [*map(str2value, e_rng(fvs))] for fk, fvs in pflt.items() if fk[-1] == "!" and fk[:-1] in post_melt_k }, equal=False, ) status.stop() ############################# Prepare dataframe ############################# specified_field = {k for k in {*df.index.names} if len({*df.index.get_level_values(k)}) == 1} if mode in {"scatter", "scatter_heat"}: key_field = {*df.index.names} - specified_field avg_field = optimized_field = set() else: key_field = {*x_fields, *pmlf, *pcrf, *pani.split(), "metric"} - specified_field avg_field = ({"seed"} if "seed" in df.index.names else set()) - specified_field - key_field optimized_field = {*df.index.names} - specified_field - key_field - avg_field # Report selected plot configs and field handling statistics print( "\n\n", Align( Columns( [ Panel( "\n".join( [ f"- {k}: {pcfg[k]}" for k in ( "mode", "multi_line_fields", "multi_plot_fields", "filter", "best_at_max", "best_ref_x_fields", "best_ref_metric_field", "best_ref_ml_fields", ) if pcfg[k] ] ), title="Plot configuration", padding=(1, 3), ), Panel( f"- Key field (has multiple values): {[*key_field]} ({len(key_field)})\n" + f"- Specified field: {[*specified_field]} ({len(specified_field)})\n" + f"- Averaged field: {[*avg_field]} ({len(avg_field)})\n" + f"- Optimized field: {[*optimized_field]} ({len(optimized_field)})", title="Field handling statistics", padding=(1, 3), ), ] ), align="center", ), ) status.update("Processing dataframe...") status.start() if mode in {"scatter", "scatter_heat"}: # no processing best_df = df else: # change field name and avg over seed and get best result over best_over best_of = {} if pcfg["best_ref_x_fields"]: # same hyperparameter over all points in line best_of.update(dict(zip(x_fields, pcfg["best_ref_x_fields"]))) if pcfg[ "best_ref_metric_field" ]: # Optimize in terms of reference metric, and apply those hyperparameters to original best_of["metric"] = pcfg["best_ref_metric_field"] if pcfg["best_ref_ml_fields"]: # same hyperparameter over all line in multi_line_fields best_of.update(dict(zip(pmlf, pcfg["best_ref_ml_fields"]))) # do best_of operation on 'step' and 'metric' seperatly after avgbest_df sm_bestof = {} if best_of.get("step") in {"last", "best"}: sm_bestof["step"] = best_of.pop("step") if sm_bestof["step"] == "best": # add step to best_over to compute optimal step optimized_field |= {"step"} if "metric" in best_of: sm_bestof["metric"] = best_of.pop("metric") # avgbest without 'step' and 'metric' in best_of best_df = avgbest_df( df, "metric_value", avg_over=avg_field, best_over=optimized_field, best_of=best_of, best_at_max=pcfg["best_at_max"], ) # process 'step' and 'metric' if sm_bestof: avg_df = avgbest_df(df, "metric_value", avg_over=avg_field) sm_df = best_df if "step" in sm_bestof: if sm_bestof["step"] == "last": sm_df = sm_df.loc[ sm_df.index.get_level_values("step") == sm_df.index.get_level_values("total_steps") ] # remove duplicates over total_step (pick best performing, might change later) # e.g., step 50 best for config with total_step 50, and likewise for step 100, pick step 100 if better than 50 # in other words, train run stopped at step 50 that performed better than other experiments at step 50 might remain and cause problems. sm_df = avgbest_df(sm_df, "metric_value", best_over=optimized_field | {"step"}) sm_df = sm_df.reset_index(["step"], drop=True) if "metric" in sm_bestof: sm_df = select_df(sm_df, {"metric": sm_bestof["metric"]}, drop=True) # homogenize best_df with sm_df, exclude 'total_steps' best_df = homogenize_df(avg_df, sm_df, {}, "total_steps") # check if there is any duplicate key_field configs in the best_df assert ( not best_df.reset_index(list({*best_df.index.names} - key_field - {"metric"})).index.duplicated().any() ), "Duplicate values in found in dataframe: \n" + str( best_df.reset_index(best_df.index.names).groupby([*key_field]).size().loc[lambda x: x > 1] ) status.stop() ############################# Print best_df ############################# show_field_order = ( (["metric"] if "metric" not in specified_field else []) + sorted(key_field - {"metric", "step", "seed"}) + sorted(optimized_field - {"metric", "step", "seed"}) + (["seed"] if not avg_field else []) + (["step"] if "step" not in specified_field else []) ) show_df = ( best_df.reset_index() .reindex(show_field_order + ["metric_value", "metric_value_std"], axis=1) .sort_values(by=show_field_order, ignore_index=True) ) print( "\n", Align("Metric Summary Table", align="center"), Align( Columns( [ Panel( "\n".join( [ f"- {k:{max(map(len, specified_field))}s} : {best_df.index.get_level_values(k)[0]}" for k in sorted(specified_field) ] ), padding=(1, 3), ), df2richtable(show_df, max_row_len=50), ] ), align="center", ), ) if "metric" not in x_fields + pmlf + pcrf: best_df = select_df(best_df, {"metric": metrics}) ############################# Plot ############################# status.update("Prepare plot...") status.start() pford = pcfg["field_orders"] pford = { fk: fvs for fk, *fvs in map(lambda od: re.split("(?<!,) ", od.strip()), pford.split(" / ")) if fk } # split ' ' except ', ' assert not ( nmtch := { k: (set(v), set(best_df.index.get_level_values(k))) for k, v in pford.items() if set(v) != set(best_df.index.get_level_values(k)) } ), f"Field order does not match with the dataframe: {nmtch} (field_order, dataframe)" get_field_values = lambda f: ( [""] if not f else pford[f] if f in pford else sorted(set(best_df.index.get_level_values(f)), key=str2value) ) # col-row, animation, and multi_line fields col_vs, row_vs = map(get_field_values, (pcrfn := pcrf + [""] * (2 - len(pcrf)))) ani_vs = get_field_values(pani) mlines = [*product(*map(get_field_values, pmlf))] # scale axis size according to number of col and row plots if isinstance(fig_size := pcfg["ax_style"]["fig_size"], (float, int)): fig_size = [fig_size] * 2 pcfg["ax_style"]["fig_size"] = [p * l for p, l in zip([len(col_vs), len(row_vs)], fig_size)] legend_style = pcfg["ax_style"].pop("legend", [{}])[0] # set style types per mode has_cbar = False if mode in {"curve", "curve_best"}: style_types = ["color", "linestyle", "marker"] # set unif_xticks for curve if pcfg["xscale"] == "unif": pcfg["ax_style"].pop("xscale", None) pcfg["line_style"]["unif_xticks"] = True elif mode == "bar": style_types = ["color"] elif mode == "heatmap": style_types = [] has_cbar = True elif mode == "scatter": style_types = ["color", "marker"] elif mode == "scatter_heat": style_types = ["marker"] has_cbar = True if has_cbar: norm_df = best_df[best_df.index.get_level_values("metric") == metrics[-1]] if pcfg["ax_style"].pop("zscale", [{}])[0] == "log": pcfg["line_style"]["norm"] = colors.LogNorm( norm_df["metric_value"].min() + 1e-15, norm_df["metric_value"].max() ) best_df[best_df.index.get_level_values("metric") == metrics[-1]] += 1e-15 else: pcfg["line_style"]["norm"] = colors.Normalize( norm_df["metric_value"].min(), norm_df["metric_value"].max() ) pcfg["line_style"]["cmap"] = "magma" if pcfg["colors"][0] == "default" else pcfg["colors"][0] rep, skp, sft = map(int, pcfg["colors_rep_skip_shift"]) style_dict = { "color": [ c for i, c in enumerate( sum(map(sns.color_palette, [None if c == "default" else c for c in pcfg["colors"]]), []) * rep ) if not (i - sft) % (skp + 1) ], "marker": ["D", "o", ">", "X", "s", "v", "^", "<", "p", "P", "*", "+", "x", "h", "H", "|", "_"], "linestyle": ["-", ":", "-.", "--"] * 3, } # use continuous colormap when the color multi-line field has numeric values if "color" in style_types and pmlf: color_idx = style_types.index("color") if color_idx < len(pmlf): color_field_vals = get_field_values(pmlf[color_idx]) numeric_vals = [str2value(v) for v in color_field_vals] if all(isinstance(v, (int, float)) for v in numeric_vals) and len(numeric_vals) > 1: cmap_name = "viridis" if pcfg["colors"][0] == "default" else pcfg["colors"][0] cont_cmap = plt.cm.get_cmap(cmap_name) vmin, vmax = min(numeric_vals), max(numeric_vals) style_dict["color"] = [cont_cmap((v - vmin) / (vmax - vmin)) for v in numeric_vals] styles = [ *product( *[[*style_dict[s]][: len(set(k))] for s, k in zip(style_types, map(df.index.get_level_values, pmlf))] ) ] # prepare plot, set figure and axes for multiple plots, and leave place for cmap fig, axs = plt.subplots( len(row_vs), len(col_vs) + int(has_cbar), sharex=True, sharey=True, gridspec_kw={"width_ratios": [1] * len(col_vs) + ([0.1] if has_cbar else [])}, ) # for 2d indexing of axs if isinstance(axs, plt.Axes): axs = np.array([[axs]]) for _ in range(2 - len(axs.shape)): axs = axs[None] status.stop() with Progress() as progress: total_prgs, init_prgs = len(ani_vs) * len(col_vs) * len(row_vs) * len(mlines), 0 task = progress.add_task(f"[green]Drawing plot... [{init_prgs}/{total_prgs}]", total=total_prgs) ani_artists = [] for aniv in ani_vs: ani_artists.append([]) for ci, col_v in enumerate(col_vs): for ri, row_v in enumerate(row_vs[::-1]): ax = axs[ri, ci] for mlvs, st in zip(mlines, styles): try: p_df = select_df( best_df, {k: v for k, v in zip([*pmlf, *pcrfn, pani], [*mlvs, col_v, row_v, aniv]) if k}, *x_fields, ) except: # for log with incomplete grid continue legend = ",".join( [ (v if isinstance(v, str) else f"{f} {v}").replace("_", " ") for f, v in zip(pmlf, mlvs) ] ) p_df, legend, mlvs = preprcs_df(p_df, legend, mlvs) # remove unnessacery fields if mode not in {"scatter", "scatter_heat"}: p_df = p_df.reset_index([*(set(p_df.index.names) - set(x_fields))], drop=False) if len(x_fields) > 1: p_df = p_df.reorder_levels(x_fields) p_df = p_df.sort_index(key=lambda s: [*map(str2value, s)]).reindex( ["metric_value", *(set(p_df) - {"metric_value"})], axis=1 ) # set line style for stp, s in zip(style_types, st): pcfg["line_style"][stp] = s ani_artists[-1] += ax_draw(ax, p_df, label=legend, **pcfg["line_style"]) init_prgs += 1 progress.update( task, description=f"[green]Drawing plot... [{init_prgs}/{total_prgs}]", advance=1 ) # add current animation field value as text on first axes if pani and ri + ci == 0: ani_artists[-1].append( ax.text( 0.0, 1.01, f"{pani}={aniv}", fontsize=pcfg["font_size"], ha="left", va="bottom", transform=ax.transAxes, ) ) ax_styler(ax, **pcfg["ax_style"]) pcfg["line_style"].pop("unif_xticks", None) if has_cbar: c_cmap, c_norm = (pcfg["line_style"].pop(k) for k in ["cmap", "norm"]) for ax in axs[:, -1]: ax.remove() cax = fig.add_subplot(axs[0, 0].get_gridspec()[:, -1]) cbar = fig.colorbar(cm.ScalarMappable(norm=c_norm, cmap=c_cmap), cax=cax) cbar.ax.tick_params(labelsize=pcfg["font_size"]) z_label = pcfg["zlabel"] or metrics[-1].replace("_", " ").capitalize() cbar.ax.set_ylabel(z_label, fontsize=pcfg["font_size"]) # add a big axis, hide frame, tick, and tick label big_ax = fig.add_subplot(111, frameon=False) big_ax.tick_params(labelcolor="none", which="both", top=False, bottom=False, left=False, right=False) # set title and x, y axis labels if title := pcfg.get("title"): big_ax.set_title(title, fontsize=pcfg["font_size"]) x_label = pcfg["xlabel"] or x_label y_label = pcfg["ylabel"] or y_label if len(pcrf) > 0: fig.supxlabel(x_label, fontsize=pcfg["font_size"]) for ci, col_v in enumerate(col_vs): ax = axs[-1, ci] ax.set_xlabel(f"{pcrf[0]}={col_v}", size=pcfg["font_size"]) else: ax.set_xlabel(x_label, fontsize=pcfg["font_size"]) if len(pcrf) > 1: fig.supylabel(y_label, fontsize=pcfg["font_size"]) for ri, row_v in enumerate(row_vs[::-1]): axs[ri, 0].set_ylabel(f"{pcrf[1]}={row_v}", size=pcfg["font_size"]) else: ax = axs[0, 0] ax.set_ylabel(y_label, fontsize=pcfg["font_size"]) # set legend, improve later if pmlf: ax = axs[0, 0] base_styles = {"color": "gray", "marker": "", "linestyle": "-"} first_styles = { k: v[0] for k, v in style_dict.items() if k in (style_types[len(pmlf) :] if len(pmlf) < 3 else []) } # when plmf is not full, legend are style with the first values of unused style_types is_wide = len(pmlf) < 3 max_row = max([len(set(df.index.get_level_values(k))) for k in pmlf]) legendlines, legendlabels = [], [] for s, k in zip(style_types, pmlf): extra = {"linewidth": 0} if s == "marker" else {} vs = sorted(set(df.index.get_level_values(k))) vs = [v.replace("_", " ").capitalize() if isinstance(v, str) else v for v in vs] legendlines += ( [lines.Line2D([], [], alpha=0)] + [ lines.Line2D([], [], **{**pcfg["line_style"], **base_styles, **first_styles, **extra, **{s: ss}}) for ss in [*style_dict[s]][: len(vs)] ] + ([lines.Line2D([], [], alpha=0) for _ in range(max_row - len(vs))] if is_wide else []) ) legendlabels += [f"[{k.replace('_', ' ').capitalize()}]", *vs] + ( ["" for _ in range(max_row - len(vs))] if is_wide else [] ) ax.legend( handles=legendlines, labels=legendlabels, **legend_style, # **pcfg['ax_style'].pop('legend', [{}])[0], ncol=len(pmlf) if is_wide else 1, columnspacing=0.8, handlelength=None if len(pmlf) == 1 else 1.5, ) return best_df, fig, ani_artists, save_name
[docs] def plot_run(flag_dict, preprcs_df=lambda *x: x): plot_config = {**default_style, **flag_dict} if flag_dict["plot_config"] != "": with open(flag_dict["plot_config"]) as f: plot_config = yaml.safe_load(f.read()) plot_config = get_plot_config(plot_config, flag_dict) # set ax_style related arguments ax_st = plot_config["ax_style"] if fig_size := flag_dict.pop("fig_size"): if len(fig_size) == 1: fig_size = fig_size * 2 fig_size = [*map(float, fig_size)] ax_st["fig_size"] = fig_size if xscale := flag_dict.pop("xscale"): ax_st["xscale"] = [xscale, {}] if yscale := flag_dict.pop("yscale"): ax_st["yscale"] = [yscale, {}] if zscale := flag_dict.pop("zscale"): ax_st["zscale"] = [zscale, {}] # set ax_style related arguments l_st = plot_config["line_style"] if msize := flag_dict.pop("marker_size"): l_st["markersize"] = msize # set style style.use(plot_config["style"]) # get paths if plot_config["wandb_sweep_id"]: base_dir = plot_config["exp_folder"] wandb_sweep_id = plot_config["wandb_sweep_id"] assert wandb_sweep_id.count("/") == 2, ( f"wandb_sweep_id should be in the form of entity/project/sweep_id, but you passed {wandb_sweep_id}." ) entity, project, sweep_id = wandb_sweep_id.split("/") plot_config["exp_folder"] = get_wandb_sweep_exp_dir(base_dir, entity, project, sweep_id) if not os.path.exists(plot_config["exp_folder"]): create_dir(plot_config["exp_folder"]) _, tsv_file, fig_dir = Experiment.get_paths(plot_config["exp_folder"]) if not os.path.exists(tsv_file): log = ExperimentLog.from_wandb_sweep(entity, project, [sweep_id], tsv_file, get_all_steps=True) log.to_tsv() _, tsv_file, fig_dir = Experiment.get_paths(plot_config["exp_folder"]) save_dir = os.path.join(fig_dir, plot_config["mode"]) assert plot_config["mode"].split("-")[0] in {"curve", "curve_best", "bar", "heatmap", "scatter", "scatter_heat"}, ( f"Mode: {plot_config['mode']} does not exist." ) df, fig, ani_artists, save_name = draw_metric(tsv_file, plot_config, preprcs_df=preprcs_df) save_name = save_name.replace("/", "_") fig.tight_layout() # save figure and dataframe if not os.path.exists(save_dir): create_dir(save_dir) if plot_config["animation_field"]: with Progress() as progress: task = progress.add_task(f"[green]Saving animation frames [0/{len(ani_artists)}]", total=len(ani_artists)) ani = animation.ArtistAnimation(fig, ani_artists, interval=400) gifwriter = animation.PillowWriter(fps=10) ani.save( os.path.join(save_dir, f"{save_name}.gif"), writer=gifwriter, progress_callback=lambda i, n: progress.update( task, description=f"[green]Saving animation frames [{i + 1}/{len(ani_artists)}]", advance=1 ), ) else: fig.savefig(os.path.join(save_dir, f"{save_name}.pdf"), format="pdf") df.to_csv(os.path.join(save_dir, f"{save_name}.tsv"), sep="\t") img_tail = "gif" if plot_config["animation_field"] else "pdf" print( "\n", Align( Panel( f"save {{plot, table}} at: {fig_dir}/[bold blue_violet]{plot_config['mode']}[/bold blue_violet]/[bold spring_green1]{save_name}[/bold spring_green1].{{{img_tail}, tsv}}", title="Plot complete", padding=(1, 3), expand=False, ), align="center", ), "\n", )
[docs] def run(argv, preprcs_df): if len(argv) > 2: raise app.UsageError("Too many command-line arguments.") # Preprocess plot_config flag_dict = FLAGS.flag_values_dict() plot_run(flag_dict, preprcs_df=preprcs_df)
[docs] def main(preprcs_df=lambda *x: x): flags.DEFINE_string("exp_folder", "", "Experiment folder path (plot save paths for wandb).") flags.DEFINE_string("wandb_sweep_id", "", "wandb entity/project/sweep_id.") flags.DEFINE_string("mode", "curve-step-val_loss", "Plot mode.") # data processing flags.DEFINE_string("filter", "", "Filter values. (e.g., 'step 0:100 / lr 0.01 0.1 / wd! 0.0')") flags.DEFINE_spaceseplist("multi_line_fields", "", "List of fields to plot multiple lines over.") flags.DEFINE_spaceseplist("multi_plot_fields", "", "Column and row fields for multiple plots.") flags.DEFINE_string("animation_field", "", "Fields to animate over.") flags.DEFINE_string("field_orders", "", "Order of string fields (e.g., 'lr_schedule constant linear cosine')") flags.DEFINE_spaceseplist("best_ref_x_fields", "", "Reference x_field-values to evaluate optimal hyperparameters.") flags.DEFINE_string( "best_ref_metric_field", "", "Reference metric_field-values to evaluate optimal hyperparameters." ) flags.DEFINE_spaceseplist( "best_ref_ml_fields", "", "Reference multi_line_fields-value to evaluate optimal hyperparameters." ) flags.DEFINE_bool("best_at_max", False, "Whether the bese metric value is the maximum value.") # plot estehtics flags.DEFINE_string("plot_config", "", "Yaml file path for various plot setups.") flags.DEFINE_string("style", "default", "Matplotlib style.") flags.DEFINE_spaceseplist( "colors", ["default"], "color type (e.g., default, light:#9467bd, Blues, rocket, crest, magma)." ) flags.DEFINE_spaceseplist( "colors_rep_skip_shift", [1, 0, 0], "Skip n colors and shift color list ([c for i, c in enumerate(cs) if not i%(skp+1)+sft]).", ) flags.DEFINE_spaceseplist("fig_size", "", "Figure size.") flags.DEFINE_string("xscale", "", "Scale of x-axis (linear, log, unif).") flags.DEFINE_string("yscale", "", "Scale of y-axis (linear, log).") flags.DEFINE_string("zscale", "", "Scale of z-axis or colorbar (linear, log).") flags.DEFINE_string("title", "", "Title.") flags.DEFINE_bool("annotate", True, "Run multiple plot according to given config.") flags.DEFINE_spaceseplist("annotate_field", "", "List of fields to include in annotation.") flags.DEFINE_string("xlabel", "", "Label of x-axis.") flags.DEFINE_string("ylabel", "", "Label of y-axis.") flags.DEFINE_string("zlabel", "", "Label of z-axis or colorbar.") flags.DEFINE_integer("font_size", 22, "Font size of title and label.") flags.DEFINE_float("marker_size", 10, "Size of marker.") flags.mark_flag_as_required("exp_folder") app.run(partial(run, preprcs_df=preprcs_df))
if __name__ == "__main__": main()