Skip to content

Instantly share code, notes, and snippets.

Last active January 27, 2022 21:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tacaswell/95177903175dbc28be5353b4a0e5118f to your computer and use it in GitHub Desktop.
Save tacaswell/95177903175dbc28be5353b4a0e5118f to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib
import matplotlib.lines
from matplotlib.artist import allow_rasterization
import matplotlib.pyplot as plt
class MatplotlibException(Exception):
class InvalidDatasource(MatplotlibException, ValueError):
class SimpleSource:
def __init__(self, **kwargs):
self._data = {k: np.asanyarray(v) for k, v in kwargs.items()} = {k: {"ndim": v.ndim, "dtype": v.dtype} for k, v in self._data.items()}
def get(self, keys, ax=None, renderer=None):
return {k: self._data[k] for k in keys}
class DFSource:
def __init__(self, df, **kwargs):
self._remapping = kwargs
self._data = df = {k: {"ndim": 1, "dtype": df[v].dtype} for k, v in kwargs.items()}
def get(self, keys, ax, renderer):
return {k: self._data[self._mapping[k]] for k in keys}
class FuncSource1D:
def __init__(self, func):
self._func = func = {"x": {"ndim": 1, "dtype": float}, "y": {"ndim": 1, "dtype": float}}
def get(self, keys, ax, renderer):
assert set(keys) == set(
xlim = ax.get_xlim()
bbox = ax.get_window_extent(renderer)
xpixels = bbox.width
x = np.linspace(*xlim, xpixels)
return {"x": x, "y": self._func(x)}
class DSLine2D(matplotlib.lines.Line2D):
def __init__(self, DS, **kwargs):
if not all(k in for k in ("x", "y")):
raise InvalidDatasource
self._DS = DS
super().__init__([], [], **kwargs)
def draw(self, renderer):
data = self._DS.get({"x", "y"}, self.axes, renderer)
super().set_data(data["x"], data["y"])
return super().draw(renderer)
ax = plt.gca()
DS = SimpleSource(x=np.linspace(0, 10, 100), y=np.sin(np.linspace(0, 10, 100)))
DS2 = FuncSource1D(lambda x: np.cos(x) + 1)
dsl = DSLine2D(DS, color="red")
dsl2 = DSLine2D(DS2, color="blue")
Copy link

we should maybe talk to @shoyer about how xarray does this? They've apparently got a super clean model under the hood...
And I'm thinking for many users the datasource stuff will get hidden in @process_data?

Copy link

I suspect we will want an easy way to wrap an xarray, but I am not sure that xarray can be the whole data model (as I am not sure how to fit things like the function source into it).

I could see expecting the source to provide a dict-of-arrays alike (ex pandas or xarray) back from the get call (instead of the artists calling them n times)?

Copy link

Not so much xarray as data model, more stealing ideas from their architecture (same really with dask on some of the functional ideas)

Copy link


Copy link

ianhi commented Jan 27, 2022

for simple indexing of an array

from typing import TYPE_CHECKING, Set
from numbers import Integral
from matplotlib.axes import Axes
class ArraySource1D:
    def __init__(self, array, scale=1) -> None:
        self._arr = array
        self._scale = scale
        if hasattr(self._arr, "vindex"):
            # account for zarr
            self._indexer = self._arr.vindex
            self._indexer = self._arr = {"x": {"ndim": 1, "dtype": float}, "y": {"ndim": 1, "dtype": float}}

    def scale(self) -> int:
        return self.scale

    def scale(self, value: int):
        if not isinstance(value, Integral):
            raise TypeError(f"scale must be integer values but is type {type(value)}")
        self._scale = value

    def get(self, keys: Set[str], ax: Axes, renderer):
        xlim = ax.get_xlim()
        xmin = np.max([int(xlim[0]), 0])
        xmax = np.max([int(xlim[1]), 0])
        x = np.arange(xmin, xmax, self._scale)

        return {"x": x, "y": self._indexer[x]}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment