Source code for pyirf.cuts

import numpy as np
from astropy.table import Table, QTable
from scipy.ndimage import gaussian_filter1d
import astropy.units as u

from .binning import calculate_bin_indices, bin_center

__all__ = [
    'calculate_percentile_cut',
    'evaluate_binned_cut',
    'compare_irf_cuts',
]


[docs] def calculate_percentile_cut( values, bin_values, bins, fill_value, percentile=68, min_value=None, max_value=None, smoothing=None, min_events=10, ): """ Calculate cuts as the percentile of a given quantity in bins of another quantity. Parameters ---------- values: ``~numpy.ndarray`` or ``~astropy.units.Quantity`` The values for which the cut should be calculated bin_values: ``~numpy.ndarray`` or ``~astropy.units.Quantity`` The values used to sort the ``values`` into bins edges: ``~numpy.ndarray`` or ``~astropy.units.Quantity`` Bin edges fill_value: float or quantity Value for bins with less than ``min_events``, must have same unit as values percentile: float The percentile to calculate in each bin as a percentage, i.e. 0 <= percentile <= 100. min_value: float or quantity or None If given, cuts smaller than this value are replaced with ``min_value`` max_value: float or quantity or None If given, cuts larger than this value are replaced with ``max_value`` smoothing: float or None If given, apply a gaussian filter of width ``sigma`` in terms of bins. min_events: int Bins with less events than this number are replaced with ``fill_value`` """ # create a table to make use of groupby operations # we use a normal table here to avoid astropy/astropy#13840 table = Table({"values": values}, copy=False) unit = table["values"].unit # make sure units match if unit is not None: fill_value = u.Quantity(fill_value).to(unit) if min_value is not None: min_value = u.Quantity(min_value).to_value(unit) if max_value is not None: max_value = u.Quantity(max_value).to_value(unit) bin_index, valid = calculate_bin_indices(bin_values, bins) by_bin = table[valid].group_by(bin_index[valid]) cut_table = QTable() cut_table["low"] = bins[:-1] cut_table["high"] = bins[1:] cut_table["center"] = bin_center(bins) cut_table["n_events"] = 0 cut_table["cut"] = np.asanyarray(fill_value, values.dtype) for bin_idx, group in zip(by_bin.groups.keys, by_bin.groups): # replace bins with too few events with fill_value n_events = len(group) cut_table["n_events"][bin_idx] = n_events if n_events < min_events: cut_table["cut"][bin_idx] = fill_value else: value = np.nanpercentile(group["values"], percentile) if min_value is not None or max_value is not None: value = np.clip(value, min_value, max_value) cut_table["cut"].value[bin_idx] = value if smoothing is not None: cut_table['cut'].value[:] = gaussian_filter1d( cut_table["cut"].value, smoothing, mode='nearest', ) return cut_table
[docs] def evaluate_binned_cut(values, bin_values, cut_table, op): """ Evaluate a binned cut as defined in cut_table on given events. Events with bin_values outside the bin edges defined in cut table will be set to False. Parameters ---------- values: ``~numpy.ndarray`` or ``~astropy.units.Quantity`` The values on which the cut should be evaluated bin_values: ``~numpy.ndarray`` or ``~astropy.units.Quantity`` The values used to sort the ``values`` into bins cut_table: ``~astropy.table.Table`` A table describing the binned cuts, e.g. as created by ``~pyirf.cuts.calculate_percentile_cut``. Required columns: - `low`: lower edges of the bins - `high`: upper edges of the bins, - `cut`: cut value op: callable(a, b) -> bool A function taking two arguments, comparing element-wise and returning an array of booleans. Must support vectorized application. Returns ------- result: np.ndarray[bool] A mask for each entry in ``values`` indicating if the event passes the bin specific cut given in cut table. """ if not isinstance(cut_table, QTable): raise ValueError('cut_table needs to be an astropy.table.QTable') bins = np.append(cut_table["low"], cut_table["high"][-1]) bin_index, valid = calculate_bin_indices(bin_values, bins) result = np.zeros(len(values), dtype=bool) result[valid] = op(values[valid], cut_table["cut"][bin_index[valid]]) return result
[docs] def compare_irf_cuts(cuts): """ checks if the same cuts have been applied in all of them Parameters ---------- cuts: list of QTables list of cuts each entry in the list correspond to one set of IRFs Returns ------- match: Boolean if the cuts are the same in all the files """ for i in range(len(cuts) - 1): if (cuts[i] != cuts[i + 1]).any(): return False return True