Skip to content

Instantly share code, notes, and snippets.

@mattbierbaum
Created April 23, 2016 22:53
Show Gist options
  • Save mattbierbaum/538425b8fd81ac85593d51884100f66c to your computer and use it in GitHub Desktop.
Save mattbierbaum/538425b8fd81ac85593d51884100f66c to your computer and use it in GitHub Desktop.
import pylab as pl
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.collections import LineCollection
def plot_colored_curve(x, y, t=None, linewidth=2, drawpoints=False, cmap=pl.cm.copper):
"""
Plot a curve that is colored based on index or 'time' value at each point
in that curve. [x,y,t] are the values y(x, t) which will be plotted. They
must be the same shape numpy array. If points for the data points are
required, set drawpoints=True.
"""
fig = pl.figure()
pl.figure(fig.number)
if t is None:
t = np.linspace(0,1,x.shape[0])
else:
t = Normalize()(t)
points = np.array([x,y]).transpose().reshape(-1,1,2)
segs = np.concatenate([points[:-1],points[1:]],axis=1)
lc = LineCollection(segs, cmap=cmap, linewidths=linewidth)
lc.set_array(t)
if drawpoints:
pl.scatter(x, y, c=np.arange(len(x)), linestyle='-',cmap=cmap)
pl.gca().add_collection(lc)
pl.show()
pl.draw()
def demo():
N = 10000
x = np.linspace(0, 1, N)
y = np.cumsum(np.random.rand(N)-0.5)
t = np.tanh(x)
plot_colored_curve(x, y, t)
pl.xlim(x.min(), x.max())
pl.ylim(y.min(), y.max())
pl.draw()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment