Skip to content

Instantly share code, notes, and snippets.

@HYChou0515
Last active February 13, 2021 02:12
Show Gist options
  • Save HYChou0515/5f7b1374e83b33ac27bca1d913765d24 to your computer and use it in GitHub Desktop.
Save HYChou0515/5f7b1374e83b33ac27bca1d913765d24 to your computer and use it in GitHub Desktop.
Line search, searching for zero point of a increasing function.
# Given func() and cond(), assume there is a number A such that
# cond(func(x)) == 0 if x < A
# == 1 if x >= A
# line search will find the number A according to func() and cond().
# Starting from point=x0 with step=s0, the searching has two phases:
# 1) expanding and 2) concentrating.
#
# 1) Expanding
# In the expanding phase, if cond(func(x0)) == 0,
# then we take larger and larger steps toward positive,
# so we multiply the step size by 2.
# That is, at the kth iteration,
# x_k <- x_k-1 + s_k-1
# s_k <- s_k-1 * 2
# Until cond(func(x_k)) == 1, we know
# x_k-1 < A <= x_k,
# and we enter phase 2.
# On the other hand, if cond(func(x0)) == 1,
# we take larger and larger steps toward negative,
# i.e., at the kth iteration,
# x_k <- x_k-1 - s_k-1
# s_k <- s_k-1 * 2,
# until cond(func(x_k)) == 0.
#
# 2) Concentrating
# If at the kth iteration it enter the 2nd phase,
# we have
# x_k-1 < A <= x_k.
# So we shrink the step size by 2 each step and find A.
# That is, before moving, let
# s_k+ <- s_k / 2.
# If cond(func(x_k)) == 0,
# x_k+1 <- x_k + s_k,
# If cond(unc(x_k)) == 1,
# x_k+1 <- x_k - s_k.
class LineSearch:
def __init__(self, func, cond, quiet=True):
self.func = func
self.cond = cond
self.quiet = quiet
self.x0 = None
self.s0 = None
self.x = None
self.a = None
self.s = None
self.tilde_A = None
def __call__(self, x0, s0):
self.x0 = x0
self.s0 = s0
return self
def __iter__(self):
self.x = self.x0 * 1.0
self.s = self.s0 * 1.0
self.tilde_A = None
self.a = self.func(self.x)
yield self.x, self.a
c = self.cond(self.a)
cnt = 1
if c:
self.s *= -1
self.tilde_A = self.x
while True:
self.x += self.s
self.a = self.func(self.x)
yield self.x, self.a
c = self.cond(self.a)
cnt += 1
if c:
self.tilde_A = self.x
if (self.s < 0 and not c) or (self.s > 0 and c):
self.s *= -1
break
self.s *= 2
if not self.quiet:
print(f'run func {cnt} times in phase 1')
cnt = 0
# enter phase 2
# we have self.x-self.s < A <= self.x
while True:
self.s /= 2
self.x += self.s
self.a = self.func(self.x)
yield self.x, self.a
c = self.cond(self.a)
cnt += 1
if c:
self.tilde_A = self.x
if (self.s < 0 and not c) or (self.s > 0 and c):
self.s *= -1
def search(self, x0, s0, nr_steps):
for i, _ in enumerate(self(x0, s0)):
if i >= nr_steps:
break
return self.tilde_A
if __name__ == '__main__':
import math
ls = LineSearch(lambda x: x + math.pi ** 10, lambda x: 2*x > 0, False)
print(ls.search(0, 1, 200))
print(math.pi**10)
for i, (x, a) in enumerate(ls(0, 1)):
print(x, a)
if i >= 50:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment