class node:
    def __init__(self, value):
        self.left = None
        self.right = None
        self.value  = value
    
    def __str__(self):
        return str(self.value)



def find_node (root: node, target: int) -> node:
    closest = None

    cur = root

    while (cur != None):
        
        if cur.value == target:
            return cur
        
        if  closest == None or (abs(closest.value - target) > abs(cur.value - target)) :
            closest = cur

        if cur.value > target:
            cur = cur.left
        elif cur.value < target:
            cur = cur.right

    return closest


root = node(12)
node5 = node(5)
node16 = node(16)
node3 = node(3)
node8 = node(8)
node14 = node(14)
node20 = node(20)
node1 = node(1)
node4 = node(4)
node11 = node(11)

root.left = node5
root.right = node16
node5.left = node3
node5.right = node8
node3.left = node1
node3.right = node4
node8.right = node11
node16.left = node14
node16.right = node20

found = find_node(root, 13)

print("found = {}".format(found))

found = find_node(root, 4)

print("found 2 = {}".format(found))