Skip to content

Instantly share code, notes, and snippets.

@profjsb
Created November 10, 2016 19:54
Show Gist options
  • Save profjsb/148a9087c2e026d71c492e81f7a206fd to your computer and use it in GitHub Desktop.
Save profjsb/148a9087c2e026d71c492e81f7a206fd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "import math\n\ndef clean_mtry(x,mid,nfeat):\n ret = []\n for i in x:\n if i < mid:\n ret.append(max(2,math.floor(i)))\n else:\n ret.append(min(nfeat,math.ceil(i)))\n return sorted(list(set(ret)))\n\n\ndef get_defaults(n_features, task_type=\"classification\",\n model_type=\"random_forest\",\n opt_criteria=\"fast\"):\n\n \"\"\"\n n_features number of features\n task_type classification, regression\n model_type random_forest\n SVM [not implemented yet]\n\n opt_criteria quick = do only a few model searches\n comprehensive = do fine grained searches\n\n \"\"\"\n \n # TODO...add more defaults for different model types\n if model_type != \"random_forest\":\n raise\n\n bootstrap = [True]\n oob_score = [True]\n n_jobs = -1\n max_depth = [None]\n\n if opt_criteria == \"fast\":\n n_estimators = [50]\n min_samples_split = [2]\n min_samples_leaf = [1]\n\n if task_type == \"classification\":\n criterion = [\"gini\",\"entropy\"]\n mid = math.sqrt(n_features)\n elif task_type == \"regression\":\n criterion = [\"mse\"]\n mid = max(1, math.ceil(n_features / 3.0))\n else:\n raise\n \n initial = [0.85*mid,mid,1.15*mid]\n max_features = clean_mtry(initial,mid,n_features)\n \n elif opt_criteria == \"comprehensive\":\n n_estimators = [100,250]\n min_samples_split = [2]\n min_samples_leaf = [1,2,5]\n\n if task_type == \"classification\":\n criterion = [\"gini\",\"entropy\"]\n mid = math.sqrt(n_features)\n\n elif task_type == \"regression\":\n criterion = [\"mse\",\"mae\"]\n mid = max(1, math.ceil(n_features / 3.0))\n else:\n raise\n\n initial = [0.5*mid,0.85*mid,mid,1.15*mid,1.3*mid,1.5*mid,2*mid]\n max_features = clean_mtry(initial,mid,n_features)\n \n ret = {\"bootstrap\": bootstrap,\n \"oob_score\": oob_score,\n \"n_jobs\": -1,\n \"max_depth\": max_depth,\n \"max_features\": max_features,\n \"criterion\": criterion,\n \"min_samples_leaf\": min_samples_leaf,\n \"min_samples_split\": min_samples_split,\n \"n_estimators\": n_estimators}\n\n return ret\n",
"execution_count": 45,
"outputs": []
},
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "get_defaults(100, task_type=\"classification\",\n model_type=\"random_forest\",\n opt_criteria=\"fast\")",
"execution_count": 46,
"outputs": [
{
"execution_count": 46,
"output_type": "execute_result",
"data": {
"text/plain": "{'bootstrap': [True],\n 'criterion': ['gini', 'entropy'],\n 'max_depth': [None],\n 'max_features': [8, 10, 12],\n 'min_samples_leaf': [1],\n 'min_samples_split': [2],\n 'n_estimators': [50],\n 'n_jobs': -1,\n 'oob_score': [True]}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "get_defaults(100, task_type=\"classification\",\n model_type=\"random_forest\",\n opt_criteria=\"comprehensive\")",
"execution_count": 47,
"outputs": [
{
"execution_count": 47,
"output_type": "execute_result",
"data": {
"text/plain": "{'bootstrap': [True],\n 'criterion': ['gini', 'entropy'],\n 'max_depth': [None],\n 'max_features': [5, 8, 10, 12, 13, 15, 20],\n 'min_samples_leaf': [1, 2, 5],\n 'min_samples_split': [2],\n 'n_estimators': [100, 250],\n 'n_jobs': -1,\n 'oob_score': [True]}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "get_defaults(2000, task_type=\"classification\",\n model_type=\"random_forest\",\n opt_criteria=\"comprehensive\")",
"execution_count": 48,
"outputs": [
{
"execution_count": 48,
"output_type": "execute_result",
"data": {
"text/plain": "{'bootstrap': [True],\n 'criterion': ['gini', 'entropy'],\n 'max_depth': [None],\n 'max_features': [22, 38, 45, 52, 59, 68, 90],\n 'min_samples_leaf': [1, 2, 5],\n 'min_samples_split': [2],\n 'n_estimators': [100, 250],\n 'n_jobs': -1,\n 'oob_score': [True]}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "get_defaults(8, task_type=\"classification\",\n model_type=\"random_forest\",\n opt_criteria=\"comprehensive\")",
"execution_count": 49,
"outputs": [
{
"execution_count": 49,
"output_type": "execute_result",
"data": {
"text/plain": "{'bootstrap': [True],\n 'criterion': ['gini', 'entropy'],\n 'max_depth': [None],\n 'max_features': [2, 3, 4, 5, 6],\n 'min_samples_leaf': [1, 2, 5],\n 'min_samples_split': [2],\n 'n_estimators': [100, 250],\n 'n_jobs': -1,\n 'oob_score': [True]}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "get_defaults(100, task_type=\"regression\",\n model_type=\"random_forest\",\n opt_criteria=\"comprehensive\")",
"execution_count": 51,
"outputs": [
{
"execution_count": 51,
"output_type": "execute_result",
"data": {
"text/plain": "{'bootstrap': [True],\n 'criterion': ['mse', 'mae'],\n 'max_depth': [None],\n 'max_features': [17, 28, 34, 40, 45, 51, 68],\n 'min_samples_leaf': [1, 2, 5],\n 'min_samples_split': [2],\n 'n_estimators': [100, 250],\n 'n_jobs': -1,\n 'oob_score': [True]}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"language_info": {
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"mimetype": "text/x-python",
"file_extension": ".py",
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"version": "3.5.2"
},
"kernelspec": {
"name": "py3k",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment