Last active
February 13, 2021 02:12
-
-
Save HYChou0515/5f7b1374e83b33ac27bca1d913765d24 to your computer and use it in GitHub Desktop.
Line search, searching for zero point of a increasing function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Given func() and cond(), assume there is a number A such that | |
# cond(func(x)) == 0 if x < A | |
# == 1 if x >= A | |
# line search will find the number A according to func() and cond(). | |
# Starting from point=x0 with step=s0, the searching has two phases: | |
# 1) expanding and 2) concentrating. | |
# | |
# 1) Expanding | |
# In the expanding phase, if cond(func(x0)) == 0, | |
# then we take larger and larger steps toward positive, | |
# so we multiply the step size by 2. | |
# That is, at the kth iteration, | |
# x_k <- x_k-1 + s_k-1 | |
# s_k <- s_k-1 * 2 | |
# Until cond(func(x_k)) == 1, we know | |
# x_k-1 < A <= x_k, | |
# and we enter phase 2. | |
# On the other hand, if cond(func(x0)) == 1, | |
# we take larger and larger steps toward negative, | |
# i.e., at the kth iteration, | |
# x_k <- x_k-1 - s_k-1 | |
# s_k <- s_k-1 * 2, | |
# until cond(func(x_k)) == 0. | |
# | |
# 2) Concentrating | |
# If at the kth iteration it enter the 2nd phase, | |
# we have | |
# x_k-1 < A <= x_k. | |
# So we shrink the step size by 2 each step and find A. | |
# That is, before moving, let | |
# s_k+ <- s_k / 2. | |
# If cond(func(x_k)) == 0, | |
# x_k+1 <- x_k + s_k, | |
# If cond(unc(x_k)) == 1, | |
# x_k+1 <- x_k - s_k. | |
class LineSearch: | |
def __init__(self, func, cond, quiet=True): | |
self.func = func | |
self.cond = cond | |
self.quiet = quiet | |
self.x0 = None | |
self.s0 = None | |
self.x = None | |
self.a = None | |
self.s = None | |
self.tilde_A = None | |
def __call__(self, x0, s0): | |
self.x0 = x0 | |
self.s0 = s0 | |
return self | |
def __iter__(self): | |
self.x = self.x0 * 1.0 | |
self.s = self.s0 * 1.0 | |
self.tilde_A = None | |
self.a = self.func(self.x) | |
yield self.x, self.a | |
c = self.cond(self.a) | |
cnt = 1 | |
if c: | |
self.s *= -1 | |
self.tilde_A = self.x | |
while True: | |
self.x += self.s | |
self.a = self.func(self.x) | |
yield self.x, self.a | |
c = self.cond(self.a) | |
cnt += 1 | |
if c: | |
self.tilde_A = self.x | |
if (self.s < 0 and not c) or (self.s > 0 and c): | |
self.s *= -1 | |
break | |
self.s *= 2 | |
if not self.quiet: | |
print(f'run func {cnt} times in phase 1') | |
cnt = 0 | |
# enter phase 2 | |
# we have self.x-self.s < A <= self.x | |
while True: | |
self.s /= 2 | |
self.x += self.s | |
self.a = self.func(self.x) | |
yield self.x, self.a | |
c = self.cond(self.a) | |
cnt += 1 | |
if c: | |
self.tilde_A = self.x | |
if (self.s < 0 and not c) or (self.s > 0 and c): | |
self.s *= -1 | |
def search(self, x0, s0, nr_steps): | |
for i, _ in enumerate(self(x0, s0)): | |
if i >= nr_steps: | |
break | |
return self.tilde_A | |
if __name__ == '__main__': | |
import math | |
ls = LineSearch(lambda x: x + math.pi ** 10, lambda x: 2*x > 0, False) | |
print(ls.search(0, 1, 200)) | |
print(math.pi**10) | |
for i, (x, a) in enumerate(ls(0, 1)): | |
print(x, a) | |
if i >= 50: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment