Skip to content

Instantly share code, notes, and snippets.

@amroamroamro
Last active January 12, 2024 22:12
Show Gist options
  • Star 77 You must be signed in to star a gist
  • Fork 15 You must be signed in to fork a gist
  • Save amroamroamro/1db8d69b4b65e8bc66a6 to your computer and use it in GitHub Desktop.
Save amroamroamro/1db8d69b4b65e8bc66a6 to your computer and use it in GitHub Desktop.
[Python] Fitting plane/surface to a set of data points

Python version of the MATLAB code in this Stack Overflow post: https://stackoverflow.com/a/18648210/97160

The example shows how to determine the best-fit plane/surface (1st or higher order polynomial) over a set of three-dimensional points.

Implemented in Python + NumPy + SciPy + matplotlib.

quadratic_surface


EDIT (2023-06-16)

I added a new example fit.py that shows polynomial fitting of any n-th order, as well as the same thing but using scikit-learn functions fit-sklearn.py.

peaks

#!/usr/bin/evn python
import numpy as np
import scipy.linalg
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# some 3-dim points
mean = np.array([0.0,0.0,0.0])
cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]])
data = np.random.multivariate_normal(mean, cov, 50)
# regular grid covering the domain of the data
X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5))
XX = X.flatten()
YY = Y.flatten()
order = 1 # 1: linear, 2: quadratic
if order == 1:
# best-fit linear plane
A = np.c_[data[:,0], data[:,1], np.ones(data.shape[0])]
C,_,_,_ = scipy.linalg.lstsq(A, data[:,2]) # coefficients
# evaluate it on grid
Z = C[0]*X + C[1]*Y + C[2]
# or expressed using matrix/vector product
#Z = np.dot(np.c_[XX, YY, np.ones(XX.shape)], C).reshape(X.shape)
elif order == 2:
# best-fit quadratic curve
A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2]
C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])
# evaluate it on a grid
Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape)
# plot points and fitted surface
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2)
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50)
plt.xlabel('X')
plt.ylabel('Y')
ax.set_zlabel('Z')
ax.axis('equal')
ax.axis('tight')
plt.show()
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
import matplotlib.pyplot as plt
def generateData(n = 30):
# similar to peaks() function in MATLAB
g = np.linspace(-3.0, 3.0, n)
X, Y = np.meshgrid(g, g)
X, Y = X.reshape(-1,1), Y.reshape(-1,1)
Z = 3 * (1 - X)**2 * np.exp(- X**2 - (Y+1)**2) \
- 10 * (X/5 - X**3 - Y**5) * np.exp(- X**2 - Y**2) \
- 1/3 * np.exp(- (X+1)**2 - Y**2)
return X, Y, Z
def names2model(names):
# C[i] * X^n * Y^m
return ' + '.join([
f"C[{i}]*{n.replace(' ','*')}"
for i,n in enumerate(names)])
# generate some random 3-dim points
X, Y, Z = generateData()
# 1=linear, 2=quadratic, 3=cubic, ..., nth degree
order = 11
# best-fit polynomial surface
model = make_pipeline(
PolynomialFeatures(degree=order),
LinearRegression(fit_intercept=False))
model.fit(np.c_[X, Y], Z)
m = names2model(model[0].get_feature_names_out(['X', 'Y']))
C = model[1].coef_.T # coefficients
r2 = model.score(np.c_[X, Y], Z) # R-squared
# print summary
print(f'data = {Z.size}x3')
print(f'model = {m}')
print(f'coefficients =\n{C}')
print(f'R2 = {r2}')
# uniform grid covering the domain of the data
XX,YY = np.meshgrid(np.linspace(X.min(), X.max(), 20), np.linspace(Y.min(), Y.max(), 20))
# evaluate model on grid
ZZ = model.predict(np.c_[XX.flatten(), YY.flatten()]).reshape(XX.shape)
# plot points and fitted surface
ax = plt.figure().add_subplot(projection='3d')
ax.scatter(X, Y, Z, c='r', s=2)
ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, alpha=0.2, linewidth=0.5, edgecolor='b')
ax.axis('tight')
ax.view_init(azim=-60.0, elev=30.0)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
import numpy as np
from scipy.linalg import lstsq
import matplotlib.pyplot as plt
def generateData(n = 30):
# similar to peaks() function in MATLAB
g = np.linspace(-3.0, 3.0, n)
X, Y = np.meshgrid(g, g)
X, Y = X.reshape(-1,1), Y.reshape(-1,1)
Z = 3 * (1 - X)**2 * np.exp(- X**2 - (Y+1)**2) \
- 10 * (X/5 - X**3 - Y**5) * np.exp(- X**2 - Y**2) \
- 1/3 * np.exp(- (X+1)**2 - Y**2)
return X, Y, Z
def exp2model(e):
# C[i] * X^n * Y^m
return ' + '.join([
f'C[{i}]' +
('*' if x>0 or y>0 else '') +
(f'X^{x}' if x>1 else 'X' if x==1 else '') +
('*' if x>0 and y>0 else '') +
(f'Y^{y}' if y>1 else 'Y' if y==1 else '')
for i,(x,y) in enumerate(e)
])
# generate some random 3-dim points
X, Y, Z = generateData()
# 1=linear, 2=quadratic, 3=cubic, ..., nth degree
order = 11
# calculate exponents of design matrix
#e = [(x,y) for x in range(0,order+1) for y in range(0,order-x+1)]
e = [(x,y) for n in range(0,order+1) for y in range(0,n+1) for x in range(0,n+1) if x+y==n]
eX = np.asarray([[x] for x,_ in e]).T
eY = np.asarray([[y] for _,y in e]).T
# best-fit polynomial surface
A = (X ** eX) * (Y ** eY)
C,resid,_,_ = lstsq(A, Z) # coefficients
# calculate R-squared from residual error
r2 = 1 - resid[0] / (Z.size * Z.var())
# print summary
print(f'data = {Z.size}x3')
print(f'model = {exp2model(e)}')
print(f'coefficients =\n{C}')
print(f'R2 = {r2}')
# uniform grid covering the domain of the data
XX,YY = np.meshgrid(np.linspace(X.min(), X.max(), 20), np.linspace(Y.min(), Y.max(), 20))
# evaluate model on grid
A = (XX.reshape(-1,1) ** eX) * (YY.reshape(-1,1) ** eY)
ZZ = np.dot(A, C).reshape(XX.shape)
# plot points and fitted surface
ax = plt.figure().add_subplot(projection='3d')
ax.scatter(X, Y, Z, c='r', s=2)
ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, alpha=0.2, linewidth=0.5, edgecolor='b')
ax.axis('tight')
ax.view_init(azim=-60.0, elev=30.0)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
@amroamroamro
Copy link
Author

@jensleitloff see the new update fit.py (and the same thing in fit-sklearn.py using sklearn functions)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment