Skip to content

Instantly share code, notes, and snippets.

@sytrus-in-github
Created April 14, 2017 15:11
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sytrus-in-github/a3b2ef4414fb144cb08505a060c99b18 to your computer and use it in GitHub Desktop.
Save sytrus-in-github/a3b2ef4414fb144cb08505a060c99b18 to your computer and use it in GitHub Desktop.
example script for matplotlib(2.0.0) quiver 3d custom coloring for each arrow
# tested with python 2 + matplotlib 2.0.0
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
def getMaxXYZ(vec3ds):
mx, my, mz = vec3ds[0]
for (x, y, z) in vec3ds:
mx, my, mz = (max(x, mx), max(y, my), max(z, mz))
return mx, my, mz
def Vec3d2rgb(vec3ds):
"""map a list of 3d vectors to correcponding RGB color representation"""
vec3ds = [map(abs, (x,y,z)) for (x,y,z) in vec3ds]
# mx, my, mz = reduce(lambda a,b: (max(a[0],b[0]), max(a[1],b[1]), max(a[2],b[2])), vec3ds)
mx, my, mz = getMaxXYZ(vec3ds)
return [(x*1./mx, y*1./my, z*1./mz) for (x,y,z) in vec3ds]
def repeatForEach(elements, times):
# return [e for e in elements for _ in xrange(times)]
return [e for e in elements for _ in range(times)]
def renderColorsForQuiver3d(colors):
colors = filter(lambda x: x!=(0.,0.,0.), colors)
return colors + repeatForEach(colors, 2)
if __name__ == "__main__":
LINE_WIDTH=0.5
_displacementLabels = [(1.,0.,0.), (2.,0.,0.), (0.,1.,0.), (0.,2.,0.), (0.,0.,1.), (0.,0.,2.), (0.,0.,0.)]
displacementLabels = [map(lambda x:0.3*x, (x,y,z)) for (x,y,z) in _displacementLabels]
labelField = np.random.random_integers(0, 6, (3, 4, 5))
colorLabels = Vec3d2rgb(displacementLabels)
colors = [colorLabels[label] for label in labelField.flatten()]
fig = plt.figure()
ax = fig.gca(projection='3d')
x, y, z = labelField.shape
xs, ys, zs = np.mgrid[:x, :y, :z]
us = np.zeros(xs.shape)
vs = np.zeros(xs.shape)
ws = np.zeros(xs.shape)
# for i in xrange(x):
# for j in xrange(y):
# for k in xrange(z):
for i in range(x):
for j in range(y):
for k in range(z):
u, v, w = displacementLabels[labelField[i,j,k]]
us[i,j,k] = u
vs[i,j,k] = v
ws[i,j,k] = w
ax.quiver(xs, ys, zs, us, vs, ws, linewidths=LINE_WIDTH, colors=renderColorsForQuiver3d(colors))
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
@gauvinalexandre
Copy link

Thank you very much! Nice solution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment