Skip to content

Instantly share code, notes, and snippets.

@y-uti
Last active February 12, 2018 02:12
Show Gist options
  • Save y-uti/1e04a3dc871a5287ccbea3283b4a3d9c to your computer and use it in GitHub Desktop.
Save y-uti/1e04a3dc871a5287ccbea3283b4a3d9c to your computer and use it in GitHub Desktop.
A sample code calculating precision-recall curve from the result of the probability estimation by SVC
<?php
require_once __DIR__ . '/vendor/autoload.php';
use Phpml\Classification\SVC;
use Phpml\SupportVectorMachine\Kernel;
// Load an example dataset published at the web site below
// https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html
//
// Note that this function can't be applied to other datasets because the file
// format is assumed not to be "sparse", see:
// https://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q03:_Data_preparation
function load_dataset($url)
{
$samples = [];
$labels = [];
foreach (file($url, FILE_IGNORE_NEW_LINES) as $line) {
$columns = explode(' ', $line);
$labels[] = array_shift($columns);
$samples[] = array_map(function ($c) { return (float) explode(':', $c, 2)[1]; }, $columns);
}
return [$samples, $labels];
}
$train_data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/svmguide1';
[$train_samples, $train_labels] = load_dataset($train_data_url);
$test_data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/svmguide1.t';
[$test_samples, $test_labels] = load_dataset($test_data_url);
// Build an SVC instance then train and predict samples, with probability estimation enabled
$svc = new SVC(Kernel::RBF, 1.0, 3, 1, 0.0, 0.001, 100, true, true);
$svc->train($train_samples, $train_labels);
$probabilities = $svc->predictProbability($test_samples);
// Sort actual labels of the test set by descending order of probabilities for positive
$positiveness = array_map(function ($p) { return $p['1']; }, $probabilities);
$actuals = array_map(null, $positiveness, $test_labels);
usort($actuals, function ($a1, $a2) { return $a2[0] <=> $a1[0]; });
$actuals = array_map(function ($a) { return $a[1]; }, $actuals);
// Calculate precision and recall
$n_positives = count(array_filter($actuals, function ($a) { return $a === '1'; }));
for ($k = 0, $tp = 0; $k < count($actuals); ++$k) {
$tp += $actuals[$k] === '1' ? 1 : 0;
$precision = $tp / ($k + 1);
$recall = $tp / $n_positives;
echo $precision, ',', $recall, PHP_EOL;
}
@y-uti
Copy link
Author

y-uti commented Feb 12, 2018

Instead of usort on the line 43, rsort($actuals) is also OK. (I always fail to choose a proper sorting method in PHP)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment