Skip to content

Instantly share code, notes, and snippets.

Created November 30, 2015 12:06
Show Gist options
  • Save subnivean/c622cc2b58e6376263b8 to your computer and use it in GitHub Desktop.
Save subnivean/c622cc2b58e6376263b8 to your computer and use it in GitHub Desktop.
Python module that uses scipy.interpolate.RectBivariateSpline to implement 2D parametric surfaces in 3D space.
"""Implements a b-spline surface as a 3-tuple of
scipy.interpolate.RectBivariateSpline instances, one
each for x, y and z.
import math
import numpy as np
from scipy.interpolate import RectBivariateSpline
class BSplineSurf(object):
def __init__(self, u, v, xyz, ku=3, kv=3, bbox=[0, 1, 0, 1],
controlpts=False, U=None, V=None):
"""Parametric (u,v) surface approximation over a rectangular mesh.
u, v : array_like
1-D arrays of coordinates in strictly ascending order.
xyz : array_like
3-D array of (x, y, z) data with shape (3, u.size, v.size).
bbox : array_like, optional
Sequence of length 4 specifying the boundary of the rectangular
approximation domain. See scipy.interpolate.RectBivariateSpline
for more info.
ku, kv : ints, optional
Degrees of the bivariate spline. Default is 3 for each.
controlpts : boolean
Indicates if the xyz points being passed are points to spline
*through*, or are already control points as defined in some
other format (e.g. the points from 'stepparser', which
returns the control points as defined in the STEP file).
U, V : array_like, optional
Knot vectors in u and v direction, as parsed from a STEP
file or similar
if controlpts is True:
assert U is not None, \
"Knot vector `U` must be passed when `controlpts` is True"
assert V is not None, \
"Knot vector `V` must be passed when `controlpts` is True"
self._create_srf(u, v, xyz, ku, kv, bbox,
controlpts, U, V)
self.bbox = bbox
self.u = u
self.v = v
self.ku = ku
self.kv = kv
def __call__(self, *args, **kwargs):
"""Convenience to allow evaluation of a BSplineSurf
instance via `foosrf(0, 0)` instead of `foosrf.ev(0, 0)`,
mostly to be consistent with the evaluation of
BSpline objects (and other interpolators, such as
return self.ev(*args, **kwargs)
def _create_srf(self, u, v, xyz, ku, kv, bbox,
controlpts, U, V):
# Create surface definitions
xsrf = RectBivariateSpline(u, v, xyz[0], bbox=bbox, kx=ku, ky=kv, s=0)
ysrf = RectBivariateSpline(u, v, xyz[1], bbox=bbox, kx=ku, ky=kv, s=0)
zsrf = RectBivariateSpline(u, v, xyz[2], bbox=bbox, kx=ku, ky=kv, s=0)
if controlpts is True:
# A little back-dooring here - replace the calculated
# control points with the *actual* control points, as
# passed in.
X = xyz[0].ravel()
Y = xyz[1].ravel()
Z = xyz[2].ravel()
# Note that U and V must be passed to the constructor
# if 'controlpts' is True - these were also explicitly
# defined in something like a STEP file, for instance.
xsrf.tck = (U, V, X)
ysrf.tck = (U, V, Y)
zsrf.tck = (U, V, Z)
elif U is not None or V is not None:
if U is not None:
xsrf.tck = (U, xsrf.tck[1], xsrf.tck[2])
ysrf.tck = (U, ysrf.tck[1], ysrf.tck[2])
zsrf.tck = (U, zsrf.tck[1], zsrf.tck[2])
if V is not None:
xsrf.tck = (xsrf.tck[0], V, xsrf.tck[2])
ysrf.tck = (ysrf.tck[0], V, ysrf.tck[2])
zsrf.tck = (zsrf.tck[0], V, zsrf.tck[2])
self._xsrf = xsrf
self._ysrf = ysrf
self._zsrf = zsrf
def _resample_uv(self, ures, vres):
"""Helper function to re-sample to u and v parameters
at the specified resolution
u, v = self.u, self.v
lu, lv = len(u), len(v)
nus = np.array(list(enumerate(u))).T
nvs = np.array(list(enumerate(v))).T
newundxs = np.linspace(0, lu - 1, ures * lu - (ures - 1))
newvndxs = np.linspace(0, lv - 1, vres * lv - (vres - 1))
hru = np.interp(newundxs, *nus)
hrv = np.interp(newvndxs, *nvs)
return hru, hrv
def ev(self, u, v, mesh=True):
"""Get point(s) on surface at (u, v).
u, v : scalar or array-like
u and v may be scalar or vector
mesh : boolean
If True, will expand the u and v values into a mesh.
For example, with u = [0, 1] and v = [0, 1]: if 'mesh'
is True, the surface will be evaluated at [0, 0], [0, 1],
[1, 0] and [1, 1], while if it is False, the evalation
will only be made at [0, 0] and [1, 1]
If scalar values are passed for *both* u and v, returns
a 1-D 3-element array [x,y,z]. Otherwise, returns an array
of shape 3 x len(u) x len(v), suitable for feeding to Mayavi's
mlab.mesh() plotting function (as mlab.mesh(*arr)).
u = np.array([u]).reshape(-1,)
v = np.array([v]).reshape(-1,)
if mesh:
# I'm still not sure why we're required to flip u and v
# below, but trust me, it doesn't work otherwise.
V, U = np.meshgrid(v, u)
U = U.ravel()
V = V.ravel()
if len(u) != len(v): # *Need* to mesh this, like above!
V, U = np.meshgrid(v, u)
U = U.ravel()
V = V.ravel()
U, V = u, v
x = self._xsrf.ev(U, V)
y = self._ysrf.ev(U, V)
z = self._zsrf.ev(U, V)
if u.shape == (1,) and v.shape == (1,):
# Scalar u and v values; return 1-D 3-element array.
return np.array([x, y, z]).ravel()
# u and/or v passed as lists; return 3 x m x n array,
# where m is len(u) and n is len(v). This format
# is compatible with mayavi's mlab.mesh()
# function.
arr = np.array([x, y, z]).reshape(3, len(u), -1)
if mesh is True:
return arr
return arr[:, :, 0]
def utan(self, u, v, normalize=True, mesh=True):
u = np.asarray([u]).reshape(-1,)
v = np.asarray([v]).reshape(-1,)
dxdu = self._xsrf(u, v, dx=1, dy=0, grid=mesh)
dydu = self._ysrf(u, v, dx=1, dy=0, grid=mesh)
dzdu = self._zsrf(u, v, dx=1, dy=0, grid=mesh)
du = np.array([dxdu, dydu, dzdu]).T
if mesh is True:
du = du.swapaxes(0, 1)
du = du[:, np.newaxis, :]
if normalize:
du /= np.sqrt((du**2).sum(axis=2))[:, :, np.newaxis]
if u.shape == (1,) and v.shape == (1,):
return du.reshape(3)
arr = du.transpose(2, 0, 1)
if mesh is True:
return arr
return arr[:, :, 0]
def vtan(self, u, v, normalize=True, mesh=True):
u = np.asarray([u]).reshape(-1,)
v = np.asarray([v]).reshape(-1,)
dxdv = self._xsrf(u, v, dx=0, dy=1, grid=mesh)
dydv = self._ysrf(u, v, dx=0, dy=1, grid=mesh)
dzdv = self._zsrf(u, v, dx=0, dy=1, grid=mesh)
dv = np.array([dxdv, dydv, dzdv]).T
if mesh is True:
dv = dv.swapaxes(0, 1)
dv = dv[:, np.newaxis, :]
if normalize:
dv /= np.sqrt((dv**2).sum(axis=2))[:, :, np.newaxis]
if u.shape == (1,) and v.shape == (1,):
return dv.reshape(3)
arr = dv.transpose(2, 0, 1)
if mesh is True:
return arr
return arr[:, :, 0]
def normal(self, u, v, mesh=True):
"""Get normal(s) at (u, v).
u, v : scalar or array-like
u and v may be scalar or vector (see below)
If scalar values are passed for *both* u and v, returns
a 1-D 3-element array [x,y,z]. Otherwise, returns an array
of shape 3 x len(u) x len(v), suitable for feeding to Mayavi's
mlab.mesh() plotting function (as mlab.mesh(*arr)).
u = np.asarray([u]).reshape(-1,)
v = np.asarray([v]).reshape(-1,)
dxdus = self._xsrf(u, v, dx=1, grid=mesh)
dydus = self._ysrf(u, v, dx=1, grid=mesh)
dzdus = self._zsrf(u, v, dx=1, grid=mesh)
dxdvs = self._xsrf(u, v, dy=1, grid=mesh)
dydvs = self._ysrf(u, v, dy=1, grid=mesh)
dzdvs = self._zsrf(u, v, dy=1, grid=mesh)
normals = np.cross([dxdus, dydus, dzdus],
[dxdvs, dydvs, dzdvs],
axisa=0, axisb=0)
if mesh is False:
normals = normals[:, np.newaxis, :]
normals /= np.sqrt((normals**2).sum(axis=2))[:, :, np.newaxis]
if u.shape == (1,) and v.shape == (1,):
return normals.reshape(3)
arr = normals.transpose(2, 0, 1)
if mesh is True:
return arr
return arr[:, :, 0]
def mplot(self, ures=8, vres=8, **kwargs):
"""Plot the surface using Mayavi's `mesh()` function
ures, vres : int
Specifies the oversampling of the original
surface in u and v directions. For example:
if `ures` = 2, and `self.u` = [0, 1, 2, 3],
then the surface will be resampled at
[0, 0.5, 1, 1.5, 2, 2.5, 3] prior to
kwargs : dict
See Mayavi docs for `mesh()`
from mayavi import mlab
from matplotlib.colors import ColorConverter
if not kwargs.has_key('color'):
# Generate random color
cvec = np.random.rand(3)
cvec /= math.sqrt(
kwargs['color'] = tuple(cvec)
# The following will convert text strings representing
# colors into their (r, g, b) equivalents (which is
# the only way Mayavi will accept them)
from matplotlib.colors import ColorConverter
cconv = ColorConverter()
if kwargs['color'] is not None:
kwargs['color'] = cconv.to_rgb(kwargs['color'])
# Make new u and v values of (possibly) higher resolution
# the original ones.
hru, hrv = self._resample_uv(ures, vres)
# Sample the surface at the new u, v values and plot
meshpts = self.ev(hru, hrv, mesh=True)
mlab.mesh(*meshpts, **kwargs)
# Turn off perspective
fig = mlab.gcf()
def plot(self, ures=8, vres=8, **kwargs):
"""Alias for mplot()
self.mplot(ures=ures, vres=vres, **kwargs)
def flipu(self):
"""Flips the u-direction of the surface
xcoeffs = self._xsrf.get_coeffs()
ycoeffs = self._ysrf.get_coeffs()
zcoeffs = self._zsrf.get_coeffs()
xuknots, xvknots = self._xsrf.get_knots()
yuknots, yvknots = self._ysrf.get_knots()
zuknots, zvknots = self._zsrf.get_knots()
ulen = len(self.u)
vlen = len(self.v)
xcoeffs = xcoeffs.reshape(ulen, vlen)[-1::-1, :].ravel()
ycoeffs = ycoeffs.reshape(ulen, vlen)[-1::-1, :].ravel()
zcoeffs = zcoeffs.reshape(ulen, vlen)[-1::-1, :].ravel()
xuknots = (1 - xuknots)[-1::-1]
yuknots = (1 - yuknots)[-1::-1]
zuknots = (1 - zuknots)[-1::-1]
self.u = (1 - self.u)[-1::-1]
bbox = self.bbox
self.bbox = [1 - bbox[1], 1 - bbox[0], bbox[2], bbox[3]]
self._xsrf.tck = (xuknots, xvknots, xcoeffs)
self._ysrf.tck = (yuknots, yvknots, ycoeffs)
self._zsrf.tck = (zuknots, zvknots, zcoeffs)
def flipv(self):
"""Flips the v-direction of the surface.
xcoeffs = self._xsrf.get_coeffs()
ycoeffs = self._ysrf.get_coeffs()
zcoeffs = self._zsrf.get_coeffs()
xuknots, xvknots = self._xsrf.get_knots()
yuknots, yvknots = self._ysrf.get_knots()
zuknots, zvknots = self._zsrf.get_knots()
ulen = len(self.u)
vlen = len(self.v)
xcoeffs = xcoeffs.reshape(ulen, vlen)[:, -1::-1].ravel()
ycoeffs = ycoeffs.reshape(ulen, vlen)[:, -1::-1].ravel()
zcoeffs = zcoeffs.reshape(ulen, vlen)[:, -1::-1].ravel()
xvknots = (1 - xvknots)[-1::-1]
yvknots = (1 - yvknots)[-1::-1]
zvknots = (1 - zvknots)[-1::-1]
self.v = (1 - self.v)[-1::-1]
bbox = self.bbox
self.bbox = [bbox[0], bbox[1], 1 - bbox[3], 1 - bbox[2]]
self._xsrf.tck = (xuknots, xvknots, xcoeffs)
self._ysrf.tck = (yuknots, yvknots, ycoeffs)
self._zsrf.tck = (zuknots, zvknots, zcoeffs)
def flipboth(self):
def copy(self):
"""Get a copy of the surface
from copy import deepcopy
return deepcopy(self)
def swapuv(self, flipdir=None):
"""Swap u and v directions. In-place modification.
flipdir : Optional; one of ('u', 'v') if not `None`
Direction to reverse to maintain surface normal direction
if flipdir is not None:
flipdir = flipdir.lower()
DIRS = ('u', 'v')
assert flipdir in DIRS, \
"Invalid value for `flipdir`; must be one of " + DIRS.__repr__()
# Swap the bounding box numbers
obbox = self.bbox
swbbox = [obbox[2], obbox[3], obbox[0], obbox[1]]
# Note that the method below gives *exactly* the same
# surface as the original, judging by the amount of 'speckling'
# seen when 'before' and 'after' surfaces are plotted in Mayavi
# (i.e. there is *no* speckling - the 'after' surface
# completely replaces the 'before' surface).
U, V = self.u, self.v
ssrf = BSplineSurf(V, U, self(U, V).swapaxes(1, 2),
ku=self.kv, kv=self.ku, bbox=swbbox)
if flipdir is not None:
if flipdir == 'u':
# Re-assign all attributes
self.__dict__ = ssrf.__dict__
def uknots(self):
"""Return the knot vector in the u-parameter direction
return self._xsrf.tck[0]
def vknots(self):
"""Return the knot vector in the v-parameter direction
return self._xsrf.tck[1]
class DemoBSplineSurf(BSplineSurf):
"""Developed this at the IPython prompt, for when
a 'real' surface isn't close at hand. Creates a
modified saddle that resembles an airfoil (half)
def __init__(self, *args, **kwargs):
if len(args) == 0 and len(kwargs) == 0:
u = np.linspace(0, 1, 200)
v = np.linspace(0, 1, 10)
pts = np.array([[((x + 0.1) - (0.15 * z)**2)
* (1 + ((z + 0.5) / 7)**2),
-0.2 * ((x / 1)**2 - (z / 2)**2)
+ 0.1 + x * np.sin(z / 8),
z + 1.5]
for z, x in np.mgrid[-2:2:10j, -1:1:200j]
.T.reshape(-1, 2)])\
.T.reshape(3, 200, 10)
super(DemoBSplineSurf, self).__init__(u, v, pts,
bbox=[0, 1, 0, 1])
super(DemoBSplineSurf, self).__init__(*args, **kwargs)
if __name__ == '__main__':
from mayavi import mlab
# Set up a test surface (wavy cylinder)
a = np.linspace(0, 2 * np.pi, 360)
x, y = np.cos(a), np.sin(a)
z = np.zeros(len(x)) # Seed value
xyz = np.array([x, y, z])
xyz = np.array([xyz + i * np.array([[0, 0, .03]]).T
for i in range(100)]).T.swapaxes(0, 1)
f = 1.3 + .13 * np.sin(4 * np.linspace(0, 2 * np.pi, 100))
xyz[0:2, :, :] *= f
srf = BSplineSurf(np.linspace(0, 1, len(x)),
np.linspace(0, 1, xyz.shape[2]), xyz,
bbox=[0.0, 1.0, -0.15, 1.15])
srf.mplot(color=(0, 1, 0), opacity=1.0, ures=1, vres=1)
# Create a funky spiral around the surface and plot.
u = np.linspace(0, 4, 4 * len(x)) % 1
v = np.linspace(0, 1, 4 * len(y))
pts = srf.ev(u, v, mesh=False)
mlab.plot3d(*pts, tube_radius=0.02, color=(1, 1, 1)) # White line
# Create a test plane and cut the surface with it
ppt = np.array([0, 0, 2.1]) # Point on plane
pn = np.array([0.1, 0.6, 0.8]) # Normal to plane
pn /= np.sqrt(, pn)) # Create unit vector
D =, pn)
A, B, C = pn
planedef = np.array([A, B, C, D])
# Plot the plane
def get_z(x, y):
return (D - A * x - B * y) / C
X, Y = 2 * np.mgrid[-1:1:2j, -1:1:2j]
Z = get_z(X, Y)
mlab.mesh(X, Y, Z, color=(0, 0, 1), opacity=0.5) # Blue plane
# Plot a u-isospline on the surface, using the full
# surface extensions
V = np.linspace(srf.bbox[2], srf.bbox[3], 200)
pts = srf.ev(0.0, V)
#mlab.plot3d(*pts, tube_radius=0.02, color=(1, 1, 0)) # Yellow line
mlab.points3d(*pts, scale_factor=0.03, color=(1, 1, 0)) # Yellow dots
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment