Skip to content

Instantly share code, notes, and snippets.

@agoose77
Last active January 18, 2023 16:05
Show Gist options
  • Save agoose77/f755d96061a36795dda592276d63f506 to your computer and use it in GitHub Desktop.
Save agoose77/f755d96061a36795dda592276d63f506 to your computer and use it in GitHub Desktop.

README

Binder

To run on Binder, right-click the notebook in JupyterLab's file browser, and select "Open with Notebook"

jupytext kernelspec
text_representation
extension format_name format_version jupytext_version
.md
myst
0.13
1.14.4
display_name language name
Python 3 (ipykernel)
python
python3

Hist plotting protocols

The syntax to construct and fill a multidimensional histogram can feel cumbersome in an analysis. It particularly harms readability, e.g.

x, y = [ak.flatten(arr, axis=None) for arr in ak.broadcast_arrays(*arrays)]
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill(
    x=x, y=y
)

It would be nice for the user to be able to opt-in to some convenient short-hand that unpacks, broadcasts, and flattens, e.g.

hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    record_array
)

or

hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    *arrays
)
import awkward as ak
import hist
import numpy as np

Broadcasting and Unpacking

The idea here is that Hist-aware libraries can implement methods to "unpack" (decompose) an array into its members, and "broadcast-flatten" an array to produce a set of compatible 1D arrays. This can be done through two APIs: an explicit stateful registration, and a namespace protocol.

_histogram_module_ protocol member should refer to an object that implements two methods. These methods provide Hist with broadcasted, flattened, unpacked arrays. We also implement a fallback system for off-class defined implementations.

_histogram_modules = {}


def histogram_module_for(cls):
    def wrapper(obj):
        _histogram_modules[cls] = obj
        return obj

    return wrapper


def find_histogram_modules(*objects):
    for arg in objects:
        try:
            yield arg._histogram_module_
        except AttributeError:
            # Find class exactly, or check subclasses
            for cls in type(arg).__mro__:
                try:
                    yield _histogram_modules[cls]
                except KeyError:
                    continue


def unpack(obj):
    for module in find_histogram_modules(obj):
        return module.unpack(obj)
    raise TypeError(f"No histogram module found for {obj!r}")


def broadcast_and_flatten(args):
    for module in find_histogram_modules(*args):
        result = module.broadcast_and_flatten(args)
        if result is not NotImplemented:
            return result

    raise TypeError(f"No histogram module found for {args!r}")

Awkward Arrays

As we control the ak.Array namespace, a convenient no-import implementation of the plotting protocol is just to define ak.Array._histogram_module_.

A "proper" implementation of broadcast_and_flatten should return NotImplemented if it cannot work with any of the args. This is similar to NumPy's NEP-18, and ensures that other plotting modules can attempt the broadcasting instead.

class AwkwardHistogramModule:
    @staticmethod
    def unpack(self):
        if not ak.fields(self):
            return None
        else:
            return {
                k: ak.flatten(x, axis=None)
                for k, x in zip(ak.fields(self), ak.broadcast_arrays(*ak.unzip(self)))
            }

    @staticmethod
    def broadcast_and_flatten(args):
        arrays = [ak.Array(x) for x in args]
        assert all([not x.fields for x in arrays])
        return tuple(
            [
                ak.to_numpy(ak.flatten(x, axis=None))
                for x in ak.broadcast_arrays(*arrays)
            ]
        )


ak.Array._histogram_module_ = AwkwardHistogramModule
example = ak.zip(
    {"x": [1, 2, 3, 4], "y": [[1, 2, 3], [4], [5, 6], [7, 8, 9]]},
    depth_limit=1,
)
example
unpack(example)
broadcast_and_flatten((example.x, example.y))

NumPy Arrays

example_numpy = np.recarray((4, 2), dtype=[("x", np.int64), ("y", np.int64)])
example_numpy[:] = np.arange(4)[:, np.newaxis]
example_numpy

NumPy doesn't provide an easy means of writing to the ndarray class, so instead we'll explicitly register support.

@histogram_module_for(np.ndarray)
class NumpyHistogramModule:
    @staticmethod
    def unpack(obj):
        if obj.dtype.names is None:
            return None
        else:
            contents = [obj[n] for n in obj.dtype.fields]
            return {
                k: np.ravel(x)
                for k, x in zip(obj.dtype.fields, np.broadcast_arrays(*contents))
            }

    @staticmethod
    def broadcast_and_flatten(args):
        arrays = []
        for arg in args:
            # If we can't interpret this argument, it's not NumPy-friendly!
            try:
                arrays.append(np.asarray(arg))
            except (TypeError, ValueError):
                return NotImplemented

        return tuple([np.ravel(x) for x in np.broadcast_arrays(*arrays)])
unpack(example_numpy)
broadcast_and_flatten((example_numpy.x, example_numpy.y))

Filling

Now we can put the plotting module to good use in defining fill_flattened:

def _hist_fill_flattened_impl(self, *args, **kwargs):
    axis_names = {ax.name for ax in self.axes}

    # Single arguments are either arrays for single-dimensional histograms, or 
    # structures for multi-dimensional hists that must first be unpacked
    if len(args) == 1 and not kwargs:
        (arg,) = args
        unpacked = unpack(arg)
        # Try to unpack the array, if it's valid to do so, i.e. is the Awkward Array a record array?
        if unpacked is None:
            # Can't unpack, fall back on broadcasting single array (to flatten and convert)
            as_tuple = broadcast_and_flatten((arg,))
            return self.fill(*as_tuple)
        else:
          # Result must be broadcast, so unpack and rebuild
            as_tuple = broadcast_and_flatten(tuple(unpacked.values()))
            as_dict = {k: v for k, v in zip(unpacked, as_tuple) if k in axis_names}
            return self.fill(**as_dict)
    # Multiple args: broadcast and flatten!
    else:
        inputs = tuple([*args, *kwargs.values()])
        arrays = broadcast_and_flatten(inputs)
        new_args = arrays[:len(args)]
        new_kwargs = {k: v for k, v in zip(kwargs, arrays[len(args):])}
        return self.fill(*new_args, **new_kwargs)


hist.Hist.fill_flattened = _hist_fill_flattened_impl
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    x=example_numpy["x"], y=example_numpy["y"]
)
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    example_numpy
)
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(example)
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    example.x, example.y
)
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    np.arange(4), y=example.y
)
hist.Hist.new.Int(0, 10, name="x").Int(0, 10, name="y").Int64().fill_flattened(
    example.x, y=example.y
)
hist.Hist.new.Int(0, 10, name="x").Int64().fill_flattened(example)
awkward>2
hist
numpy
jupyterlab
jupytext
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment