Skip to content

Instantly share code, notes, and snippets.

@mpilosov
Created February 11, 2018 23:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mpilosov/26fba49424e404dcd12cb5665fb08dd8 to your computer and use it in GitHub Desktop.
Save mpilosov/26fba49424e404dcd12cb5665fb08dd8 to your computer and use it in GitHub Desktop.
piecewise function in python
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import numpy as np
def q1(lam):
L1 = lam[:,0] # local column-vectors.
L2 = lam[:,1]
ell2 = np.linalg.norm(lam, axis=1) # ell^2 norm (euclidean) for convenience later.
return np.exp(-ell2) - L1**3 - L2**3
def q2(lam):
d = lam.shape[1]
L1 = lam[:,0] # local column-vectors.
L2 = lam[:,1]
ell2 = np.linalg.norm(lam, axis=1) # ell^2 norm (euclidean) for convenience later.
return 1.0 + q1(lam) + 0.25*d*ell2
def fun2(lam):
n = lam.shape[0]
d = lam.shape[1] # get dimension. Should check for conformity to size, but we'll just assume it's right.
L1 = lam[:,0] # give first two columns names for convenience
L2 = lam[:,1] # later when we define our conditional statements.
# now let's figure out how to partition our space up...
inds1 = np.where( (3*L1 + 2*L2 >= 0) & (-L1 + 0.3*L2 < 0) )[0] # this weird syntax is because of `np.where`
inds2 = np.where( (3*L1 + 2*L2 >= 0) & (-L1 + 0.3*L2 >= 0) )[0]
if d == 2:
inds3 = np.where( ( (L1 + 1.0)**2 + (L2 + 1.0)**2 ) < 0.95**2 )[0]
inds4 = np.where( (3*L1 + 2*L2 < 0) & ( (L1 + 1.0)**2 + (L2 + 1.0)**2 > 0.95**2) )[0]
else:
inds3 = [] # if d != 2, this map
inds4 = np.where( (3*L1 + 2*L2 < 0) )[0]
output = np.zeros(n)
output[inds1] = q1(lam[inds1,:]) - 2.0
output[inds2] = q2(lam[inds2,:])
output[inds3] = 2*q1(lam[inds3,:]) + 4.0
output[inds4] = q1(lam[inds4,:])
return output
fig = plt.figure()
ax = fig.gca(projection='3d')
# Make data.
X = np.arange(-1, 1, 0.1)
Y = np.arange(-1, 1, 0.1)
X, Y = np.meshgrid(X, Y)
A = np.array([X.ravel(), Y.ravel()]).transpose()
Z = fun2(A).reshape(X.shape)
# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
linewidth=0, antialiased=False)
# Customize the z axis.
# ax.set_zlim(-1, 1.01)
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
ax.view_init(90,0)
# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()
@mpilosov
Copy link
Author

taken from example 6.2 in the paper by Butler, Jakeman, and Wildey.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment