Skip to content

Instantly share code, notes, and snippets.

@montali
Created April 21, 2021 12:31
Show Gist options
  • Save montali/eee9a607d98622a453bb042361ef8938 to your computer and use it in GitHub Desktop.
Save montali/eee9a607d98622a453bb042361ef8938 to your computer and use it in GitHub Desktop.
import numpy as np
class InitialPointShapeException(Exception):
pass
class NoSimplexDefinedException(Exception):
pass
class NelderMead:
def __init__(self, n, fn, sum_constraint, reflection_parameter=1, expansion_parameter=2, contraction_parameter=0.5, shrinkage_parameter=0.5):
self.reflection_parameter = reflection_parameter
self.expansion_parameter = expansion_parameter
self.contraction_parameter = contraction_parameter
self.shrinkage_parameter = shrinkage_parameter
self.n = n
self.fn = fn
self.sum_constraint = sum_constraint
def initialize_simplex(self, x_1=None):
"""Initializes the first simplex to begin iterations
Args:
x_1 (np.array, optional): used as the first point for the simplex generation. Defaults to None, which becomes a random point.
Raises:
InitialPointShapeException: Raised when the provided first point has the wrong number of dimensions.
"""
# First, if no initial point was provided, we'll get a random one
self.simplex_points = np.empty((self.n+1, self.n))
# If the user provided a point, and it is not in the right shape
if x_1 != None and x_1.shape != (self.n):
raise InitialPointShapeException(
f"Please enter an initial point having {self.n} dimensions.")
elif x_1 == None: # If the user didn't provide a point
# Multiply it by 10 so that we get numbers from 0 to 10
self.simplex_points[0] = np.random.rand(self.n)
else: # If the user provided a point, and it is in the right shape
self.simplex_points[0] = x_1
# Then, we will generate all the other points
for i in range(1, self.n+1): # The simplex has n+1 points
shift_coefficient = 0.05 if self.simplex_points[0][i -
1] != 0 else 0.0025
unit_vector_i = np.zeros(self.n)
unit_vector_i[i-1] = 1
self.simplex_points[i] = self.simplex_points[0] + \
shift_coefficient * unit_vector_i
print(f"Succesfully initialized first simplex: {self.simplex_points}")
def sort(self):
"""
Fills self.simplex_points with the function values, then
returns the worst, second worst and best points.
Returns:
- tuple: Worst, second best and best indices of the simplex points' values
"""
# Calculate values of the function in all points of the simplex
self.simplex_vals = np.array(
self.fn(self.simplex_points.transpose()))
sorted_indices = np.argsort(self.simplex_vals)
self.min = self.simplex_vals[sorted_indices[0]]
return sorted_indices[0], sorted_indices[-2], sorted_indices[-1]
def iterate(self):
"""Performs one iteration of the Nelder-Mead method:
- Sorts the simplex points
- Computes the centroid
- Tries reflection, expansion, contraction, shrinking
- Updates the simplex
"""
best, sec_worst, worst = self.sort()
# Compute the centroid, excluding the worst point
centroid = np.mean(np.delete(self.simplex_points, worst), axis=0)
# Transformation: reflection
x_reflected = centroid + \
(self.reflection_parameter * (centroid-self.simplex_points[worst]))
y_reflected = self.fn(x_reflected)
# If the new point is better than the second worst, but worse than the best, we can break to the next iteration
if self.simplex_vals[best] < y_reflected <= self.simplex_vals[sec_worst]:
# We don't want negative points
self.simplex_points[worst] = x_reflected if x_reflected > 0 else 0
print("✨ Reflected ✨")
return
# If the point we've found is better than the best, we try to expand it
elif y_reflected < self.simplex_vals[best]:
x_expanded = centroid + self.expansion_parameter * \
(x_reflected-centroid)
y_expanded = self.fn(x_expanded)
# We substitute the worst point with the better of the two
if y_expanded < y_reflected:
self.simplex_points[worst] = x_expanded if x_expanded > 0 else 0
print("✨ Tried expansion and it worked! ✨")
else:
self.simplex_points[worst] = x_reflected if x_reflected > 0 else 0
print("✨ Tried expansion but reflection was better ✨")
return
# If the point we've found was worse than the second worst, we'll contract
elif y_reflected > self.simplex_vals[sec_worst]:
x_contracted = centroid + self.contraction_parameter * \
(self.simplex_points[worst] - centroid)
y_contracted = self.fn(x_contracted)
if y_contracted < self.simplex_vals[worst]:
self.simplex_points[worst] = x_contracted if x_contracted > 0 else 0
print("✨ Contracted ✨")
return
# If none of the previous methods worked, we'll try our last resort: shrink contraction
# We'll want to redefine all the simplex points except for the best one.
for i in range(self.n+1):
if (i != best): # We won't change the best one
self.simplex_points[i] = self.simplex_points[best] + self.shrinkage_parameter * (
self.simplex_points[i] - self.simplex_points[best])
print("✨ Shrinked ✨")
def fix(self):
"""Reduces the simplex points' size to satisfy the constraint
"""
self.simplex_points = (
self.simplex_points / np.sum(self.simplex_points, axis=1, keepdims=1)) * self.sum_constraint
def fit(self, target_stddev):
"""Computes until the STD deviation of the function values in the simplex reaches a given value
Args:
target_stddev (float, optional): Target standard deviation
Returns:
tuple: point of maximum X and its value
"""
# Check if simplex points has been defined, i.e. initialize_simplex has been called
if type(self.simplex_points) is not np.ndarray:
raise NoSimplexDefinedException
self.simplex_vals = np.array(
self.fn(self.simplex_points.transpose()))
std_dev = np.std(self.simplex_vals)
i = 0
print(std_dev)
while std_dev > target_stddev and i < 50:
self.iterate()
std_dev = np.std(self.simplex_vals)
print(
f"🚀 Performing iteration {i}\t🥴 Standard deviation={round(std_dev, 2)}\t🏅 Value={round(self.min, 3)}")
i += 1
self.fix()
_, _, best = self.sort()
return self.simplex_points[best]
if __name__ == '__main__':
def fn(x): return ((x[0]+2*x[1]-7)**2 + (2*x[0]+x[1]-5)**2)
def fn2(x): return (x[0]**2 + x[1]**2 + 4*x[3]**2 - x[4]**2 - x[5]**4)
nm = NelderMead(6, fn2, 1, reflection_parameter=4, expansion_parameter=4,
contraction_parameter=0.05, shrinkage_parameter=0.05)
nm.initialize_simplex()
print(nm.fit(0.00001))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment