Source code for Karma.PostProcessing.Palisade._input

from __future__ import print_function

import ast
import functools
import ROOT
import numpy as np
import operator as op
import os
import pandas as pd
import six
import uuid

from array import array

from rootpy import asrootpy
from rootpy.io import root_open
from rootpy.plotting import Hist1D, Hist2D, Profile1D, Efficiency, Graph
from rootpy.plotting.hist import _Hist, _Hist2D
from rootpy.plotting.profile import _ProfileBase

if six.PY2:
    from collections import Mapping
elif six.PY3:
    from collections.abc import Mapping

import scipy.stats as stats


__all__ = ['InputROOTFile', 'InputROOT']


class HashableMap(Mapping):
    """A hashable, immutable mapping type.

    The arguments to ``HashableMap`` are processed the same way as those to ``dict``.
    Both the keys and the values contained in the map must be hashable.

    An exception are `lists`, which are converted to `tuple` in order  to make them
    hashable.

    Two `HashableMap` instances with identical keys and values are guaranteed to be
    identical. "Identical" in this context means having the same hash.
    """

    def __init__(self, *args, **kwargs):

        self._d = dict(*args, **kwargs)

        # validate values
        for key, value in six.iteritems(self._d):
            # check hashability
            try:
                hash(value)
            except TypeError:
                # attempt to wrap non-hashable types
                if isinstance(value, list) or isinstance(value, tuple):
                    value = tuple(enumerate(value))
                elif isinstance(value, dict):
                    value = tuple(six.iteritems(value))
                else:
                    raise

                self._d[key] = HashableMap(value)

        self._hash = None  # computed on request

    def __iter__(self):
        return iter(self._d)

    def __len__(self):
        return len(self._d)

    def __getitem__(self, key):
        return self._d[key]

    def __hash__(self):
        # compute hash if not available
        if self._hash is None:
            # hash computed from hashes of keys and values
            self._hash = 0
            for key, value in self.iteritems():
                self._hash ^= hash(key)
                self._hash ^= hash(value)

        return self._hash

    def __repr__(self):
        return "%s(%s)" % (self.__class__.__name__, self._d)


class _ROOTObjectFunctions(object):

    @staticmethod
    def get_all():
        """return all methods not beginning with a '_'"""
        return {
            _method : getattr(_ROOTObjectFunctions, _method)
            for _method in dir(_ROOTObjectFunctions)
            if not _method.startswith('_') and callable(getattr(_ROOTObjectFunctions, _method))
        }

    @staticmethod
    def _project_or_clone(tobject, projection_options=None):

        if isinstance(tobject, _ProfileBase):
            # create an "x-projection" with a unique suffix
            if projection_options is None:
                return asrootpy(tobject.ProjectionX(uuid.uuid4().get_hex()))
            else:
                return asrootpy(tobject.ProjectionX(uuid.uuid4().get_hex(), projection_options))
        else:
            return tobject.Clone()

    @staticmethod
    def histdivide(tobject_1, tobject_2, option=""):
        """divide two histograms, taking error calculation option into account"""

        _new_tobject_1 = _ROOTObjectFunctions._project_or_clone(tobject_1)
        _new_tobject_2 = _ROOTObjectFunctions._project_or_clone(tobject_2)

        _new_tobject_1.Divide(_new_tobject_1, _new_tobject_2, 1, 1, option)

        return _new_tobject_1

    @staticmethod
    def max_yield_index(yields, efficiencies, eff_threshold):
        """for each bin, return index of object in `yields` which is maximizes yield, subject to the efficiency remaining above threshold"""

        # `yields` and `efficiencies` must have the same length
        assert len(yields) == len(efficiencies)
        # all `yields` and `efficiencies` must have the same number of bins
        assert all([len(_tobj_yld) == len(yields[0]) for _tobj_yld in yields[1:]])

        _new_tobject = _ROOTObjectFunctions._project_or_clone(yields[0])

        for _bin_idx in range(len(yields[0])):
            _max_yield_for_bin = 0
            _max_yield_obj_idx = -1

            for _obj_index, (_tobj_yld, _tobj_eff) in enumerate(zip(yields, efficiencies)):
                # skip bins with efficiency below threshold
                if _tobj_eff[_bin_idx].value < eff_threshold:
                    continue
                # keep index of object with maximum yield
                if _tobj_yld[_bin_idx].value > _max_yield_for_bin:
                    _max_yield_for_bin = _tobj_yld[_bin_idx].value
                    _max_yield_obj_idx = _obj_index

            _new_tobject[_bin_idx].value = _max_yield_obj_idx
            _new_tobject[_bin_idx].error = 0

        return _new_tobject

    @staticmethod
    def max_value_index(tobjects):
        """for each bin *i*, return index of object in `tobjects` which contains the largest value for bin *i*"""

        # all `tobjects` must have the same number of bins
        assert all([len(_tobj) == len(tobjects[0]) for _tobj in tobjects[1:]])

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobjects[0])

        for _bin_idx in range(len(tobjects[0])):
            _max_yield_for_bin = 0
            _max_yield_obj_idx = -1

            for _obj_index, _tobj in enumerate(tobjects):
                # keep index of object with maximum yield
                if _tobj[_bin_idx].value > _max_yield_for_bin:
                    _max_yield_for_bin = _tobj[_bin_idx].value
                    _max_yield_obj_idx = _obj_index

            _new_tobject[_bin_idx].value = _max_yield_obj_idx
            _new_tobject[_bin_idx].error = 0

        return _new_tobject

    @staticmethod
    def select(tobjects, indices):
        """the content of each bin *i* in the return object is taken from the object whose index in `tobjects` is given by bin *i* in `indices`"""
        # if indices are outside the range of `tobjects`, the bins are set to zero
        # all `yields` and `efficiencies` must have the same binning
        # `yields` and `efficiencies` must have the same length
        assert len(tobjects) > 0
        assert len(tobjects[0]) == len(indices)
        # all `tobjects` must have the same number of bins
        assert all([len(_tobj) == len(tobjects[0]) for _tobj in tobjects[1:]])

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobjects[0])

        for _i_bin, (_bin_proxy, _obj_idx) in enumerate(zip(_new_tobject, indices)):
            # range check
            if _obj_idx.value >= 0 and _obj_idx.value < len(tobjects):
                _bin_proxy.value = tobjects[int(_obj_idx.value)][_i_bin].value
                _bin_proxy.error = tobjects[int(_obj_idx.value)][_i_bin].error
            else:
                _bin_proxy.value = 0
                _bin_proxy.error = 0

        return _new_tobject

    @staticmethod
    def mask_lookup_value(tobject, tobject_lookup, lookup_value):
        """bin *i* in return object is bin *i* in `tobject` if bin *i* in `tobject_lookup` is equal to `lookup_value`"""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject)

        for _bin_proxy, _bin_proxy_lookup in zip(_new_tobject, tobject_lookup):
            # mask bins where looked up value is different than the reference
            if _bin_proxy_lookup.value != lookup_value:
                _bin_proxy.value = 0
                _bin_proxy.error = 0

        return _new_tobject

    @staticmethod
    def apply_efficiency_correction(tobject, efficiency, threshold=None):
        """Divide each bin in `tobject` by the corresponding bin in `efficiency`. If `efficiency` is lower than `threshold`, the number of events is set to zero."""
        assert len(efficiency) == len(tobject)

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject)

        for _eff, _bin_proxy in zip(efficiency.efficiencies(), _new_tobject):
            if (_eff <= 0) or (threshold is not None and _eff < threshold):
                _bin_proxy.value = 0
                _bin_proxy.error = 0
            else:
                _bin_proxy.value /= _eff
                _bin_proxy.error /= _eff  # should this be done?

        return _new_tobject

    @staticmethod
    def efficiency(tobject_numerator, tobject_denominator):
        """Compute TEfficiency"""

        return Efficiency(tobject_numerator, tobject_denominator)

    @staticmethod
    def efficiency_graph(tobject_numerator, tobject_denominator):
        """Compute TEfficiency with proper clopper-pearson intervals"""

        _eff = Efficiency(tobject_numerator, tobject_denominator)
        return asrootpy(_eff.CreateGraph())

    @staticmethod
    def project_x(tobject):
        """Apply ProjectionX() operation."""

        if hasattr(tobject, 'ProjectionX'):
            _new_tobject = asrootpy(tobject.ProjectionX(uuid.uuid4().get_hex()))
        else:
            print("[INFO] `project_x` not available for object with type {}".format(type(tobject)))
            return tobject

        return _new_tobject

    @staticmethod
    def project_y(tobject):
        """Apply ProjectionY() operation."""

        if hasattr(tobject, 'ProjectionY'):
            _new_tobject = asrootpy(tobject.ProjectionY(uuid.uuid4().get_hex()))
        else:
            raise ValueError("`project_y` not available for object with type {}".format(type(tobject)))

        return _new_tobject

    @staticmethod
    def diagonal(th2d):
        """Return a TH1D containing the main diagonal of an input TH2D."""

        if hasattr(th2d, 'ProjectionX'):
            _new_tobject = asrootpy(th2d.ProjectionX(uuid.uuid4().get_hex()))
        else:
            raise ValueError("`diagonal` not available for object with type {}".format(type(tobject)))

        for _bp in th2d:
            if _bp.xyz[0] == _bp.xyz[1]:
                _new_tobject[_bp.xyz[0]].value = _bp.value
                _new_tobject[_bp.xyz[0]].error = _bp.error

        return _new_tobject

    @staticmethod
    def yerr(tobject):
        """replace bin value with bin error and set bin error to zero"""

        # project preserving errors
        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject, "e")

        for _bin_proxy in _new_tobject:
            _bin_proxy.value, _bin_proxy.error = _bin_proxy.error, 0

        return _new_tobject

    @staticmethod
    def atleast(tobject, min_value):
        """mask all values below threshold"""

        # project preserving errors
        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject, "e")

        for _bin_proxy in _new_tobject:
            if hasattr(_bin_proxy, 'graph_'):
                # for TGraph etc.
                if _bin_proxy.y.value < min_value:
                    _bin_proxy.y.value = 0
                    _bin_proxy.y.error_hi = 0
                    # 'low' error setter has a bug in rootpy. workaround:
                    _bin_proxy.graph_.SetPointEYlow(_bin_proxy.idx_, 0)
            else:
                # for TH1D etc.
                if _bin_proxy.value < min_value:
                    _bin_proxy.value, _bin_proxy.error = 0, 0


        return _new_tobject

    @staticmethod
    def threshold(tobject, min_value):
        """returns a histogram like tobject with bins set to zero if thet fall
        below the miminum value and to one if not. Errors are always set to zero"""

        # project preserving errors
        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject)

        for _bin_proxy in _new_tobject:
            if _bin_proxy.value < min_value:
                _bin_proxy.value, _bin_proxy.error = 0, 0
            else:
                _bin_proxy.value, _bin_proxy.error = 1, 0

        return _new_tobject

    @staticmethod
    def discard_errors(tobject):
        """set all bin errors to zero"""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject)

        for _bin_proxy in _new_tobject:
            if hasattr(_bin_proxy, 'graph_'):
                # for TGraph etc.
                _bin_proxy.y.error_hi = 0
                # 'low' error setter has a bug in rootpy. workaround:
                _bin_proxy.graph_.SetPointEYlow(_bin_proxy.idx_, 0)
            else:
                # for TH1D etc.
                _bin_proxy.error = 0

        return _new_tobject

    @staticmethod
    def bin_width(tobject):
        """replace bin value with width of bin and set bin error to zero"""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject)

        for _bin_proxy in _new_tobject:
            _bin_proxy.value, _bin_proxy.error = _bin_proxy.x.width, 0

        return _new_tobject

    @staticmethod
    def max(*tobjects):
        """binwise `max` for a collection of histograms with identical binning"""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobjects[0], "e")

        _tobj_clones = []
        for _tobj in tobjects:
            _tobj_clones.append(_ROOTObjectFunctions._project_or_clone(_tobj, "e"))

        for _bin_proxies in zip(_new_tobject, *_tobj_clones):
            _argmax = max(range(1, len(_bin_proxies)), key=lambda idx: _bin_proxies[idx].value)
            _bin_proxies[0].value = _bin_proxies[_argmax].value
            _bin_proxies[0].error = _bin_proxies[_argmax].error

        # cleanup
        for _tobj_clone in _tobj_clones:
            _tobj_clone.Delete()

        return _new_tobject

    @staticmethod
    def max_val_min_err(*tobjects):
        """binwise 'max' on value followed by a binwise 'min' on error."""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobjects[0], "e")

        _tobj_clones = []
        for _tobj in tobjects:
            _tobj_clones.append(_ROOTObjectFunctions._project_or_clone(_tobj, "e"))

        for _bin_proxies in zip(_new_tobject, *_tobj_clones):
            _maxval = None
            for _bin_proxy in _bin_proxies[1:]:
                if _maxval is None or _bin_proxy.value > _maxval:
                    _maxval = _bin_proxy.value
                    _minerr = _bin_proxy.error
                elif _maxval == _bin_proxy.value and _bin_proxy.error < _minerr:
                    _minerr = _bin_proxy.error
            _bin_proxies[0].value = _maxval
            _bin_proxies[0].error = _minerr

        # cleanup
        for _tobj_clone in _tobj_clones:
            _tobj_clone.Delete()

        return _new_tobject

    @staticmethod
    def mask_if_less(tobject, tobject_ref):
        """set `tobject` bins and their errors to zero if their content is less than the value in `tobject_ref`"""

        assert len(tobject) == len(tobject_ref)

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject, "e")
        _new_tobject_ref = _ROOTObjectFunctions._project_or_clone(tobject_ref, "e")

        for _bin_proxy, _bin_proxy_ref in zip(_new_tobject, _new_tobject_ref):
            if hasattr(_bin_proxy, 'graph_'):
                # for TGraph etc.
                if _bin_proxy.y < _bin_proxy_ref.y:
                    _bin_proxy.y.value = 0
                    _bin_proxy.y.error_hi = 0
                    # 'low' error setter has a bug in rootpy. workaround:
                    _bin_proxy.graph_.SetPointEYlow(_bin_proxy.idx_, 0)
            else:
                # for TH1D etc.
                if _bin_proxy.value < _bin_proxy_ref.value:
                    _bin_proxy.value, _bin_proxy.error = 0, 0

        # cleanup
        _new_tobject_ref.Delete()

        return _new_tobject

    @staticmethod
    def double_profile(tprofile_x, tprofile_y):
        """creates a graph with points whose x and y values and errors are taken from the bins of two profiles with identical binning"""
        ## Note: underflow and overflow bins are discarded

        if len(tprofile_x) != len(tprofile_y):
            raise ValueError("Cannot build double profile: x and y profiles "
                             "have different number or bins ({} and {})".format(
                                len(tprofile_x)-2, len(tprofile_y)-2))

        _dp_graph = Graph(len(tprofile_x)-2, type='errors')  # symmetric errors

        _i_point = 0
        for _i_bin, (_bin_proxy_x, _bin_proxy_y) in enumerate(zip(tprofile_x, tprofile_y)):
            # disregard overflow/underflow bins
            if _i_bin == 0 or _i_bin == len(tprofile_x) - 1:
                continue

            if _bin_proxy_y.value:
                _dp_graph.SetPoint(_i_point, _bin_proxy_x.value, _bin_proxy_y.value)
                _dp_graph.SetPointError(_i_point, _bin_proxy_x.error, _bin_proxy_y.error)
                _i_point += 1

        # remove "unfilled" points
        while (_dp_graph.GetN() > _i_point):
            _dp_graph.RemovePoint(_dp_graph.GetN()-1)

        return _dp_graph

    @staticmethod
    def threshold_by_ref(tobject, tobject_ref):
        """set `tobject` bins to zero if their content is less than the value in `tobject_ref`, and to 1 otherwise.
        Result bin errors are always set to zero."""

        assert len(tobject) == len(tobject_ref)

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject)
        _new_tobject_ref = _ROOTObjectFunctions._project_or_clone(tobject_ref)

        for _bin_proxy, _bin_proxy_ref in zip(_new_tobject, _new_tobject_ref):
            if _bin_proxy.value < _bin_proxy_ref.value:
                _bin_proxy.value, _bin_proxy.error = 0, 0
            else:
                _bin_proxy.value, _bin_proxy.error = 1, 0

        # cleanup
        _new_tobject_ref.Delete()

        return _new_tobject

    @staticmethod
    def normalize_x(tobject):
        """Normalize bin contents of each x slice of a TH2D by dividing by the y integral over each x slice."""

        if not isinstance(tobject, _Hist2D):
            raise ValueError("Cannot apply function `normalize_x` to object of type '{}': must be Hist2D [TH2D]!".format(type(tobject)))

        _new_tobject = asrootpy(tobject.Clone())
        #_projection = asrootpy(tobject.ProjectionX(uuid.uuid4().get_hex(), 1, len(list(tobject.y()))-1))
        _projection = asrootpy(tobject.ProjectionX(uuid.uuid4().get_hex()))

        # divide 2D bin contents by integral over x slice (= result of ProjectionX())
        for _bin_proxy in _new_tobject:
            if _projection[_bin_proxy.xyz[0]].value:
                _bin_proxy.value /= _projection[_bin_proxy.xyz[0]].value
                _bin_proxy.error /= _projection[_bin_proxy.xyz[0]].value
            else:
                _bin_proxy.value, _bin_proxy.error = 0, 0

        _projection.Delete()  # cleanup

        return _new_tobject

    @staticmethod
    def unfold(th1d_input, th2d_response, th1d_marginal_gen, th1d_marginal_reco):
        """
        Use TUnfold to unfold a reconstructed spectrum.

        Parameters
        ----------
            th1d_input : `ROOT.TH1D`
                measured distribution to unfold

            th2d_response :`ROOT.TH2D`
                2D response histogram. Contains event numbers
                per (gen, reco) bin **after** rejecting spurious reconstructions
                and accounting for losses due to the reco acceptance.
                Gen bins should be on the `x` axis.
                Overflow/underflow should not be present and will be ignored!
                Acceptance losses and spurious reconstructions ("fakes") are
                inferred from the difference between the projections of the response
                and the full marginal distributions, which are given separately.

            th1d_marginal_gen : `ROOT.TH1D`
                marginal distribution on gen-level
                Contains event numbers per gen bin, **without** accounting for
                losses due to detector acceptance. The losses are inferred
                by comparing to the projection of the 2D response histogram,
                where these losses are accounted for.

            th1d_marginal_reco : `ROOT.TH1D`
                marginal distribution on reco-level
                Contains event numbers per reco bin, **without** subtracting
                spurious reconstructions ("fakes"). The fakes are inferred
                by comparing to the projection of the 2D response histogram,
                where these fakes are not present.
        """

        # input sanity checks
        _nbins_gen = th1d_marginal_gen.GetNbinsX()
        _nbins_reco = th1d_marginal_reco.GetNbinsX()
        assert(th2d_response.GetNbinsX() == _nbins_gen)
        assert(th2d_response.GetNbinsY() == _nbins_reco)
        assert(th2d_response.GetNbinsY() == _nbins_reco)

        _th2d_response_clone = asrootpy(th2d_response.Clone())
        _th1d_input_clone = asrootpy(th1d_input.Clone())

        # determine relative fake rate per reco. bin
        _th1d_true_reco = asrootpy(th2d_response.ProjectionY(uuid.uuid4().get_hex()))
        _th1d_true_fraction_reco = _th1d_true_reco / th1d_marginal_reco

        # determine absolute number of lost events per gen. bin
        _th1d_accepted_gen = asrootpy(th2d_response.ProjectionX(uuid.uuid4().get_hex()))
        _th1d_rejected_gen = th1d_marginal_gen - _th1d_accepted_gen

        # correct reco. distribution for fakes using inferred true fraction
        _th1d_input_clone = _th1d_true_fraction_reco * _th1d_input_clone

        # get rid of gen-level underflow/overflow:
        for _bin_proxy in _th2d_response_clone:
            # get rid of gen-level underflow/overflow:
            if _bin_proxy.xyz[0] == 0 or _bin_proxy.xyz[0] == _nbins_gen + 1:
                _bin_proxy.value = 0
                _bin_proxy.error = 0
            # fill losses into y-underflow bins
            elif _bin_proxy.xyz[1] == 0:
                _bin_proxy.value = _th1d_rejected_gen[_bin_proxy.xyz[0]].value
                _bin_proxy.error = _th1d_rejected_gen[_bin_proxy.xyz[0]].error

        # construct TUnfold instance and perform unfolding
        _tunfold = ROOT.TUnfold(
            _th2d_response_clone,
            ROOT.TUnfold.kHistMapOutputHoriz,   # gen-level on x axis
            ROOT.TUnfold.kRegModeNone,          # no regularization
            ROOT.TUnfold.kEConstraintNone,      # no constraints
        )
        _tunfold.SetInput(
            _th1d_input_clone
        )
        _tunfold.DoUnfold(0)

        _th1d_output = th2d_response.ProjectionX(uuid.uuid4().get_hex()) #tobject_reco.Clone()
        _tunfold.GetOutput(_th1d_output)

        _th2d_response_clone.Delete()
        _th1d_input_clone.Delete()
        _th1d_accepted_gen.Delete()
        _th1d_rejected_gen.Delete()
        _th1d_true_reco.Delete()
        _th1d_true_fraction_reco.Delete()
        _tunfold.Delete()

        return asrootpy(_th1d_output)

    @staticmethod
    def normalize_to_ref(tobject, tobject_ref):
        """Normalize `tobject` to the integral over `tobject_ref`."""

        _new_tobject = asrootpy(tobject.Clone())
        if tobject.integral():
            _factor = float(tobject_ref.integral()) / float(tobject.integral())

            return _new_tobject * _factor
        else:
            return _new_tobject

    @staticmethod
    def cumulate(tobject):
        """Make value of n-th bin equal to the sum of all bins up to and including n (but excluding underflow bins)."""
        #                                     forward  suffix
        return asrootpy(tobject.GetCumulative(True,    uuid.uuid4().get_hex()))

    @staticmethod
    def cumulate_reverse(tobject):
        """Make value of n-th bin equal to the sum of all bins from n up to and inclufing the last bin (but excluding overflow bins)."""
        #                                     forward  suffix
        return asrootpy(tobject.GetCumulative(False,   uuid.uuid4().get_hex()))

    @staticmethod
    def bin_differences(tobject):
        """Make value of n-th bin equal to the difference between the n-th and (n-1)-th bins."""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject, "e")

        for _bin_proxy, _bin_proxy_1, _bin_proxy_2 in zip(_new_tobject[1:-2], tobject[2:-1], tobject[1:-2]):
            if _bin_proxy_2.value:
                _bin_proxy.value = _bin_proxy_1.value - _bin_proxy_2.value
            else:
                _bin_proxy.error = 0
            _bin_proxy.error = 0

        return _new_tobject

    @staticmethod
    def bin_ratios(tobject):
        """Make value of n-th bin equal to the ratio between the n-th and (n-1)-th bins."""

        _new_tobject = _ROOTObjectFunctions._project_or_clone(tobject, "e")

        for _bin_proxy, _bin_proxy_num, _bin_proxy_denom in zip(_new_tobject[1:-2], tobject[2:-1], tobject[1:-2]):
            if _bin_proxy_denom.value:
                _bin_proxy.value = _bin_proxy_num.value / _bin_proxy_denom.value
            else:
                _bin_proxy.error = 0
            _bin_proxy.error = 0

        return _new_tobject


class InputROOTFile(object):
    """An input module for accessing objects from a single ROOT file.

    Multiple objects can be requested. They will be all be retrieved
    simultaneously and cached on the first subsequent call to `get()`.
    The file will thus only be opened once.

    Usage example:

    .. code:: python

       m = InputROOTFile('/path/to/rootfile.root')

       m.request(dict(object_path='MyDirectory/myObject'))
       my_object = m.get('MyDirectory/myObject')
    """

    def __init__(self, filename):
        self._filename = filename
        self._outstanding_requests = dict()
        self._plot_data_cache = dict()

    def _process_outstanding_requests(self):
        # if no requests, return immediately
        if not self._outstanding_requests:
            return

        # process outstanding requests
        with root_open(self._filename) as _tfile:
            for tobj_path, request_spec in six.iteritems(self._outstanding_requests):
                _rebin_factor = request_spec.pop('rebin_factor', None)
                _profile_error_option = request_spec.pop('profile_error_option', None)

                _tobj = _tfile.Get(tobj_path)
                # for histograms: move to global directory
                try:
                    _tobj.SetDirectory(0)
                except AttributeError:
                    # call not needed to other objects
                    pass
                #print(tobj_path, _tobj)

                # aply rebinning (if requested)
                if _rebin_factor is not None:
                    _tobj.Rebin(_rebin_factor)

                # set TProfile error option (if requested)
                if _profile_error_option is not None:
                    # TOOD: check if profile?
                    _tobj.SetErrorOption(_profile_error_option)

                self._plot_data_cache[tobj_path] = _tobj

        self._outstanding_requests = dict()


    def get(self, object_path):
        """
        Get an object.

        Parameters
        ----------
            object_path : string, path to resource in ROOT file (e.g. "directory/object")
        """
        # request object if not present
        if object_path not in self._outstanding_requests:
            if object_path not in self._plot_data_cache:
                self.request([dict(object_path=object_path)])

        # process request if object is waiting
        if object_path in self._outstanding_requests:
            self._process_outstanding_requests()

        # return object
        return self._plot_data_cache[object_path]

    def request(self, request_specs):
        """
        Request an object. Requested objects are all retrieved in one
        go when one of them is retrived via 'get()'

        Parameters
        ----------
            request_specs : e.g. dict(object_path="directory/object")
        """
        for request_spec in request_specs:
            _object_path = request_spec.pop('object_path')
            _force_rerequest = request_spec.pop('force_rerequest', True)

            # override earlier request iff 'force_rerequest' is True
            if (not (_object_path in self._outstanding_requests or _object_path in self._plot_data_cache)) or _force_rerequest:
                self._outstanding_requests[_object_path] = request_spec

                if _object_path in self._plot_data_cache:
                    del self._plot_data_cache[_object_path]

    def clear(self):
        """
        Remove all cached data and outstanding requests.
        """
        self._plot_data_cache = {}
        self._outstanding_requests = {}


[docs]class InputROOT(object): """An input module for accessing objects from multiple ROOT files. A nickname can be registered for each file, which then allows object retrieval by prefixing it to the object path (i.e. ``<file_nickname>:<object_path_in_file>``). Single-file functionality is delegated to child :py:class:`~DijetAnalysis.PostProcessing.Palisade.InputROOTFile` objects. """ operators = { ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul, ast.Div: op.truediv, ast.Pow: op.pow, ast.BitXor: op.xor, ast.USub: op.neg } # input functions (meant to be applied to ROOT objects in files) functions = dict( _ROOTObjectFunctions.get_all(), # add some useful aliases **{ 'h': _ROOTObjectFunctions.project_x, # alias 'hist': _ROOTObjectFunctions.project_x, # alias } ) # class-level cache for storing memoized function results _cache = {} def __init__(self, files_spec=None): """ Parameters ---------- files_spec : `dict`, optional specification of file nicknames (keys) and paths pointed to (values). Can be omitted (files can be added later via :py:class:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.add_file`) Usage example: .. code:: python m = InputROOT() # add a file and register a nickname for it m.add_file('/path/to/rootfile.root', nickname='file0') # optional: request object first (retrieves several objects at once) m.request(dict(file_nickname='file0', object_path='MyDirectory/myObject')) # retrieve an object from a file my_object = m.get('file0:MyDirectory/myObject') # apply simple arithmetical expressions to objects my_sum_object = m.get_expr('"file0:MyDirectory1/myObject1" + "file0:MyDirectory2/myObject2"') # use basic functions in expressions my_object_noerrors = m.get_expr('discard_errors("file0:MyDirectory1/myObject1")') # register user-defined input functions InputRoot.add_function(my_custom_function) # `my_custom_function` defined beforehand # use function in expression: my_function_result = m.get_expr('my_custom_function("file0:MyDirectory1/myObject1")') """ self._input_controllers = {} self._file_nick_to_realpath = {} self._locals = {} if files_spec is not None: for _nickname, _file_path in six.iteritems(files_spec): self.add_file(_file_path, nickname=_nickname) def _get_input_controller_for_file(self, file_spec): '''get input controller for file specification. handle nickname resolution''' _file_realpath = self._file_nick_to_realpath.get(file_spec, file_spec) if _file_realpath not in self._input_controllers: raise ValueError("Cannot get input controller for file specification '{}': have you added a file with this nickname or path?".format(file_spec)) return self._input_controllers[_file_realpath] @staticmethod def _get_file_nickname_and_obj_path(object_spec): _file_nickname, _object_path_in_file = object_spec.split(':', 1) return _file_nickname, _object_path_in_file
[docs] @classmethod def add_function(cls, function=None, name=None, override=False, memoize=False): '''Register a user-defined input function. Can also be used as a decorator. .. note:: Functions are registered **globally** in the ``InputROOT`` **class** and are immediately available to all ``InputROOT`` instances. Parameters ---------- function : `function`, optional function or method to add. Can be omitted when used as a decorator. name : `str`, optional function name. If not given, taken from ``function.__name__`` override : `bool`, optional if ``True``, allow existing functions to be overridden (*default*: ``False``) memoize : `bool`, optional if ``True``, store function result in a cache on first call. For every subsequent call with identical arguments, the result will be retrieved from the cache instead of evaluating the function again. (*default*: ``False``) Usage examples: * as a simple decorator: .. code:: python @InputROOT.add_function def my_function(rootpy_object): ... * to override a function that has already been registered: .. code:: python @InputROOT.add_function(override=True) def my_function(rootpy_object): ... * to register a function under a different name: .. code:: python @InputROOT.add_function(name='short_name') def very_long_fuction_name_we_do_not_want_to_use_in_expressions(rootpy_object): ... * as a method: .. code:: python InputROOT.add_function(my_function) .. note:: All *Palisade* processors (and especially the :py:class:`~DijetAnalysis.PostProcessing.Lumberjack.PlotProcessor`) expect the objects returned by functions to be valid *rootpy* objects. When implementing user-defined functions, make sure to convert "naked" ROOT (PyROOT) objects by wrapping them in :py:func:`rootpy.asrootpy` before returning them. ''' if name is None and function is not None: name = function.__name__ if name in cls.special_functions: raise ValueError("Cannot add user-defined ROOT function with name " "'{}': name is reserved for internal use only!".format(name)) if not override and name in cls.functions: raise ValueError("Cannot add user-defined ROOT function with name " "'{}': it already exists and `override` not explicitly allowed!".format(name)) if memoize: def memoize(f): @functools.wraps(f) def _memoized_function(*args, **kwargs): # compute unique hash key for the argument structure key = HashableMap(func=f, args=args, kwargs=kwargs) # look up in cache if key in cls._cache: # return if found return cls._cache[key] # compute and store if not found _result = f(*args, **kwargs) cls._cache[key] = _result return _result return _memoized_function def _decorator(f): # replace 'f' with a version enabling memoization of results f = memoize(f) # add memoized function to mapping cls.functions[name or f.__name__] = f return f else: def _decorator(f): # add user-specified function to mapping cls.functions[name or f.__name__] = f return f if function is not None: # default decorator call, i.e. '@InputROOT.add_function' return _decorator(function) else: # decorator call with parameters , e.g. '@InputROOT.add_function(override=True)' return _decorator
[docs] @classmethod def get_function(cls, name): '''Retrieve a defined input function by name. Returns `None` if no such function exists.''' return cls.functions.get(name, None)
[docs] def add_file(self, file_path, nickname=None): """ Add a ROOT file. Parameters ---------- file_path : `str` path to ROOT file nickname : `str`, optional file nickname. If not given, ``file_path`` will be used """ # determine real (absolute) path for file if '://' in file_path: # keep URL-like paths as they are _file_realpath = file_path else: _file_realpath = os.path.realpath(file_path) # register file name (as given) as file nickname self._file_nick_to_realpath[file_path] = _file_realpath # register nickname (if given) if nickname is not None: if nickname in self._file_nick_to_realpath: raise ValueError("Cannot add file for nickname '{}': nickname already registered for file '{}'".format(nickname, filename)) self._file_nick_to_realpath[nickname] = _file_realpath # create controller for file if _file_realpath not in self._input_controllers: self._input_controllers[_file_realpath] = InputROOTFile(_file_realpath)
[docs] def get(self, object_spec): """ Get an object from one of the registered files. .. tip:: If calling :py:meth:`get` on multiple objects (e.g. in a loop), consider issuing a :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.request` call for all objects beforehand. The first call to :py:meth:`get` will then retrieve all requested objects in one go, opening and closing the file only once. Parameters ---------- object_spec : `str` file nickname and path to object in ROOT file, separated by a colon, e.g. :py:data:`"file_nickname:directory/object"` """ _file_nickname, _object_path_in_file = self._get_file_nickname_and_obj_path(object_spec) _ic = self._get_input_controller_for_file(_file_nickname) return _ic.get(_object_path_in_file)
[docs] def request(self, request_specs): """ Request objects from the registered files. Requests for objects are stored until :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.get` is called for one of the objects. All requested objects are then be retrieved in one go and cached. Parameters ---------- request_specs : `list` of `dict` each `dict` represents a request for *one* object from *one* file. A request dict must have either a key :py:const:`object_spec`, which contains both the file nickname and the path to the object within the file (separated by a colon, ``:``), or two keys :py:const:`file_nickname` and :py:const:`object_path` specifying these separately. The following requests behave identically: * :py:data:`dict(file_nickname='file0', object_path="directory/object")` * :py:data:`dict(object_spec="file0:directory/object")` """ _delegations = {} for request_spec in request_specs: _file_nickname = request_spec.pop('file_nickname', None) _object_path_in_file = request_spec.pop('object_path', None) _object_spec = request_spec.pop('object_spec', None) if not ((_object_spec is not None) == ((_file_nickname is None) and (_object_path_in_file is None))): raise ValueError("Invalid request: must either contain both 'file_nickname' and 'object_path' keys or an 'object_spec' key, but contains: {}".format(request_spec.keys())) if _object_spec is not None: _file_nickname, _object_path_in_file = self._get_file_nickname_and_obj_path(_object_spec) if _file_nickname not in _delegations: _delegations[_file_nickname] = [] _delegations[_file_nickname].append(dict(object_path=_object_path_in_file, **request_spec)) for _file_nickname, _requests in six.iteritems(_delegations): _ic = self._get_input_controller_for_file(_file_nickname) _ic.request(_requests)
[docs] def get_expr(self, expr, locals={}): """ Evaluate an expression involving objects retrieved from file(s). The string given must be a valid Python expression. Strings contained in the expression are interpreted as specifications of objects in files (see :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.get` for the object specification syntax). Before the expression is evaluated, all strings are replaced by the objects they refer to. To interpret a string as a *literal string* (i.e. not referring to an object in a file), it must be wrapped inside the special function ``str``. Any *functions* called in the expression must have been defined beforehand using :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.add_function`. There are a number of special functions, which behave as follows: * ``str``: interpret a string as a *literal string* * ``no_input``: interpret *all* strings encountered anywhere inside the function call as literal strings * ``input``: interpret *all* strings encountered anywhere inside the function call as specifications of objects in files All *Python identifiers* used in the expression are interpreted as local variables. A map specifying the values of local variables for this call to `get_expr` can be given via the keyword argument `locals`. Alternatively, local variables can be registered for use by all calls to :py:meth:`get_expr` by calling :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.register_local` beforehand. Variables given in the `locals` dictionary will take precedence over those defined via :py:meth:`register_local`. Local variable lookup can be disabled completely by passing ``locals=None``. Parameters ---------- expr : `str` valid Python expression locals : `dict` or ``None`` (default: ``{}``) mapping of local variable names that may appear in `expr` to values. These will override local variable values specified beforehand using :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.register_local` before calling this method. If :py:const:`None`, local variable lookup is disabled and a :py:exc:`NameError` will be raised if an identifier is encountered in the expression. Usage examples: .. code:: python my_result = my_input_root.get_expr( 'my_function(' # this function gets called '"my_file:path/to/my_object",' # this gets replaced by ROOT object '42' # this argument is passed literally ')' ) .. code:: python # register a local variable and assign it a value my_input_root.register_local('local_variable', 42) my_result = my_input_root.get_expr( 'my_function(' # this function gets called '"my_file:path/to/my_object",' # this gets replaced by ROOT object 'local_variable,' # this gets replaced by its assigned value ')' ) .. tip:: Writing expressions inside a single string can get very convoluted. To maintain legiblity, the expression string can be spread out on several lines, by taking advantage of Python's automatic string concatenation inside parentheses (see above). Alternatively, triple-quoted strings can be used. """ _locals = {} if locals is not None: _locals = dict(self._locals, **locals) expr = expr.strip() # extraneous spaces otherwise interpreted as indentation self._request_all_objects_in_expression(expr) _result = self._eval(node=ast.parse(expr, mode='eval').body, ctx=dict(operators=self.operators, functions=self.functions, locals=_locals, input=True)) # raise exceptions unable to be raised during `_eval` for technical reasons # (e.g. due to expressions with self-referencing local variables that would # cause infinite recursion) if isinstance(_result, Exception): raise _result return _result
def _request_all_objects_in_expression(self, expr, **other_request_params): """Walk through the expression AST and request an object for each string or identifier""" _ast = ast.parse(expr, mode='eval') _reqs = [] for _node in ast.walk(_ast): if isinstance(_node, ast.Name): _obj_spec = _node.id elif isinstance(_node, ast.Str): _obj_spec = _node.s else: continue if ':' in _obj_spec: _reqs.append(dict(object_spec=_obj_spec, force_rerequest=False, **other_request_params)) self.request(_reqs)
[docs] def register_local(self, name, value): """ Register a local variable to be used when evaluating an expression using :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.get_expr` Parameters ---------- name : `str` valid Python identifier. Must not have been registered before. value : any Python object to be made accessible in expressions under :py:const:`name`. """ try: assert name not in self._locals except AssertionError as e: print("[ERROR] Cannot register local '{}' with value {}! It already exists and is: {}".format(name, value, self._locals[name])) raise e self._locals[name] = value
[docs] def clear_locals(self): """ Clear all locals defined via :py:meth:`~DijetAnalysis.PostProcessing.Palisade.InputROOT.register_local`. """ self._locals = dict()
def _eval(self, node, ctx): """Evaluate an AST node""" if node is None: return None elif isinstance(node, ast.Name): # <identifier> # lookup identifiers in local namespace if node.id in ctx['locals']: _local = ctx['locals'][node.id] # if local variable contains a list, evaluate each element by threading 'get_expr' over it if isinstance(_local, list): _retlist = [] for _local_el in _local: # non-string elements are simply passed through if not isinstance(_local_el, str): _retlist.append(_local_el) continue # string-valued elements are evaluated try: # NOTE: local variable lookup is disabled when threading # over lists that were stored in local variables themselves. # This is done to prevent infinite recursion errors for # expressions which may reference themselves _ret_el = self.get_expr(_local_el, locals=None) except NameError as e: # one element of the list references a local variable # -> stop evaluation and return dummy # use NameError object instead of None to identifiy # dummy elements unambiguously later _retlist.append(e) else: # evaluation succeeded _retlist.append(_ret_el) return _retlist # local variables containing strings are parsed elif isinstance(_local, str): return self.get_expr(_local, locals=None) # all other types are simply passed through else: return _local # if no local is found, try a few builtin Python literals elif node.id in ('True', 'False', 'None'): # restrict subset of supported literals return ast.literal_eval(node.id) # returns corresponding Python literal from string # if nothing above matched, assume mistyped identifier and give up # NOTE: do *not* assume identifier is a ROOT file path. ROOT file paths # must be given explicitly as strings. else: raise NameError("Cannot resolve identifier '{}': not a valid Python literal or a registered local variable!".format(node.id)) elif isinstance(node, ast.Str): # <string> : array column if ctx['input']: # lookup in ROOT file return self.get(node.s) else: # return string as-is return node.s elif isinstance(node, ast.Num): # <number> return node.n elif isinstance(node, ast.Call): # node names containing parentheses (interpreted as 'Call' objects) # -- determine function to call # function handle is a simple identifier if isinstance(node.func, ast.Name): # handle special functions if node.func.id in self.special_functions: _spec_func_spec = self.special_functions[node.func.id] # callable for special function (default to no-op) _callable = _spec_func_spec.get('func', lambda x: x) # modify avaluation context for special function ctx = dict(ctx, **_spec_func_spec.get('ctx', {})) # call a registered input function else: try: _callable = ctx['functions'][node.func.id] except KeyError as e: raise KeyError( "Cannot call input function '{}': no such " "function!".format(node.func.id)) # function handle is an expression else: # evaluate 'func' as any other node _callable = self._eval(node.func, ctx) # evaluate unpacked positional arguments, if any _starargs_values = [] if node.starargs is not None: _starargs_values = self._eval(node.starargs, ctx) # starred kwargs (**) not supported for the moment if node.kwargs: raise NotImplementedError( "Unpacking keyword arguments in expressions via " "** is not supported. Expression was: '{}'".format( ast.dump(node, annotate_fields=False))) # evaluate arguments _args = map(lambda _arg: self._eval(_arg, ctx), node.args) + _starargs_values _kwargs = { _keyword.arg : self._eval(_keyword.value, ctx) for _keyword in node.keywords } # call function return _callable(*_args, **_kwargs) elif isinstance(node, ast.BinOp): # <left> <operator> <right> return ctx['operators'][type(node.op)](self._eval(node.left, ctx), self._eval(node.right, ctx)) elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 return ctx['operators'][type(node.op)](self._eval(node.operand, ctx)) elif isinstance(node, ast.Subscript): # <operator> <operand> e.g., -1 if isinstance(node.slice, ast.Index): # support subscripting via simple index return self._eval(node.value, ctx)[self._eval(node.slice.value, ctx)] elif isinstance(node.slice, ast.Slice): # support subscripting via slice return self._eval(node.value, ctx)[self._eval(node.slice.lower, ctx):self._eval(node.slice.upper, ctx):self._eval(node.slice.step, ctx)] else: raise TypeError(node) elif isinstance(node, ast.Attribute): # <value>.<attr> return getattr(self._eval(node.value, ctx), node.attr) elif isinstance(node, ast.List): # list of node names return [self._eval(_el, ctx) for _el in node.elts] elif isinstance(node, ast.Tuple): # tuple of node names return tuple(self._eval(_el, ctx) for _el in node.elts) else: raise TypeError(node) @classmethod def clear_cache(cls): cls._cache = {} # functions with special meanings/side effects # when encountered in expressions, these functions can change the # behavior of the evaluation of descendant nodes or access functionality # of the InputROOT object itself special_functions = { # enable ROOT input when evaluating descendant nodes 'input': dict(ctx=dict(input=True)), # disable ROOT input when evaluating descendant nodes 'no_input': dict(ctx=dict(input=False)), # get argument as string (i.e. without ROOT input) 'str': dict(func=str, ctx=dict(input=False)), }