Source code for malet.plot_utils.utils
import os
import shutil
from operator import attrgetter
from matplotlib.axes import Axes
[docs]
def create_dir(dir):
"""Creates a directory, clearing all contents if it already exists.
Args:
dir: Path of the directory to create or clear.
"""
if os.path.exists(dir):
for f in os.listdir(dir):
if os.path.isdir(os.path.join(dir, f)):
shutil.rmtree(os.path.join(dir, f))
else:
os.remove(os.path.join(dir, f))
else:
os.makedirs(dir)
[docs]
def merge_dict(base: dict, other: dict):
"""Merge plot_config dict (priority: ``base``)"""
for k in set(base) & set(other):
if isinstance(base[k], list):
if base[k] and isinstance(base[k][-1], dict):
base[k] = base[k][:-1] + other[k][:-1] + [merge_dict(base[k][-1], other[k][-1])]
elif isinstance(base[k], dict):
base[k] = merge_dict(base[k], other[k])
for k in set(other) - set(base):
base[k] = other[k]
return base
default_style = {
"annotate": False,
"std_plot": "fill",
"ax_style": {
"frame_width": 2.5,
"fig_size": 7,
"legend": [{"fontsize": 20}],
"grid": [True, {"linestyle": "--"}],
"tick_params": [{"axis": "both", "which": "major", "labelsize": 25, "direction": "in", "length": 5}],
},
"line_style": {
"linewidth": 4,
"marker": "D",
"markersize": 10,
"markevery": 1,
},
}
[docs]
def ax_styler(ax: Axes, **style_dict):
"""Applies matplotlib styling to axes from a style dictionary.
Handles special keys 'fig_size' and 'frame_width' directly, then
delegates remaining entries to the corresponding axes methods
(``set_<name>`` or direct attribute for tick_params, legend, grid).
Args:
ax: Matplotlib axes to style.
**style_dict: Style parameters where each value is a list of
positional args followed by a keyword-arg dict as the last element.
"""
if (n := "fig_size") in style_dict:
dim = style_dict.pop(n)
if isinstance(dim, int):
w = h = dim
else:
w, h = dim
ax.figure.set_figwidth(w)
ax.figure.set_figheight(h)
if (n := "frame_width") in style_dict:
fw = style_dict.pop(n)
for axis in ["top", "bottom", "left", "right"]:
ax.spines[axis].set_linewidth(fw)
non_set = ["tick_params", "legend", "grid"]
for name, (*arg_pos, arg_kw) in style_dict.items():
attr_name = name if name in non_set else f"set_{name}"
attrgetter(attr_name)(ax)(*arg_pos, **arg_kw)