Skip to content

Instantly share code, notes, and snippets.

@prednaz
Last active September 11, 2023 19:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prednaz/2a14e10f450d27e8851376dfb8536c0d to your computer and use it in GitHub Desktop.
Save prednaz/2a14e10f450d27e8851376dfb8536c0d to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import Union
ClassificationTree = Union["Verdict", "Split"] # type for https://mypy.readthedocs.io/
@dataclass(frozen=True)
class Verdict: # leaves of `ClassificationTree`s
class_label: int
@dataclass(frozen=True)
class Split: # binary nodes of `ClassificationTree`s
attribute_index: int # which attribute do we split on?
split_point: float # a value of the attribute we split on
left: ClassificationTree
right: ClassificationTree
def tree_grow(x: np.ndarray, y: np.ndarray, nmin: int, minleaf: int, nfeat: int) -> ClassificationTree:
"""
grow a classification tree using the gini index and return it.
arguments:
x -- data matrix (2 dimensional array) containing the numeric attribute
values. each row contains the attribute values of one training example.
y -- vector (1 dimensional array) of class labels. the class label
is binary, with values coded as 0 and 1.
nmin -- number of observations that a node must contain at least, for
it to be allowed to be split. in other words: if a node contains fewer
cases than `nmin`, it becomes a leaf node.
minleaf -- minimum number of observations required for a leaf node;
hence, a split that creates a node with fewer than `minleaf`
observations is not acceptable.
nfeat -- number of features that should be considered for each split.
every time we compute the best split in a particular node, we first
draw at random `nfeat` features from which the best split is to be
selected.
"""
attributes = x
classes = y
return Split(
attribute_index=2,
split_point=(attributes[42][2] + attributes[43][2]) / 2,
left=Verdict(classes[0]),
right=tree_grow(
attributes[2:], classes[2:],
nmin, minleaf, nfeat
)
)
def tree_pred(x: np.ndarray, tr: ClassificationTree) -> np.ndarray:
"""
predict the class of new cases.
return a vector (1 dimensional array) `y` of predicted class labels for
the cases in `x`, that is, `y[i]` contains the predicted class label
for row `i` of `x`.
arguments:
x -- data matrix (2 dimensional array) containing the attribute values
of the cases for which predictions are required.
tr -- tree object created with the function `tree_grow`.
"""
pass
def tree_grow_b(
x: np.ndarray,
y: np.ndarray,
nmin: int,
minleaf: int,
nfeat: int,
m: int,
) -> list[ClassificationTree]:
"""
grow classification trees using the gini index and return them in a
list.
arguments:
x -- data matrix (2 dimensional array) containing the numeric attribute
values. each row contains the attribute values of one training example.
y -- vector (1 dimensional array) of class labels. the class label
is binary, with values coded as 0 and 1.
nmin -- number of observations that a node must contain at least, for
it to be allowed to be split. in other words: if a node contains fewer
cases than `nmin`, it becomes a leaf node.
minleaf -- minimum number of observations required for a leaf node;
hence, a split that creates a node with fewer than `minleaf`
observations is not acceptable.
nfeat -- number of features that should be considered for each split.
every time we compute the best split in a particular node, we first
draw at random `nfeat` features from which the best split is to be
selected.
m -- number of bootstrap samples to be drawn. On each bootstrap sample
a tree is grown.
"""
pass
def tree_pred_b(x: np.ndarray, trees: list[ClassificationTree]) -> np.ndarray:
"""
predict the class of new cases taking the majority vote.
apply `tree_pred` to `x` using each tree in the list in turn. for each
row of `x` the final prediction is obtained by taking the majority vote
of the `m` predictions. return a vector (1 dimensional array) `y` of
predicted class labels for the cases in `x`, that is, `y[i]` contains
the predicted class label for row `i` of `x`.
arguments:
x -- data matrix (2 dimensional array) containing the attribute values
of the cases for which predictions are required.
trees -- tree object list created with the function `tree_grow_b`.
"""
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment