Created
May 19, 2019 18:29
-
-
Save ychennay/33cd7decf12ad8a343734400003793ba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import numpy as np | |
from sklearn.base import BaseEstimator | |
def bootstrap_contour_predict(bootstrapped_models: List[BaseEstimator], xx: np.array, yy: np.array) -> np.ndarray: | |
""" | |
Makes a prediction for len(xx) * len(yy) data points - a mesh grid | |
:param bootstrapped_models: a list of fitted sklearn estimators | |
:param xx: Numpy array of values from 1st dimension mesh axis | |
:param yy: Numpy array of values from 2nd dimension mesh axis | |
:return: 2D Numpy array of predicted values for mesh grid | |
""" | |
Z = np.zeros(xx.shape) | |
for model in bootstrapped_models: | |
Z += model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape) | |
# average and then round to nearest whole number | |
return np.around(Z / (len(bootstrapped_models) * 1.0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment