Created
November 7, 2023 04:57
-
-
Save tor-gu/a6632fb0c0ae75fb4ecbce54e5c6a218 to your computer and use it in GitHub Desktop.
Find optimal points to remove so that a model fits better.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
from scipy.optimize import LinearConstraint, minimize | |
np.random.seed(3737) | |
LO, HI = 0, 100 | |
n = 100 # Number of points | |
k = 73 # Points to remove | |
def obj_fun(arr, X, Y): | |
""" | |
Function to minimize: Weighted least-squares relative to model of | |
the form log(A * x + B) | |
""" | |
W = arr[:n] | |
A = arr[n] | |
B = arr[n + 1] | |
return sum(W * (A * np.log(X + B) - Y) ** 2) | |
# Generate random X and Y values | |
X = np.random.uniform(LO, HI, n) | |
Y = np.random.uniform(LO, HI, n) | |
X.sort() | |
# Initial guess is (n-k)/n for the weights, and A=1, B=0 | |
init = [(n - k) / n] * n + [1, 0] | |
# The weights are bounded from 0 to 1. A is unbounded. B must be non-negative | |
bounds = [(0, 1)] * n + [(None, None), (0, None)] | |
# The constraint is that the weights need to add up to n - k | |
row = [1] * n + [0, 0] | |
constraints = LinearConstraint([row], n - k, n - k) | |
# Now do the minimization | |
res = minimize( | |
obj_fun, | |
init, | |
args=(X, Y), | |
method="SLSQP", | |
bounds=bounds, | |
constraints=constraints, | |
) | |
assert res.success | |
# Extract the weights and A & B from the results | |
W = res.x[:n] | |
A = res.x[n] | |
B = res.x[n + 1] | |
# See which elements we have selected | |
selected = W > 0.999 | |
removed = W < 0.001 | |
# Check that there are no intermediate weights | |
assert np.logical_xor(selected, removed).all() | |
# Plot the results | |
_ = plt.scatter(X[selected], Y[selected], color="red") | |
_ = plt.scatter(X[removed], Y[removed], color="blue") | |
_ = plt.plot(X[selected], A * np.log(X[selected] + B), color="orange") | |
_ = plt.title(f"y = {A:.2f} * log(x + {B:.2f})") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment