Last active
September 11, 2023 19:37
-
-
Save prednaz/2a14e10f450d27e8851376dfb8536c0d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Union | |
ClassificationTree = Union["Verdict", "Split"] # type for https://mypy.readthedocs.io/ | |
@dataclass(frozen=True) | |
class Verdict: # leaves of `ClassificationTree`s | |
class_label: int | |
@dataclass(frozen=True) | |
class Split: # binary nodes of `ClassificationTree`s | |
attribute_index: int # which attribute do we split on? | |
split_point: float # a value of the attribute we split on | |
left: ClassificationTree | |
right: ClassificationTree | |
def tree_grow(x: np.ndarray, y: np.ndarray, nmin: int, minleaf: int, nfeat: int) -> ClassificationTree: | |
""" | |
grow a classification tree using the gini index and return it. | |
arguments: | |
x -- data matrix (2 dimensional array) containing the numeric attribute | |
values. each row contains the attribute values of one training example. | |
y -- vector (1 dimensional array) of class labels. the class label | |
is binary, with values coded as 0 and 1. | |
nmin -- number of observations that a node must contain at least, for | |
it to be allowed to be split. in other words: if a node contains fewer | |
cases than `nmin`, it becomes a leaf node. | |
minleaf -- minimum number of observations required for a leaf node; | |
hence, a split that creates a node with fewer than `minleaf` | |
observations is not acceptable. | |
nfeat -- number of features that should be considered for each split. | |
every time we compute the best split in a particular node, we first | |
draw at random `nfeat` features from which the best split is to be | |
selected. | |
""" | |
attributes = x | |
classes = y | |
return Split( | |
attribute_index=2, | |
split_point=(attributes[42][2] + attributes[43][2]) / 2, | |
left=Verdict(classes[0]), | |
right=tree_grow( | |
attributes[2:], classes[2:], | |
nmin, minleaf, nfeat | |
) | |
) | |
def tree_pred(x: np.ndarray, tr: ClassificationTree) -> np.ndarray: | |
""" | |
predict the class of new cases. | |
return a vector (1 dimensional array) `y` of predicted class labels for | |
the cases in `x`, that is, `y[i]` contains the predicted class label | |
for row `i` of `x`. | |
arguments: | |
x -- data matrix (2 dimensional array) containing the attribute values | |
of the cases for which predictions are required. | |
tr -- tree object created with the function `tree_grow`. | |
""" | |
pass | |
def tree_grow_b( | |
x: np.ndarray, | |
y: np.ndarray, | |
nmin: int, | |
minleaf: int, | |
nfeat: int, | |
m: int, | |
) -> list[ClassificationTree]: | |
""" | |
grow classification trees using the gini index and return them in a | |
list. | |
arguments: | |
x -- data matrix (2 dimensional array) containing the numeric attribute | |
values. each row contains the attribute values of one training example. | |
y -- vector (1 dimensional array) of class labels. the class label | |
is binary, with values coded as 0 and 1. | |
nmin -- number of observations that a node must contain at least, for | |
it to be allowed to be split. in other words: if a node contains fewer | |
cases than `nmin`, it becomes a leaf node. | |
minleaf -- minimum number of observations required for a leaf node; | |
hence, a split that creates a node with fewer than `minleaf` | |
observations is not acceptable. | |
nfeat -- number of features that should be considered for each split. | |
every time we compute the best split in a particular node, we first | |
draw at random `nfeat` features from which the best split is to be | |
selected. | |
m -- number of bootstrap samples to be drawn. On each bootstrap sample | |
a tree is grown. | |
""" | |
pass | |
def tree_pred_b(x: np.ndarray, trees: list[ClassificationTree]) -> np.ndarray: | |
""" | |
predict the class of new cases taking the majority vote. | |
apply `tree_pred` to `x` using each tree in the list in turn. for each | |
row of `x` the final prediction is obtained by taking the majority vote | |
of the `m` predictions. return a vector (1 dimensional array) `y` of | |
predicted class labels for the cases in `x`, that is, `y[i]` contains | |
the predicted class label for row `i` of `x`. | |
arguments: | |
x -- data matrix (2 dimensional array) containing the attribute values | |
of the cases for which predictions are required. | |
trees -- tree object list created with the function `tree_grow_b`. | |
""" | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment