Skip to content

Instantly share code, notes, and snippets.

@adrianseeley
Last active October 21, 2015 01:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adrianseeley/3b760ff5bfc65c1d56c2 to your computer and use it in GitHub Desktop.
Save adrianseeley/3b760ff5bfc65c1d56c2 to your computer and use it in GitHub Desktop.
Mighty RF
var fs = require('fs');
var rl = require('readline');
function read_train (cb) {
var train_inputs = [];
var train_outputs = [];
var headers_read = false;
var file = rl.createInterface({input: fs.createReadStream('./train.csv')});
file.on('line', function (line) {
if (!headers_read) {
headers_read = true;
return;
}
var line_parts = line.split(',');
if (line_parts.length == 1) {
return;
}
train_outputs.push(parseFloat(line_parts[0]));
var training_input_values = [];
for (var line_part_idx = 1; line_part_idx < line_parts.length; line_part_idx++) {
training_input_values.push(parseFloat(line_parts[line_part_idx]));
}
train_inputs.push(training_input_values);
});
file.on('close', function () {
return cb(train_inputs, train_outputs);
});
};
function read_test (cb) {
console.log('reading test');
var test_inputs = [];
var test_ids = [];
var test_id = 1;
var headers_read = false;
var file = rl.createInterface({input: fs.createReadStream('./test.csv')});
file.on('line', function (line) {
if (!headers_read) {
headers_read = true;
return;
}
var line_parts = line.split(',');
if (line_parts.length == 1) {
return;
}
test_ids.push(test_id);
test_id++;
var test_input_values = [];
for (var line_part_idx = 0; line_part_idx < line_parts.length; line_part_idx++) {
test_input_values.push(parseFloat(line_parts[line_part_idx]));
}
test_inputs.push(test_input_values);
});
file.on('close', function () {
console.log('read ' + test_inputs.length + ' test cases, with ' + test_inputs[0].length + ' components each');
return cb(test_inputs, test_ids);
});
};
function partition_train (number_of_partitions, train_inputs, train_outputs) {
var train_partition_inputs = [];
var train_partition_outputs = [];
for (var partition_idx = 0; partition_idx < number_of_partitions; partition_idx++) {
train_partition_inputs.push([]);
train_partition_outputs.push([]);
}
var partition_idx = 0;
for (var train_idx = 0; train_idx < train_inputs.length; train_idx++) {
var random_idx = Math.floor(Math.random() * train_inputs.length);
train_partition_inputs[partition_idx].push(train_inputs[random_idx]);
train_partition_outputs[partition_idx].push(train_outputs[random_idx]);
partition_idx++;
if (partition_idx >= number_of_partitions) {
partition_idx = 0;
}
}
return [train_partition_inputs, train_partition_outputs];
};
function create_random_forest_tree_node_class_distribution (train_outputs) {
var class_distribution = {};
for (train_idx = 0; train_idx < train_outputs.length; train_idx++) {
if (!class_distribution.hasOwnProperty(train_outputs[train_idx])) {
class_distribution[train_outputs[train_idx]] = 0;
}
class_distribution[train_outputs[train_idx]]++;
}
for (var class_key in class_distribution) {
class_distribution[class_key] /= train_outputs.length;
}
return class_distribution;
};
function calculate_random_forest_tree_node_class_distribution_entropy (class_distribution) {
var entropy = 0;
for (var class_key in class_distribution) {
entropy += -class_distribution[class_key] * log2(class_distribution[class_key]);
}
return entropy;
};
var log2_base = Math.log(2);
function log2 (value) {
return Math.log(value) / log2_base;
};
function measure_random_forest_tree_node_information_gain (entropy_before, component_idx, component_value, train_inputs, train_outputs) {
var left_classes = {};
var right_classes = {};
var left_classes_count = 0;
var right_classes_count = 0;
for (var train_idx = 0; train_idx < train_inputs.length; train_idx++) {
if (train_inputs[train_idx][component_idx] <= component_value) {
if (!left_classes.hasOwnProperty(train_outputs[train_idx])) {
left_classes[train_outputs[train_idx]] = 0;
}
left_classes[train_outputs[train_idx]]++;
left_classes_count++;
} else {
if (!right_classes.hasOwnProperty(train_outputs[train_idx])) {
right_classes[train_outputs[train_idx]] = 0;
}
right_classes[train_outputs[train_idx]]++;
right_classes_count++;
}
}
var left_weight = left_classes_count / train_inputs.length;
var right_weight = right_classes_count / train_inputs.length;
for (var left_class in left_classes) {
left_classes[left_class] /= left_classes_count;
}
for (var right_class in right_classes) {
right_classes[right_class] /= right_classes_count;
}
var left_entropy = calculate_random_forest_tree_node_class_distribution_entropy(left_classes);
var right_entropy = calculate_random_forest_tree_node_class_distribution_entropy(right_classes);
var entropy_after = (left_entropy * left_weight) + (right_entropy * right_weight);
var information_gain = entropy_before - entropy_after;
return information_gain;
};
function create_random_forest_tree_node (current_tree_depth, maximum_tree_depth, components_observed, train_inputs, train_outputs) {
var random_forest_tree_node = {
component_idx: null,
component_value: null,
left: null,
right: null,
class_distribution: null
};
if (train_inputs.length == 0) {
// no cases to consider
random_forest_tree_node.class_distribution = {};
return random_forest_tree_node;
} else if (current_tree_depth == maximum_tree_depth) {
// calculate class distribution for terminal node
random_forest_tree_node.class_distribution = create_random_forest_tree_node_class_distribution(train_outputs);
return random_forest_tree_node;
} else {
// calculate the entropy of the current node before the split
var class_distribution = create_random_forest_tree_node_class_distribution(train_outputs);
var entropy_before = calculate_random_forest_tree_node_class_distribution_entropy(class_distribution);
// find highest information gain for branch node
var information_gain = null;
for (var components_observed_idx = 0; components_observed_idx < components_observed.length; components_observed_idx++) {
for (train_idx = 0; train_idx < train_inputs.length; train_idx++) {
var current_information_gain = measure_random_forest_tree_node_information_gain(entropy_before, components_observed[components_observed_idx], train_inputs[train_idx][components_observed[components_observed_idx]], train_inputs, train_outputs);
if (information_gain == null || current_information_gain > information_gain) {
information_gain = current_information_gain;
random_forest_tree_node.component_idx = components_observed[components_observed_idx];
random_forest_tree_node.component_value = train_inputs[train_idx][components_observed[components_observed_idx]];
}
}
}
// create left right train inputs and outputs
var left_train_inputs = [];
var right_train_inputs = [];
var left_train_outputs = [];
var right_train_outputs = [];
for (train_idx = 0; train_idx < train_inputs.length; train_idx++) {
if (train_inputs[train_idx][random_forest_tree_node.component_idx] <= random_forest_tree_node.component_value) {
left_train_inputs.push(train_inputs[train_idx]);
left_train_outputs.push(train_outputs[train_idx]);
} else {
right_train_inputs.push(train_inputs[train_idx]);
right_train_outputs.push(train_outputs[train_idx]);
}
}
// create child nodes
random_forest_tree_node.left = create_random_forest_tree_node(current_tree_depth + 1, maximum_tree_depth, components_observed, left_train_inputs, left_train_outputs);
random_forest_tree_node.right = create_random_forest_tree_node(current_tree_depth + 1, maximum_tree_depth, components_observed, right_train_inputs, right_train_outputs);
return random_forest_tree_node;
}
};
function create_random_forest_tree (maximum_tree_depth, number_of_components_observed, train_inputs, train_outputs) {
var components_observed = [];
for (var component_idx = 0; component_idx < number_of_components_observed; component_idx++) {
components_observed.push(Math.floor(Math.random() * train_inputs[0].length));
}
return create_random_forest_tree_node(0, maximum_tree_depth, components_observed, train_inputs, train_outputs);
};
function create_random_forest (number_of_trees, maximum_tree_depth, number_of_components_observed, train_partition_inputs, train_partition_outputs) {
var random_forest = [];
var partition_idx = 0;
for (var tree_idx = 0; tree_idx < number_of_trees; tree_idx++) {
console.log(tree_idx);
random_forest.push(create_random_forest_tree(maximum_tree_depth, number_of_components_observed, train_partition_inputs[partition_idx], train_partition_outputs[partition_idx]));
partition_idx++;
if (partition_idx >= train_partition_inputs.length) {
partition_idx = 0;
}
}
return random_forest;
};
function recurse_lookup_random_forest_tree_node (node, single_test_input) {
if (node.class_distribution == null) {
if (single_test_input[node.component_idx] <= node.component_value) {
return recurse_lookup_random_forest_tree_node(node.left, single_test_input);
} else {
return recurse_lookup_random_forest_tree_node(node.right, single_test_input);
}
} else {
return node.class_distribution;
}
};
function run_random_forest (random_forest, test_inputs) {
var test_outputs = [];
for (var test_idx = 0; test_idx < test_inputs.length; test_idx++) {
var class_estimates = {};
for (var tree_idx = 0; tree_idx < random_forest.length; tree_idx++) {
var current_class_estimates = recurse_lookup_random_forest_tree_node(random_forest[tree_idx], test_inputs[test_idx]);
for (var class_estimate in current_class_estimates) {
if (!class_estimates.hasOwnProperty(class_estimate)) {
class_estimates[class_estimate] = 0;
}
class_estimates[class_estimate] += current_class_estimates[class_estimate];
}
}
var largest_class_estimate = null;
var largest_class_estimate_at = null;
for (var class_estimate in class_estimates) {
if (largest_class_estimate == null || class_estimates[class_estimate] > largest_class_estimate) {
largest_class_estimate = class_estimate;
largest_class_estimate_at = class_estimates[class_estimate];
}
}
test_outputs.push(largest_class_estimate);
}
return test_outputs;
};
function calculate_validation_score (validation_outputs, est_validation_outputs) {
var score = 0;
for (var validation_idx = 0; validation_idx < validation_outputs.length; validation_idx++) {
if (validation_outputs[validation_idx] == est_validation_outputs[validation_idx]) {
score++;
}
}
return score / validation_outputs.length;
};
function monte_carlo_random_forest (partition_low, partition_high, partition_step, number_of_trees_low, number_of_trees_high, number_of_trees_step, maximum_tree_depth_low, maximum_tree_depth_high, maximum_tree_depth_step, number_of_components_observed_low, number_of_components_observed_high, number_of_components_observed_step, train_inputs, train_outputs, validation_inputs, validation_outputs) {
var results = [];
for (var partition_idx = partition_low; partition_idx <= partition_high; partition_idx += partition_step) {
var train_partitions = partition_train(partition_idx, train_inputs, train_outputs);
var train_partition_inputs = train_partitions[0];
var train_partition_outputs = train_partitions[1];
for (var number_of_trees_idx = number_of_trees_low; number_of_trees_idx <= number_of_trees_high; number_of_trees_idx += number_of_trees_step) {
for (var maximum_tree_depth_idx = maximum_tree_depth_low; maximum_tree_depth_idx <= maximum_tree_depth_high; maximum_tree_depth_idx += maximum_tree_depth_step) {
for (var number_of_components_observed_idx = number_of_components_observed_low; number_of_components_observed_idx <= number_of_components_observed_high; number_of_components_observed_idx += number_of_components_observed_step) {
process.stdout.cursorTo(0);
process.stdout.write(
'partition: [' + partition_low + ' (' + partition_idx + ') ' + partition_high + ']' +
' Ntrees: [' + number_of_trees_low + ' (' + number_of_trees_idx + ') ' + number_of_trees_high + ']' +
' treeD: [' + maximum_tree_depth_low + ' (' + maximum_tree_depth_idx + ') ' + maximum_tree_depth_high + ']' +
' numCO: [' + number_of_components_observed_low + ' (' + number_of_components_observed_idx + ') ' + number_of_components_observed_high + ']');
var random_forest = create_random_forest(number_of_trees_idx, maximum_tree_depth_idx, number_of_components_observed_idx, train_partition_inputs, train_partition_outputs);
var random_forest_validation_outputs = run_random_forest(random_forest, validation_inputs);
var validation_score = calculate_validation_score(validation_outputs, random_forest_validation_outputs);
results.push([partition_idx, number_of_trees_idx, maximum_tree_depth_idx, number_of_components_observed_idx, validation_score]);
write_monte_carlo_results(results, true);
}
}
}
}
return results;
};
function write_test (test_ids, test_outputs) {
console.log('writing test to est[date].csv');
var str = 'ImageId,Label\n';
for (var test_idx = 0; test_idx < test_ids.length; test_idx++) {
str += test_ids[test_idx] + ',' + test_outputs[test_idx] + '\n';
}
fs.writeFileSync('est' + new Date().getTime() + '.csv', str, 'utf8');
return;
};
function write_monte_carlo_results (monte_carlo_results, partial) {
var str = 'number_of_partitions,number_of_trees,maximum_tree_depth,number_of_components_observed,validation_score\n';
for (var result_idx = 0; result_idx < monte_carlo_results.length; result_idx++) {
str += monte_carlo_results[result_idx].join(',') + '\n';
}
if (!partial) {
console.log('writing monte carlo results to mc[date].csv');
fs.writeFileSync('mc' + new Date().getTime() + '.csv', str, 'utf8');
} else {
fs.writeFileSync('mc_partial.csv', str, 'utf8');
}
};
var cfg_number_of_train_partitions = 500;
var cfg_number_of_random_forest_trees = 500;
var cfg_maximum_random_forest_tree_depth = 100;
var cfg_number_of_components_observed_per_random_forest_tree = 100;
var cfg_validation_percent = 0.5;
var cfg_monte_carlo_partition_low = 25;
var cfg_monte_carlo_partition_high = 25;
var cfg_monte_carlo_partition_step = 1;
var cfg_monte_carlo_number_of_trees_low = 25;
var cfg_monte_carlo_number_of_trees_high = 25;
var cfg_monte_carlo_number_of_trees_step = 1;
var cfg_monte_carlo_maximum_tree_depth_low = 100;
var cfg_monte_carlo_maximum_tree_depth_high = 100;
var cfg_monte_carlo_maximum_tree_depth_step = 1;
var cfg_monte_carlo_number_of_components_observed_low = 100;
var cfg_monte_carlo_number_of_components_observed_high = 100;
var cfg_monte_carlo_number_of_components_observed_step = 1;
if (cfg_number_of_train_partitions > cfg_number_of_random_forest_trees) {
throw 'there should always be more trees than train partitions';
}
if (cfg_monte_carlo_partition_high > cfg_monte_carlo_number_of_trees_low) {
//throw 'monte carlo error, there should always be more trees than train partitions';
}
read_train(function (train_inputs, train_outputs) {
//train_inputs = train_inputs.splice(0, 5000);
//train_outputs = train_outputs.splice(0, 5000);
/*var validation_count = Math.floor(train_inputs.length * cfg_validation_percent);
var validation_inputs = train_inputs.splice(0, validation_count);
var validation_outputs = train_outputs.splice(0, validation_count)
var monte_carlo_results = monte_carlo_random_forest(cfg_monte_carlo_partition_low, cfg_monte_carlo_partition_high, cfg_monte_carlo_partition_step, cfg_monte_carlo_number_of_trees_low, cfg_monte_carlo_number_of_trees_high, cfg_monte_carlo_number_of_trees_step, cfg_monte_carlo_maximum_tree_depth_low, cfg_monte_carlo_maximum_tree_depth_high, cfg_monte_carlo_maximum_tree_depth_step, cfg_monte_carlo_number_of_components_observed_low, cfg_monte_carlo_number_of_components_observed_high, cfg_monte_carlo_number_of_components_observed_step, train_inputs, train_outputs, validation_inputs, validation_outputs);
write_monte_carlo_results(monte_carlo_results, false);*/
var train_partitions = partition_train(cfg_number_of_train_partitions, train_inputs, train_outputs);
var train_partition_inputs = train_partitions[0];
var train_partition_outputs = train_partitions[1];
random_forest = create_random_forest(cfg_number_of_random_forest_trees, cfg_maximum_random_forest_tree_depth, cfg_number_of_components_observed_per_random_forest_tree, train_partition_inputs, train_partition_outputs);
read_test(function (test_inputs, test_ids) {
var test_outputs = run_random_forest(random_forest, test_inputs);
test_outputs = ret_test_outputs;
write_test(test_ids, test_outputs);
});
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment