from __future__ import print_function
import colorsys # for rgb_to_hls
import math
import os
import six
import warnings
import yaml
from copy import deepcopy
from functools import partial
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
import numpy as np
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import LogFormatterSciNotation
from matplotlib.colors import LogNorm, Normalize, colorConverter
from rootpy.plotting import Hist1D, Hist2D, Profile1D, Efficiency, F1
from rootpy.plotting.hist import _Hist, _Hist2D
from rootpy.plotting.profile import _ProfileBase
from .._input import InputROOT
from .._colormaps import viridis
from ._base import ContextValue, LiteralString, _ProcessorBase, _make_directory
__all__ = ['LogFormatterSciNotationForceSublabels', 'PlotProcessor']
plt.register_cmap(name='viridis', cmap=viridis)
def _mplrc():
mpl.rcParams.update({'font.size': 11})
if int(mpl.__version__.split('.')[0]) >= 2:
mpl.rc('xtick', direction='in', bottom=True, top=True)
mpl.rc('ytick', direction='in', left=True, right=True)
else:
mpl.rc('xtick', direction='in')
mpl.rc('ytick', direction='in')
mpl.rc('mathtext', fontset='stixsans', fallback_to_cm=False, rm='sans')
mpl.rc('axes', labelsize=16)
mpl.rc('legend', labelspacing=.1, fontsize=8)
def _mathdefault(s):
return '\\mathdefault{%s}' % s
def is_close_to_int(x):
if not np.isfinite(x):
return False
return abs(x - round(x)) < 1e-10
class LogFormatterSciNotationForceSublabels(LogFormatterSciNotation):
"""Variant of LogFormatterSciNotation that always displays labels at
certain non-decade positions. Needed because parent class may hide these
labels based on axis spacing."""
def set_locs(self, *args, **kwargs):
'''override sublabels'''
_ret = super(LogFormatterSciNotationForceSublabels, self).set_locs(*args, **kwargs)
# override locations
_locs = kwargs.pop("locs", None)
if _locs is not None:
self._sublabels = _locs
else:
self._sublabels = {1.0, 2.0, 5.0, 10.0}
return _ret
def _plot_with_error_band(ax, *args, **kwargs):
"""display data as line. If `yerr` is given, an `y` +/- `yerr` error band is also drawn.
You can use custom `_band_kwargs` to format the error band."""
kwargs.pop('xerr', None)
_yerr = kwargs.pop('yerr', None)
_x = np.asarray(args[0])
_y = np.asarray(args[1])
# kwargs delegated to mpl fill_between
_band_kwargs = kwargs.pop('band_kwargs', None) or dict()
if _yerr is None:
return ax.plot(*args, **kwargs)
else:
return (ax.plot(*args, **kwargs),
ax.fill_between(_x, _y - _yerr[0], _y + _yerr[1], **dict(dict(kwargs, alpha=0.5, linewidth=0),
**_band_kwargs)))
def _plot_as_step(ax, *args, **kwargs):
"""display data as horizontal bars with given by `x` +/- `xerr`. `y` error bars are also drawn."""
assert len(args) == 2
_x = np.asarray(args[0])
_y = np.asarray(args[1])
_zeros = np.zeros_like(_x)
# kwarg `yerr_as_band` to display
_show_yerr_as = kwargs.pop('show_yerr_as', None)
if _show_yerr_as is not None and _show_yerr_as not in ('errorbar', 'band'):
raise ValueError("Invalid value '{}' for 'show_yerr_as'. Available: {}".format(_show_yerr_as, ('errorbar', 'band')))
assert 'xerr' in kwargs
if len(kwargs['xerr']) == 1:
_xerr_dn = _xerr_up = kwargs.pop('xerr')[0]
else:
_xerr_dn, _xerr_up = kwargs.pop('xerr')
_yerr = kwargs.pop('yerr', None)
if _yerr is not None:
if len(_yerr) == 1:
_yerr_dn = _yerr_up = _yerr[0]
else:
_yerr_dn, _yerr_up = _yerr
_yerr_dn = np.asarray(_yerr_dn)
_yerr_up = np.asarray(_yerr_up)
_xerr_dn = np.asarray(_xerr_dn)
_xerr_up = np.asarray(_xerr_up)
# replicate each point five times -> bin anchors
# 1 + + 5
# | |
# +---+---+
# 2 3 4
_x = np.vstack([_x, _x, _x, _x, _x]).T.flatten()
_y = np.vstack([_y, _y, _y, _y, _y]).T.flatten()
# stop processing y errors if they are zero
if np.allclose(_yerr, 0):
_yerr = None
# attach y errors (if any) to "bin" center
if _yerr is not None:
if _show_yerr_as == 'band':
# error band: shade across entire bin width
_yerr_dn = np.vstack([_zeros, _yerr_dn, _yerr_dn, _yerr_dn, _zeros]).T.flatten()
_yerr_up = np.vstack([_zeros, _yerr_up, _yerr_up, _yerr_up, _zeros]).T.flatten()
else:
# errorbars: only show on central point
_yerr_dn = np.vstack([_zeros, _zeros, _yerr_dn, _zeros, _zeros]).T.flatten()
_yerr_up = np.vstack([_zeros, _zeros, _yerr_up, _zeros, _zeros]).T.flatten()
_yerr = [_yerr_dn, _yerr_up]
# shift left and right replicas in x by xerr
_x += np.vstack([-_xerr_dn, -_xerr_dn, _zeros, _xerr_up, _xerr_up]).T.flatten()
# obtain indices of points with a binning discontinuity
_bin_edge_discontinuous_at = (np.flatnonzero(_x[0::5][1:] != _x[4::5][:-1]) + 1)*5
# prevent diagonal connections across bin discontinuities
if len(_bin_edge_discontinuous_at):
_x = np.insert(_x, _bin_edge_discontinuous_at, [np.nan])
_y = np.insert(_y, _bin_edge_discontinuous_at, [np.nan])
if _yerr is not None:
_yerr = np.insert(_yerr, _bin_edge_discontinuous_at, [np.nan], axis=1)
# do actual plotting
if _show_yerr_as == 'errorbar' or _show_yerr_as is None:
return ax.errorbar(_x, _y, yerr=_yerr if _show_yerr_as else None, **kwargs)
elif _show_yerr_as == 'band':
_band_alpha = kwargs.pop('band_alpha', 0.5)
_capsize = kwargs.pop('capsize', None)
_markeredgecolor = kwargs.pop('markeredgecolor', None)
if _yerr is None:
_yerr = 0, 0
return (
ax.errorbar(_x, _y, yerr=None, capsize=_capsize, markeredgecolor=_markeredgecolor, **kwargs),
ax.fill_between(_x, _y-_yerr[0], _y+_yerr[1], **dict(kwargs, alpha=_band_alpha, linewidth=0)))
[docs]class PlotProcessor(_ProcessorBase):
"""Processor for plotting objects from ROOT files.
.. todo::
API documentation.
"""
CONFIG_KEY_FOR_TEMPLATES = "figures"
SUBKEYS_FOR_CONTEXT_REPLACING = ["subplots", "pads", "texts"]
CONFIG_KEY_FOR_CONTEXTS = "expansions"
_EXTERNAL_PLOT_METHODS = dict(
step = _plot_as_step,
plot = _plot_with_error_band
)
_PC_KEYS_MPL_AXES_METHODS = dict(
x_label = dict(
method='set_xlabel',
kwargs=dict(x=1.0, ha='right'),
),
y_label = dict(
method='set_ylabel',
kwargs=dict(y=1.0, ha='right'),
),
x_range = dict(
method='set_xlim',
),
y_range = dict(
method='set_ylim',
),
x_scale = dict(
method='set_xscale',
),
y_scale = dict(
method='set_yscale',
),
x_ticklabels = dict(
method='set_xticklabels',
),
y_ticklabels = dict(
method='set_yticklabels',
),
x_ticks = dict(
method='set_xticks',
),
y_ticks = dict(
method='set_yticks',
),
)
_DEFAULT_LEGEND_KWARGS = dict(
ncol=1, numpoints=1, fontsize=12, frameon=False,
loc='upper right'
)
_DEFAULT_LINE_KWARGS = dict(
linestyle='--', color='gray', linewidth=1, zorder=-99
)
def __init__(self, config, output_folder):
super(PlotProcessor, self).__init__(config, output_folder)
self._input_controller = InputROOT(
files_spec=self._config['input_files']
)
self._figures = {}
self._global_request_params = self._config.get("global_request_params", {})
# introduce pseudo-context for accessing input file content
self._config[self.CONFIG_KEY_FOR_CONTEXTS].update(
_input_controller=[self._input_controller]
)
# -- helper methods
def _get_figure(self, figure_name, figsize=None):
if figure_name not in self._figures:
self._figures[figure_name] = plt.figure(figsize=figsize)
return self._figures[figure_name]
@staticmethod
def _merge_legend_handles_labels(handles, labels):
'''merge handles for identical labels'''
_seen_labels = []
_seen_label_handles = []
_new_label_indices = []
for _ihl, (_h, _l) in enumerate(zip(handles, labels)):
if _l not in _seen_labels:
_seen_labels.append(_l)
_seen_label_handles.append([_h])
else:
_idx = _seen_labels.index(_l)
_seen_label_handles[_idx].append(_h)
for _i, (_sh, _sl) in enumerate(zip(_seen_label_handles, _seen_labels)):
_seen_label_handles[_i] = tuple(_seen_label_handles[_i])
return _seen_label_handles, _seen_labels
@staticmethod
def _sort_legend_handles_labels(handles, labels, stack_labels=None):
'''sort handles and labels, reversing the order of those that are part of a stack'''
# if no stacks or a stack with a single label, don't sort
if stack_labels is None or len(stack_labels) <= 1:
return handles, labels
# temporarily cast to array to use numpy indexing
_hs, _ls = np.asarray(handles), np.asarray(labels)
_criterion = np.vectorize(lambda label: label in stack_labels)
# reverse sublist selected by criterion
_ls[_criterion(_ls)] = _ls[_criterion(_ls)][::-1]
_hs[_criterion(_ls)] = _hs[_criterion(_ls)][::-1]
# return as lists
return list(_hs), list(_ls)
# -- actions
def _request(self, config):
'''request all objects encountered in all subplot expressions'''
for _subplot_cfg in config['subplots']:
request_params = dict(self._global_request_params, **_subplot_cfg.get('request_params', {}))
self._input_controller._request_all_objects_in_expression(_subplot_cfg['expression'], **request_params)
#print('REQ', _subplot_cfg['expression'])
def _plot(self, config):
'''plot all figures'''
_mplrc()
# register expressions as locals for lookup by the input controller's `get` call
self._input_controller.register_local('expressions', [_subplot_cfg['expression'] for _subplot_cfg in config['subplots']])
_filename = os.path.join(self._output_folder, config['filename'])
# prepare dict for YAML dump, if requested
_dump_yaml = config.pop('dump_yaml', False)
if _dump_yaml:
_yaml_filename = '.'.join(_filename.split('.')[:-1]) + '.yml'
# need to create directory first
_make_directory(os.path.dirname(_yaml_filename))
# add input files to dump
_config_for_dump = dict(deepcopy(config), input_files=self._config['input_files'])
else:
_config_for_dump = config # dummy link to original config
# step 1: create figure and pads
_figsize = config.pop('figsize', None)
_fig = self._get_figure(_filename, figsize=_figsize)
# obtain configuration of pads
_pad_configs = config.get('pads', None)
if _pad_configs is None:
# default pad configuration
_pad_configs = [dict()]
# get share
_height_ratios = [_pc.get('height_share', 1) for _pc in _pad_configs]
# construct GridSpec from `pad_spec` or make default
_gridspec_kwargs = config.get('pad_spec', dict())
_gridspec_kwargs.pop('height_ratios', None) # ignore explicit user-provided `height_ratios`
_gs = GridSpec(nrows=len(_pad_configs), ncols=1, height_ratios=_height_ratios, **_gridspec_kwargs)
# store `Axes` objects in pad configuration
for _i_pad, _pad_config in enumerate(_pad_configs):
_pad_config['axes'] = _fig.add_subplot(_gs[_i_pad])
_stack_bottoms = _pad_config.setdefault('stack_bottoms', {})
_bin_labels = _pad_config.setdefault('bin_labels', {})
_bin_label_anchors = _pad_config.setdefault('bin_label_anchors', {})
# enable text output, if requested
if config.pop("text_output", False):
_text_filename = '.'.join(_filename.split('.')[:-1]) + '.txt'
# need to create directory first
_make_directory(os.path.dirname(_text_filename))
_text_file = open(_text_filename, 'w')
else:
_text_file = None
# step 2: retrieve data and plot
assert len(config['subplots']) == len(_config_for_dump['subplots'])
for _pc, _pc_for_dump in zip(config['subplots'], _config_for_dump['subplots']):
_kwargs = deepcopy(_pc)
# obtain and validate pad ID
_pad_id = _kwargs.pop('pad', 0)
if _pad_id >= len(_pad_configs):
raise ValueError("Cannot plot to pad {}: only pads up to {} have been configured!".format(_pad_id, len(_pad_configs)-1))
# select pad axes and configuration
_pad_config = _pad_configs[_pad_id]
_ax = _pad_config['axes']
_stack_bottoms = _pad_config.setdefault('stack_bottoms', {})
_stack_labels = _pad_config.setdefault('stack_labels', [])
_bin_labels = _pad_config.setdefault('bin_labels', {})
_bin_label_anchors = _pad_config.setdefault('bin_label_anchors', {})
_expression = _kwargs.pop('expression')
#("PLT {}".format(_expression))
_plot_object = self._input_controller.get_expr(_expression)
# extract arrays for keys which could be masked by 'mask_zero_errors'
_plot_data = {
_property_name : np.array(list(getattr(_plot_object, _property_name)()))
for _property_name in ('x', 'xerr', 'y', 'yerr', 'xwidth', 'efficiencies', 'errors')
if hasattr(_plot_object, _property_name)
}
# extract individual bin labels (if they exist)
for _i_axis, _axis in enumerate("xyz"):
try:
_root_obj_axis = _plot_object.axis(_i_axis)
except AttributeError:
_root_obj_axis = None
if _root_obj_axis is not None and bool(_root_obj_axis.GetLabels()):
_axis_nbins_method = getattr(_plot_object, "GetNbins{}".format(_axis.upper()))
_plot_data['{}binlabels'.format(_axis)] = [_root_obj_axis.GetBinLabel(_i_bin) for _i_bin in range(1, _axis_nbins_method() + 1)]
# map fields for TEfficiency objects
if isinstance(_plot_object, Efficiency):
_total_hist = _plot_object.total
_plot_data['x'] = np.array(list(_total_hist.x()))
_plot_data['xerr'] = np.array(list(_total_hist.xerr()))
_plot_data['y'] = _plot_data.pop('efficiencies', None)
_plot_data['yerr'] = _plot_data.pop('errors', None)
# map fields for TF1 objects
elif isinstance(_plot_object, F1):
_xmin, _xmax = _plot_object.xaxis.get_xmin(), _plot_object.xaxis.get_xmax()
# compute support points (evenly-spaced)
_plot_data['x'] = np.linspace(_xmin, _xmax, 100) # TODO: make configurable
_plot_data['xerr'] = np.zeros_like(_plot_data['x'])
# evaluate TF1 at every point
_plot_data['y'] = np.asarray(list(map(_plot_object, _plot_data['x'])))
_plot_data['yerr'] = np.zeros_like(_plot_data['y']) # TODO: function errors (?)
# mask all points with erorrs set to zero
_mze = _kwargs.pop('mask_zero_errors', False)
if _mze and len(_plot_data['yerr']) != 0:
_mask = np.all((_plot_data['yerr'] != 0), axis=1)
_plot_data = {
_key : np.compress(_mask, _value, axis=0)
for _key, _value in six.iteritems(_plot_data)
}
# extract arrays for keys which cannot be masked
_plot_data.update({
_property_name : np.array(list(getattr(_plot_object, _property_name)()))
for _property_name in ('xedges', 'yedges', 'z')
if hasattr(_plot_object, _property_name)
})
# -- draw
_plot_method_name = _kwargs.pop('plot_method', 'errorbar')
# -- obtain plot method
try:
# use external method (if available) and curry in the axes object
_plot_method = partial(self._EXTERNAL_PLOT_METHODS[_plot_method_name], _ax)
except KeyError:
#
_plot_method = getattr(_ax, _plot_method_name)
if _plot_method_name in ['errorbar', 'step']:
_kwargs.setdefault('capsize', 0)
if 'color' in _pc:
_kwargs.setdefault('markeredgecolor', _kwargs['color'])
# remove connecting lines for 'errorbar' plots only
if _plot_method_name == 'errorbar':
_kwargs.setdefault('linestyle', '')
# marker styles
_marker_style = _kwargs.pop('marker_style', None)
if _marker_style is not None:
if _marker_style == 'full':
_kwargs.update(
markerfacecolor=_kwargs['color'],
markeredgewidth=0,
)
elif _marker_style == 'empty':
_kwargs.update(
markerfacecolor='w',
markeredgewidth=1,
)
else:
raise ValueError("Unkown value for 'marker_style': {}".format(_marker_style))
# handle stacking
_stack_name = _kwargs.pop('stack', None)
_y_bottom = 0
if _stack_name is not None:
_y_bottom = _stack_bottoms.setdefault(_stack_name, 0.0) # actually 'get' with default
# keep track of stack labels in order to reverse the legend order later
_stack_label = _kwargs.get('label', None)
if _stack_label is not None:
_stack_labels.append(_stack_label)
# different methods handle information differently
if _plot_method_name == 'bar':
_kwargs['width'] = _plot_data['xwidth']
_kwargs.setdefault('align', 'center')
_kwargs.setdefault('edgecolor', '')
_kwargs.setdefault('linewidth', 0)
if 'color' in _kwargs:
# make error bar color match fill color
_kwargs.setdefault('ecolor', _kwargs['color'])
_kwargs['y'] = _plot_data['y']
_kwargs['bottom'] = _y_bottom
else:
_kwargs['y'] = _plot_data['y'] + _y_bottom
_kwargs['xerr'] = _plot_data['xerr'].T
_show_yerr = _kwargs.pop('show_yerr', True)
if _show_yerr:
_kwargs['yerr'] = _plot_data['yerr'].T
_y_data = _kwargs.pop('y')
_normflag = _kwargs.pop('normalize_to_width', False)
if _normflag:
_y_data /= _plot_data['xwidth']
if 'yerr' in _kwargs and _kwargs['yerr'] is not None:
_kwargs['yerr'] /= _plot_data['xwidth']
# -- sort out positional arguments to plot method
if _plot_method_name == 'pcolormesh':
# mask zeros
_z_masked = np.ma.array(_plot_data['z'], mask=_plot_data['z']==0)
# determine data range in z
_z_range = _pad_config.get('z_range', None)
if _z_range is not None:
# use specified values as range
_z_min, _z_max = _z_range
else:
# use data values
_z_min, _z_max = _z_masked.min(), _z_masked.max()
# determine colormap normalization (if not explicitly given)
if 'norm' not in _kwargs:
_z_scale = _pad_config.get('z_scale', "linear")
if _z_scale == 'linear':
_norm = Normalize(vmin=_z_min, vmax=_z_max)
elif _z_scale == 'log':
_norm = LogNorm(vmin=_z_min, vmax=_z_max)
else:
raise ValueError("Unknown value '{}' for keyword 'z_scale': known are {{'linear', 'log'}}".format(_z_scale))
_kwargs['norm'] = _norm
# Z array needs to be transposed because 'X' refers to columns and 'Y' to rows...
_args = [_plot_data['xedges'], _plot_data['yedges'], _z_masked.T]
_kwargs.pop('color', None)
_kwargs.pop('xerr', None)
_kwargs.pop('yerr', None)
# some kwargs must be popped and stored for later use
_label_bins_with_content = _kwargs.pop('label_bins_with_content', False)
_bin_label_format = _kwargs.pop('bin_label_format', "{:f}")
_bin_label_color = _kwargs.pop('bin_label_color', 'k')
else:
_args = [_plot_data['x'], _y_data]
# skip empty arguments
if len(_args[0]) == 0:
continue
# run the plot method
_plot_handle = _plot_method(
*_args,
**_kwargs
)
# store 2D plots for displaying color bars
if _plot_method_name == 'pcolormesh':
_pad_config.setdefault('2d_plots', []).append(_plot_handle)
# add 2D bin annotations, if requested
if _label_bins_with_content:
_bin_center_x = 0.5 * (_plot_data['xedges'][1:] + _plot_data['xedges'][:-1])
_bin_center_y = 0.5 * (_plot_data['yedges'][1:] + _plot_data['yedges'][:-1])
_bin_center_xx, _bin_center_yy = np.meshgrid(_bin_center_x, _bin_center_y)
_bin_content = _args[2]
for _row_x_y_content in zip(_bin_center_xx, _bin_center_yy, _bin_content):
for _x, _y, _content in zip(*_row_x_y_content):
# skip masked and invalid bin contents
if not isinstance(_content, np.ma.core.MaskedConstant) and not np.isnan(_content):
if _bin_label_color == 'auto':
_patch_color_lightness = colorsys.rgb_to_hls(*(_plot_handle.to_rgba(_content)[:3]))[1]
_text_color = 'w' if _patch_color_lightness < 0.5 else 'k'
else:
_text_color = _bin_label_color
_ax.text(_x, _y, _bin_label_format.format(_content),
ha='center', va='center',
fontsize=16,
color=_text_color,
transform=_ax.transData
)
# write results to config dict that will be dumped
if _dump_yaml:
_pc_for_dump['plot_args'] = dict(
# prevent dumping numpy arrays as binary
args=[_a.tolist() if isinstance(_a, np.ndarray) else _a for _a in _args],
**{_kw : _val.tolist() if isinstance(_val, np.ndarray) else _val for _kw, _val in six.iteritems(_kwargs)}
)
if _text_file is not None:
np.set_printoptions(threshold=np.inf)
_text_file.write("- {}(\n\t{},\n\t{}\n)\n".format(
_plot_method_name,
',\n\t'.join(["{}".format(repr(_arg)) for _arg in _args]),
',\n\t'.join(["{} = {}".format(_k, repr(_v)) for _k, _v in six.iteritems(_kwargs)]),
))
np.set_printoptions(threshold=1000)
# update stack bottoms
if _stack_name is not None:
_stack_bottoms[_stack_name] += _plot_data['y']
# keep track of the bin labels of each object in a pad
for _i_axis, _axis in enumerate("xyz"):
_bl_key = '{}binlabels'.format(_axis)
_bl = _plot_data.get(_bl_key, None)
if _bl is not None:
_bin_labels.setdefault(_axis, []).append(_bl)
_bin_label_anchors.setdefault(_axis, []).append(_plot_data.get(_axis, None))
# close text output
if _text_file is not None:
_text_file.close()
# step 3: pad adjustments
for _pad_config in _pad_configs:
_ax = _pad_config['axes']
# simple axes adjustments
for _prop_name, _meth_dict in six.iteritems(self._PC_KEYS_MPL_AXES_METHODS):
_prop_val = _pad_config.get(_prop_name, None)
if _prop_val is not None:
#print(_prop_name, _prop_val)
getattr(_ax, _meth_dict['method'])(_prop_val, **_meth_dict.get('kwargs', {}))
# draw colorbar if there was a 2D plot involved
if _pad_config.get('2d_plots', None):
for _2d_plot in _pad_config['2d_plots']:
_cbar = _fig.colorbar(_2d_plot, ax=_ax)
_z_label = _pad_config.get('z_label', None)
_z_labelpad = _pad_config.get('z_labelpad', None)
if _z_label is not None:
_cbar.ax.set_ylabel(_z_label, rotation=90, va="bottom", ha='right', y=1.0, labelpad=_z_labelpad)
# handle sets of horizontal and vertical lines
for _axlines_key in ('axhlines', 'axvlines'):
_ax_method_name = _axlines_key[:-1]
assert hasattr(_ax, _ax_method_name)
_axlines = _pad_config.pop(_axlines_key, [])
# wrap in list if not already list
if not isinstance(_axlines, list):
_axlines = [_axlines]
for _axlines_set in _axlines:
if not isinstance(_axlines_set, dict):
# wrap inner 'values' in list if not already list
if not isinstance(_axlines_set, list):
_axlines_set = [_axlines_set]
_axlines_set = dict(values=_axlines_set)
_vals = _axlines_set.pop('values')
# draw the line
for _val in _vals:
getattr(_ax, _ax_method_name)(_val, **dict(self._DEFAULT_LINE_KWARGS, **_axlines_set))
# -- handle plot legend
# obtain legend handles and labels
_hs, _ls = _ax.get_legend_handles_labels()
# re-sort, reversing the order of labels that are part of a stack
_hs, _ls = self._sort_legend_handles_labels(_hs, _ls, stack_labels=_pad_config.get("stack_labels", None))
# merge legend entries with identical labels
_hs, _ls = self._merge_legend_handles_labels(_hs, _ls)
# draw legend with user-specified kwargs
_legend_kwargs = self._DEFAULT_LEGEND_KWARGS.copy()
_legend_kwargs.update(_pad_config.pop('legend_kwargs', {}))
_ax.legend(_hs, _ls, **_legend_kwargs)
# handle log x-axis formatting (only if 'x_ticklabels' is not given as [])
if _pad_config.get('x_scale', None) == 'log' and _pad_config.get('x_ticklabels', True):
_log_decade_ticklabels = _pad_config.get('x_log_decade_ticklabels', {1.0, 2.0, 5.0, 10.0})
_formatter = LogFormatterSciNotationForceSublabels(base=10.0, labelOnlyBase=False)
_ax.xaxis.set_minor_formatter(_formatter)
_formatter.set_locs(locs=_log_decade_ticklabels)
# NOTE: do not force labeling of minor ticks in log-scaled y axes
## handle log y-axis formatting (only if 'y_ticklabels' is not given as [])
#if _pad_config.get('y_scale', None) == 'log' and _pad_config.get('y_ticklabels', True):
# _log_decade_ticklabels = _pad_config.get('y_log_decade_ticklabels', {1.0, 5.0})
# _formatter = LogFormatterSciNotationForceSublabels(base=10.0, labelOnlyBase=False)
# _ax.yaxis.set_minor_formatter(_formatter)
# _formatter.set_locs(locs=_log_decade_ticklabels)
# draw bin labels instead of numeric labels at ticks
for _axis in "xyz":
_bl_sets = _pad_config["bin_labels"].get(_axis, None)
_ba_sets = _pad_config["bin_label_anchors"].get(_axis, None)
# skip for axes without bin labels
if not _bl_sets:
continue
# check if bin labels are identical for all objects in the pad
if len(_bl_sets) > 1:
if False in [_bl_sets[_i_set] == _bl_sets[0] for _i_set in range(1, len(_bl_sets))]:
raise ValueError("Bin labels for axis '{}' differ across objects for the same pad! Got the following sets: {}".format(_axis, _bl_sets))
elif False in [np.all(_ba_sets[_i_set] == _ba_sets[0]) for _i_set in range(1, len(_ba_sets))]:
raise ValueError("Bin label anchors for axis '{}' differ across objects for the same pad! Got the following sets: {}".format(_axis, _ba_sets))
# draw bin labels
if _axis == 'x':
for _bl, _ba in zip(_bl_sets[0], _ba_sets[0]):
_ax.annotate(_bl, xy=(_ba, 0), xycoords=('data', 'axes fraction'), xytext=(0, -6), textcoords='offset points', va='top', ha='right', rotation=30)
_ax.xaxis.set_ticks(_ba_sets[0]) # reset tick marks
_ax.xaxis.set_ticklabels([]) # hide numeric tick labels
elif _axis == 'y':
for _bl, _ba in zip(_bl_sets[0], _ba_sets[0]):
_ax.annotate(_bl, xy=(0, _ba), xycoords=('axes fraction', 'data'), xytext=(-6, 0), textcoords='offset points', va='center', ha='right')
_ax.yaxis.set_ticks(_ba_sets[0]) # reset tick marks
_ax.yaxis.set_ticklabels([]) # hide numeric tick labels
else:
print("WARNING: Bin labels found for axis '{}', but this is not supported. Ignoring...".format(_axis))
# step 4: text and annotations
# draw text/annotations
_text_configs = config.pop('texts', [])
for _text_config in _text_configs:
# retrieve target pad
_pad_id = _text_config.pop('pad', 0)
_ax = _pad_configs[_pad_id]['axes']
# handle deprecated keyword 'transform'
if 'transform' in _text_config:
raise ValueError(
"Keyword 'transform' is deprecated. Use keywords "
"'xycoords' and 'textcoords' to specify a coordinate "
"system.")
# retrieve coordinates and text
_xy = _text_config.pop('xy')
_s = _text_config.pop('text')
# draw annotation
_text_config.setdefault('ha', 'left')
_ax.annotate(_s, _xy,
**_text_config
)
# step 5: figure adjustments
# handle figure label ("upper_label")
_upper_label = config.pop('upper_label', None)
if _upper_label is not None:
# keyword is deprecated
warnings.warn(
"Keyword `upper_label` is deprecated and will be removed in "
"the future. Replace it with the following equivalent annotation under "
"`texts`: dict(text='...', xy=(1, 1), xycoords='axes fraction', xytext=(0, 5), "
"textcoords='offset points', ha='right', pad=0).", DeprecationWarning)
# place above topmost `Axes`
_pad_configs[0]['axes'].annotate(_upper_label, xy=(1, 1),
xycoords='axes fraction',
xytext=(0, 5),
textcoords='offset points',
ha='right',
)
# step 6: save figures
_make_directory(os.path.dirname(_filename))
_fig.savefig('{}'.format(_filename))
#plt.close(_fig) # close figure to save memory
# dump YAML to file, if requested
if _dump_yaml:
with open(_yaml_filename, 'w') as _yaml_file:
yaml.dump(_config_for_dump, _yaml_file)
# de-register all the locals after a plot is done
self._input_controller.clear_locals()
# -- register action slots
_ACTIONS = [_request, _plot]
# -- additional public API