Skip to content

Instantly share code, notes, and snippets.

from functools import reduce
from timeit import timeit
# Naive implementation of a polynomial function.
def polynomial(A, x):
return reduce(lambda a, b: a + b,
[a_i * x**i for i, a_i in enumerate(A)], 0)
# More efficient implementation of a polynomial function.
def smarter_polynomial(A, x):
import matplotlib.pyplot as plt
from jax import grad
from jax.numpy import log, sin
import numpy as np
def f(x):
"""
Function to be minimized
"""
import matplotlib.pyplot as plt
from jax.numpy import log, sin
from jax import grad
import numpy as np
def f(x):
"""
Function to be minimized
"""
return -(-(x-6)**2 / 3.0 + log(6+x) + 2*sin(x))
import matplotlib.pyplot as plt
from jax import grad
import numpy as np
def f(x):
"""
Function to be minimized
"""
return (x-6)**3
from jax import grad, jit
import jax.numpy as jnp
import numpy as np
from smooth_binary_node import SmoothBinaryNode
from utils import mse
node = SmoothBinaryNode()
# feature 1 will be used s criteria
import jax.numpy as jnp
from smooth_binary_node import SmoothBinaryNode
from utils import mse
node = SmoothBinaryNode()
# feature 1 will be used s criteria
# The threshold is 50
params = {'weights': jnp.array([0, 10, 0]),
@kayhman
kayhman / mse.py
Created September 16, 2023 11:00
def mse(params, tree, features, y_true):
def ferr(params, features, y_true):
pred = tree.predict(params, features)
diff = (pred - y_true)
err = (diff**2).mean()
return err
return ferr
@kayhman
kayhman / 2_levels_smooth_tree.py
Last active September 18, 2023 13:59
2_levels_smooth_tree.py
import jax.numpy as jnp
from smooth_binary_node import SmoothBinaryNode
node_left = SmoothBinaryNode()
node_right = SmoothBinaryNode()
root = SmoothBinaryNode(left_node=node_left, right_node=node_right)
left_params = {'weights': jnp.array([0.0, 10.0, 0.0]),
import jax.numpy as jnp
from smooth_binary_node import SmoothBinaryNode
node = SmoothBinaryNode()
# feature 1 will be used s criteria
# The threshold is 50
params = {'weights': jnp.array([0, 10, 0]),
'leaves': jnp.array([-1, 1]),
import jax.numpy as jnp
from entmax_jax import entmax, entmax15, sparsemax
class SmoothBinaryNode():
def __init__(self, right_node=None, left_node=None):
self.left_node = left_node
self.right_node = right_node
self.scale = 1.0
def _get_params(self, params):