Source code for malet.plot_utils.plot_drawer

from typing import Literal

import numpy as np
import pandas as pd
from matplotlib.axes import Axes


[docs] def ax_draw_curve( ax: Axes, df: pd.DataFrame, label: str, annotate=True, annotate_field=[], std_plot: Literal["none", "fill", "bar"] = "fill", unif_xticks=False, color="orange", linewidth=4, marker="D", markersize=10, markevery=20, linestyle="-", **_, ) -> Axes: """Draws curve of y_field over arbitrary x_field setted as index of the dataframe. If there is column 'y_field_sdv' is in dataframe, draws in errorbar or fill_between depending on ``sdv_bar_plot`` """ assert std_plot in {"bar", "fill", "none"}, 'std_plot should be one of {"bar","fill","none"}' y_field = list(df)[0] x_values, metric_values = map(np.array, zip(*dict(df[y_field]).items())) assert not isinstance(metric_values[0], pd.Series), f"y_field should have only {1} values for each index." if len(x_values) > 100: marker = None if unif_xticks: ax.locator_params(tight=True, nbins=5) elif len(x_values) > 20: markevery = 20 if unif_xticks: ax.locator_params(tight=True, nbins=5) else: markevery = 1 tick_values = x_values artists = [] if len(tick_values) == 1: artists.append(ax.axhline(metric_values, linewidth=linewidth, color=color, label=label)) if f"{y_field}_std" in df: metric_std = float(df[f"{y_field}_std"]) artists.append( ax.axhspan(metric_values[0] + metric_std, metric_values[0] - metric_std, alpha=0.3, color=color) ) else: if unif_xticks: tick_values = np.arange(len(x_values)) ax.set_xticks(tick_values, x_values, fontsize=10) # , rotation=45) artists += ax.plot( tick_values, metric_values, label=label, color=color, linewidth=linewidth, marker=marker, markersize=markersize, markevery=markevery, linestyle=linestyle, ) if len(x_values) % markevery != 0: artists += ax.plot(tick_values[-1], metric_values[-1], color=color, marker=marker, markersize=markersize) if f"{y_field}_std" in df: x_values, metric_std = map(np.array, zip(*dict(df[f"{y_field}_std"]).items())) if std_plot == "bar": artists.append(ax.errorbar(tick_values, metric_values, yerr=metric_std, color=color, elinewidth=3)) elif std_plot == "fill": artists.append( ax.fill_between( tick_values, metric_values + metric_std, metric_values - metric_std, alpha=0.3, color=color ) ) if annotate: assert not ( f := set(annotate_field) - (a := set(df) - {"total_steps", "step", y_field, f"{y_field}_std"}) ), f"Annotation field: {f} are not in dataframe field: {a}" annotate_field = set(annotate_field) & a abv = lambda s: ( "".join([i[0] for i in s.split("_")] if "_" in s else [s[0]] + [i for i in s[1:] if i not in "aeiou"]) if len(s) > 3 else s ) abv_annot = [*map(abv, annotate_field)] for i, (x, y, t) in enumerate(zip(x_values, metric_values, tick_values)): if i % markevery and i != len(x_values) - 1: continue txt = "\n".join( [ f"{y:.5f}" + ( rf"$\pm${metric_std[i]:.5f}" if (f"{y_field}_std" in df and pd.notna(metric_std[i])) else "" ), str(x), ] + [f"{i}={df.loc[x][j]}" for i, j in zip(abv_annot, annotate_field)] ) artists.append(ax.annotate(txt, (t, y), textcoords="offset points", xytext=(0, 10), ha="center")) ax.tick_params(axis="both", which="major", labelsize=17, direction="in", length=5) return artists
[docs] def ax_draw_best_stared_curve( ax: Axes, df: pd.DataFrame, label: str, annotate=True, annotate_field=[], std_plot: Literal["none", "fill", "bar"] = "fill", best_at_max=True, unif_xticks=False, color="orange", linewidth=4, marker="D", markersize=10, markevery=20, linestyle="-", **_, ) -> Axes: """Draws curve of y_field over arbitrary x_field setted as index of the dataframe. If there is column 'y_field_sdv' is in dataframe, draws in errorbar or fill_between depending on ``sdv_bar_plot`` """ assert std_plot in {"bar", "fill", "none"}, 'std_plot should be one of {"bar","fill","none"}' y_field = list(df)[0] x_values, metric_values = map(np.array, zip(*dict(df[y_field]).items())) assert not isinstance(metric_values[0], pd.Series), f"y_field should have only {1} values for each index." if len(x_values) > 100: marker = None if unif_xticks: ax.locator_params(tight=True, nbins=5) elif len(x_values) > 20: markevery = 20 if unif_xticks: ax.locator_params(tight=True, nbins=5) else: markevery = 1 tick_values = x_values artists = [] if len(tick_values) == 1: artists.append(ax.axhline(metric_values, linewidth=linewidth, color=color, label=label)) if f"{y_field}_std" in df: metric_std = float(df[f"{y_field}_std"]) artists.append( ax.axhspan(metric_values[0] + metric_std, metric_values[0] - metric_std, alpha=0.3, color=color) ) else: if unif_xticks: tick_values = np.arange(len(x_values)) ax.set_xticks(tick_values, x_values, fontsize=10, rotation=45) ax.plot(tick_values, metric_values, color=color, linewidth=linewidth) if f"{y_field}_std" in df: x_values, metric_std = map(np.array, zip(*dict(df[f"{y_field}_std"]).items())) if std_plot == "bar": artists.append(ax.errorbar(tick_values, metric_values, yerr=metric_std, color=color, elinewidth=3)) elif std_plot == "fill": artists.append( ax.fill_between( tick_values, metric_values + metric_std, metric_values - metric_std, alpha=0.3, color=color ) ) best_idx = list(metric_values).index((max if best_at_max else min)(metric_values)) for i, (_, _y, _t) in enumerate(zip(x_values, metric_values, tick_values)): if i % markevery: continue if i == best_idx: artists += ax.plot( tick_values[i], metric_values[i], color="green", marker="*", markersize=markersize + 10 ) else: artists += ax.plot( tick_values[i], metric_values[i], color=color, marker=marker, markersize=markersize, markevery=markevery, linestyle=linestyle, ) if annotate: assert not ( f := set(annotate_field) - (a := set(df) - {"total_steps", "step", y_field, f"{y_field}_std"}) ), f"Annotation field: {f} are not in dataframe field: {a}" annotate_field = set(annotate_field) & a abv = lambda s: ( "".join([i[0] for i in s.split("_")] if "_" in s else [s[0]] + [i for i in s[1:] if i not in "aeiou"]) if len(s) > 3 else s ) abv_annot = [*map(abv, annotate_field)] for i, (x, y, t) in enumerate(zip(x_values, metric_values, tick_values)): if i % markevery: continue txt = "\n".join( [ f"{y:.5f}" + ( rf"$\pm${metric_std[i]:.5f}" if (f"{y_field}_std" in df and pd.notna(metric_std[i])) else "" ), str(x), ] + [f"{i}={df.loc[x][j]}" for i, j in zip(abv_annot, annotate_field)] ) artists.append(ax.annotate(txt, (t, y), textcoords="offset points", xytext=(0, 10), ha="center")) ax.tick_params(axis="both", which="major", labelsize=17, direction="in", length=5) return artists
[docs] def ax_draw_bar( ax: Axes, df: pd.DataFrame, label: str, annotate=True, annotate_field=[], std_plot=True, unif_xticks=False, color="orange", **_, ) -> Axes: """Draws bar graph of y_field over arbitrary x_field setted as index of the dataframe. If there is column 'y_field_sdv' is in dataframe, draws in errorbar or fill_between depending on ``sdv_bar_plot`` """ y_field = list(df)[0] x_values, metric_values = map(np.array, zip(*dict(df[y_field]).items())) assert not isinstance(metric_values[0], pd.Series), f"y_field should have only {1} values for each index." tick_values = np.arange(len(x_values)) ax.set_xticks(tick_values, x_values, fontsize=10, rotation=45) artists = [] if std_plot and f"{y_field}_std" in df: x_values, metric_std = map(np.array, zip(*dict(df[f"{y_field}_std"]).items())) artists.append(ax.bar(tick_values, metric_values, yerr=metric_std, label=label, color=color)) else: artists.append(ax.bar(tick_values, metric_values, label=label, color=color)) if annotate: assert not (f := set(annotate_field) - (a := set(df) - {"total_steps", "step", y_field, f"{y_field}_std"})), ( f"Annotation field: {f} are not in dataframe field: {a}" ) annotate_field = set(annotate_field) & a abv = lambda s: ( "".join([i[0] for i in s.split("_")] if "_" in s else [s[0]] + [i for i in s[1:] if i not in "aeiou"]) if len(s) > 3 else s ) abv_annot = [*map(abv, annotate_field)] for i, (x, y, t) in enumerate(zip(x_values, metric_values, tick_values)): txt = "\n".join( [ f"{y:.5f}" + (rf"$\pm${metric_std[i]:.5f}" if (f"{y_field}_std" in df and pd.notna(metric_std[i])) else "") ] + ["" if unif_xticks else str(x)] + [f"{k}={df.loc[x][j]}" for k, j in zip(abv_annot, annotate_field)] ) artists.append(ax.annotate(txt, (t, y), textcoords="offset points", xytext=(0, 10), ha="center")) ax.tick_params(axis="both", which="major", labelsize=17, direction="in", length=5) return artists
[docs] def ax_draw_heatmap(ax: Axes, df: pd.DataFrame, cmap="magma", annotate=True, annotate_field=[], norm=None, **_) -> Axes: """Draws heatmap of y_field over two arbitrary x_fields setted as multi-index of the dataframe.""" y_field = list(df)[0] y_field_df = df.drop(columns=list(df)[1:]) x_fields = y_field_df.index.names grid_df = y_field_df.reset_index().pivot(index=x_fields[1], columns=x_fields[0]) artists = [ax.pcolor(grid_df, cmap=cmap, edgecolors="w", norm=norm)] (*x_values,) = map(lambda l: sorted(set(y_field_df.index.get_level_values(l))), x_fields) ax.set_xticks(np.arange(0.5, len(x_values[0]), 1), x_values[0], fontsize=10, rotation=45) ax.set_yticks(np.arange(0.5, len(x_values[1]), 1), x_values[1], fontsize=10) if annotate: assert not (f := set(annotate_field) - (a := set(df) - {"total_steps", "step", y_field, f"{y_field}_std"})), ( f"Annotation field: {f} are not in dataframe field: {a}" ) annotate_field = set(annotate_field) & a abv = lambda s: ( "".join([i[0] for i in s.split("_")] if "_" in s else [s[0]] + [i for i in s[1:] if i not in "aeiou"]) if len(s) > 3 else s ) abv_annot = [*map(abv, annotate_field)] if f"{y_field}_std" in df: y_std_df = df.drop(columns=list(set(df) - {f"{y_field}_std"})) std_grid_df = y_std_df.reset_index().pivot(index=x_fields[1], columns=x_fields[0]) for i, (mtc, x) in enumerate([*grid_df]): for j, y in enumerate([*grid_df.index.get_level_values(0)]): txt = "\n".join( [ f"{grid_df.loc[y, (mtc, x)]:.5f}" + ( f"\n$\\pm${std_grid_df.loc[y, (f'{mtc}_std', x)]:.5f}" if (f"{y_field}_std" in df and pd.notna(std_grid_df.loc[y, (f"{mtc}_std", x)])) else "" ) ] + [ f"{i}={df.loc[(x, y), j]}" for i, j in zip(abv_annot, annotate_field) if df.index.isin([(x, y)]).any() ] ) artists.append(ax.text(i + 0.5, j + 0.5, txt, c="dimgrey", ha="center", va="center", weight="bold")) # ax.tick_params(axis='both', which='major', labelsize=17, direction='in', length=5) return artists
[docs] def ax_draw_scatter(ax: Axes, df: pd.DataFrame, y_fields: list, color="orange", marker="D", markersize=30, **_) -> Axes: """Draws a 2D scatter plot of two metrics from a melted dataframe. Args: ax: Matplotlib axes to draw on. df: DataFrame with a 'metric' level in the index and 'metric_value' column. y_fields: List of exactly two metric names to plot as x and y axes. color: Marker face color. marker: Marker style. markersize: Base marker size (scaled by 20x internally). Returns: list: List of matplotlib artist objects created. """ assert len(set(df.index.get_level_values("metric"))) == 2, ( f"There should be {2} metrics in the dataframe, got {set(df.index.get_level_values('metric'))}." ) assert set(df.index.get_level_values("metric")) == set(y_fields), ( "y_fields should be the same as metrics in the dataframe." ) df = df.reset_index(["total_steps", "step"], drop=True) # revert back melted metrics into original column form prcs = lambda y: ( df.loc[df.index.get_level_values("metric") == y] .reset_index("metric", drop=True) .rename(columns={"metric_value": y}) ) df = pd.concat([*map(prcs, y_fields)], axis=1) y1, y2 = map(lambda y: list(df[y]), y_fields) artists = [ax.scatter(y1, y2, color=color, marker=marker, s=markersize * 20, edgecolors="black")] return artists
[docs] def ax_draw_scatter_heat( ax: Axes, df: pd.DataFrame, y_fields: list, cmap="magma", marker="D", markersize=30, norm=None, **_ ) -> Axes: """Draws a scatter plot with heatmap coloring from three metrics. The first two metrics are used as x and y coordinates, and the third metric determines the color of each point. Args: ax: Matplotlib axes to draw on. df: DataFrame with a 'metric' level in the index and 'metric_value' column. y_fields: List of exactly three metric names (x, y, color). cmap: Colormap name for the heatmap coloring. marker: Marker style. markersize: Base marker size (scaled by 20x internally). norm: Matplotlib normalization instance for color mapping. Returns: list: List of matplotlib artist objects created. """ assert len(set(df.index.get_level_values("metric"))) == 3, ( f"There should be {3} metrics in the dataframe, got {set(df.index.get_level_values('metric'))}." ) assert set(df.index.get_level_values("metric")) == set(y_fields), ( "y_fields should be the same as metrics in the dataframe." ) df = df.reset_index(["total_steps", "step"], drop=True) # revert back melted metrics into original column form prcs = lambda y: ( df.loc[df.index.get_level_values("metric") == y] .reset_index("metric", drop=True) .rename(columns={"metric_value": y}) ) df = pd.concat([*map(prcs, y_fields)], axis=1) y1, y2, y3 = map(lambda y: list(df[y]), y_fields) artists = [ax.scatter(y1, y2, c=y3, marker=marker, s=markersize * 20, norm=norm, cmap=cmap)] return artists