Skip to content

Instantly share code, notes, and snippets.

@jensleitloff
Last active June 16, 2023 08:34
Show Gist options
  • Save jensleitloff/f8c253ca8fb68cfabfff5b0cf1353429 to your computer and use it in GitHub Desktop.
Save jensleitloff/f8c253ca8fb68cfabfff5b0cf1353429 to your computer and use it in GitHub Desktop.
[Python] Fitting plane/surface to a set of data points with optional weights
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
import matplotlib.pyplot as plt
USE_WEIGHTS = True
# some 3-dim points
mean = np.array([0.0,0.0,0.0])
cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]])
data = np.random.multivariate_normal(mean, cov, 50)
if USE_WEIGHTS:
# weights can't be negative
w = np.abs(np.random.normal(loc=1, scale=1, size=50))
else:
w = np.ones(shape=50)
# regular grid covering the domain of the data
X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5))
XX = X.flatten()
YY = Y.flatten()
order = 2 # 1: linear, 2: quadratic
model = make_pipeline(PolynomialFeatures(degree=order), LinearRegression())
model.fit(data[:, :2], data[:, -1], linearregression__sample_weight=w)
Z = model.predict(np.c_[XX, YY]).reshape(X.shape)
# plot points and fitted surface
ax = plt.figure().add_subplot(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2)
ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50)
plt.xlabel('X')
plt.ylabel('Y')
ax.set_zlabel('Z')
ax.axis('equal')
ax.axis('tight')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment