This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
from timeit import timeit | |
# Naive implementation of a polynomial function. | |
def polynomial(A, x): | |
return reduce(lambda a, b: a + b, | |
[a_i * x**i for i, a_i in enumerate(A)], 0) | |
# More efficient implementation of a polynomial function. | |
def smarter_polynomial(A, x): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from jax import grad | |
from jax.numpy import log, sin | |
import numpy as np | |
def f(x): | |
""" | |
Function to be minimized | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from jax.numpy import log, sin | |
from jax import grad | |
import numpy as np | |
def f(x): | |
""" | |
Function to be minimized | |
""" | |
return -(-(x-6)**2 / 3.0 + log(6+x) + 2*sin(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from jax import grad | |
import numpy as np | |
def f(x): | |
""" | |
Function to be minimized | |
""" | |
return (x-6)**3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax import grad, jit | |
import jax.numpy as jnp | |
import numpy as np | |
from smooth_binary_node import SmoothBinaryNode | |
from utils import mse | |
node = SmoothBinaryNode() | |
# feature 1 will be used s criteria |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from smooth_binary_node import SmoothBinaryNode | |
from utils import mse | |
node = SmoothBinaryNode() | |
# feature 1 will be used s criteria | |
# The threshold is 50 | |
params = {'weights': jnp.array([0, 10, 0]), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mse(params, tree, features, y_true): | |
def ferr(params, features, y_true): | |
pred = tree.predict(params, features) | |
diff = (pred - y_true) | |
err = (diff**2).mean() | |
return err | |
return ferr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from smooth_binary_node import SmoothBinaryNode | |
node_left = SmoothBinaryNode() | |
node_right = SmoothBinaryNode() | |
root = SmoothBinaryNode(left_node=node_left, right_node=node_right) | |
left_params = {'weights': jnp.array([0.0, 10.0, 0.0]), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from smooth_binary_node import SmoothBinaryNode | |
node = SmoothBinaryNode() | |
# feature 1 will be used s criteria | |
# The threshold is 50 | |
params = {'weights': jnp.array([0, 10, 0]), | |
'leaves': jnp.array([-1, 1]), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from entmax_jax import entmax, entmax15, sparsemax | |
class SmoothBinaryNode(): | |
def __init__(self, right_node=None, left_node=None): | |
self.left_node = left_node | |
self.right_node = right_node | |
self.scale = 1.0 | |
def _get_params(self, params): |
NewerOlder