Created
September 11, 2012 07:49
-
-
Save bdholt1/3696748 to your computer and use it in GitHub Desktop.
Test for memory mapping
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.tree import DecisionTreeRegressor | |
from sklearn.ensemble import ExtraTreesRegressor | |
from sklearn import datasets | |
from sklearn.externals import joblib | |
from sklearn.externals.joblib import Parallel, delayed, cpu_count | |
def dump_to_mmemmap(): | |
X, y = datasets.make_regression(n_samples=1000000, n_features=100, n_informative=100, | |
n_targets=1, bias=0.0, tail_strength=0.5, noise=0.0, shuffle=True, coef=False, | |
random_state=None) | |
X = np.asfortranarray(X, dtype=np.float32) | |
joblib.dump(X, "H:/temp/regression_X.pkl") | |
np.save("H:/temp/regression_y.npy", y) | |
def _parallel_build_trees(n_trees, X, y, verbose): | |
"""Private function used to build a batch of trees within a job.""" | |
trees = [] | |
for i in xrange(n_trees): | |
if verbose > 1: | |
print("building tree %d of %d" % (i + 1, n_trees)) | |
tree = DecisionTreeRegressor() | |
tree.fit(X, y) | |
trees.append(tree) | |
return trees | |
def train_forest1(): | |
X = joblib.load("H:/temp/regression_X.pkl", mmap_mode='c') | |
y = np.load("H:/temp/regression_y.npy", mmap_mode=None) | |
n_jobs=3 | |
n_trees=3 | |
verbose=5 | |
# Parallel loop | |
all_trees = Parallel(n_jobs=n_jobs, verbose=verbose, temp_folder="H:/temp/")( | |
delayed(_parallel_build_trees)( | |
n_trees, | |
X, | |
y, | |
verbose=verbose) | |
for i in xrange(n_jobs)) | |
def train_forest2(): | |
X = joblib.load("H:/temp/regression_X.pkl", mmap_mode='c') | |
y = np.load("H:/temp/regression_y.npy", mmap_mode=None) | |
clf = ExtraTreesRegressor(n_estimators=3, max_depth=10, min_samples_leaf=5, n_jobs=3, verbose=5, bootstrap=False) | |
clf.fit(X, y) | |
if __name__ == "__main__": | |
dump_to_mmemmap() | |
train_forest1() | |
train_forest2() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment