Skip to content

Instantly share code, notes, and snippets.

@kayhman
Last active September 16, 2023 10:30
Show Gist options
  • Save kayhman/ed053cec27dc86a5ea6613935ecf01df to your computer and use it in GitHub Desktop.
Save kayhman/ed053cec27dc86a5ea6613935ecf01df to your computer and use it in GitHub Desktop.
import jax.numpy as jnp
from smooth_binary_node import SmoothBinaryNode
node = SmoothBinaryNode()
# feature 1 will be used s criteria
# The threshold is 50
params = {'weights': jnp.array([0, 10, 0]),
'leaves': jnp.array([-1, 1]),
'bias': 50}
features = jnp.array([[1, 80, 7]])
pred = node.predict(params, features)
# 80 is greater than 50
# we expect a 1
print('Prediction', pred)
#> 1
features = jnp.array([[1, 20, 7]])
pred = node.predict(params, features)
# 20 is lesser than 50
# we expect a -1
print('Prediction', pred)
#> -1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment