Skip to content

Instantly share code, notes, and snippets.

@salomaestro
Created August 29, 2022 15:46
Show Gist options
  • Save salomaestro/c29c6313415ed86308dab703615b4529 to your computer and use it in GitHub Desktop.
Save salomaestro/c29c6313415ed86308dab703615b4529 to your computer and use it in GitHub Desktop.
File for recursively find a solution to a given function f on a interval [a, b] using the Intermediate Value Problem
import matplotlib.pyplot as plt
import numpy as np
class Node:
def __init__(self, boundary: tuple, depth: int, right: bool):
self.boundary = boundary
self.depth = depth
self.right = right
self.left = not right
def __str__(self):
return f"Node(boundary={self.boundary}, depth={self.depth}, direction={'right' if self.right else 'left'})"
def __repr__(self):
return f"Node(boundary={self.boundary}, depth={self.depth}, direction={'right' if self.right else 'left'})"
class Tree:
def __init__(self, func, leftbound, rightbound):
self.func, self.leftbound, self.rightbound = func, leftbound, rightbound
self.nodes = []
def append(self, node: Node):
self.nodes.append(node)
def __str__(self):
return f"Tree(nodes={self.nodes})"
def __repr__(self):
return f"Tree(nodes={self.nodes})"
def draw(self):
fig, ax = plt.subplots(1, 1)
x = np.linspace(self.leftbound, self.rightbound, 100)
y = self.func(x)
ax.axhline(y=0, color='k')
ax.axvline(x=0, color='k')
ax.plot(x, y)
increment = 0.05
for node in self.nodes:
py = node.boundary
px = (increment, increment)
ax.scatter(py, px, color='r')
increment += 0.1
plt.show()
def recursive_bisect(f, a, b, n):
tree = Tree(f, a, b)
def recursive_bisect_exec(f, a, b, i):
fa = f(a)
fb = f(b)
c = (a + b) / 2
fc = f(c)
# Trivial case
if i == 0:
return c
if fa * fc < 0:
tree.append(Node((a, c), n-i, False))
return recursive_bisect_exec(f, a, c, i-1)
elif fc * fb < 0:
tree.append(Node((c, b), n-i, True))
return recursive_bisect_exec(f, c, b, i-1)
else:
return "no root"
return recursive_bisect_exec(f, a, b, n), tree
if __name__ == "__main__":
f = lambda x: x**2 - x - 1
approx_root, tree = recursive_bisect(f, -1, 2, 10)
print(tree)
tree.draw()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment