Skip to content

Instantly share code, notes, and snippets.

Created May 25, 2017 14:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save anonymous/023cc2ac9007765da8e089fdcd6f44eb to your computer and use it in GitHub Desktop.
Save anonymous/023cc2ac9007765da8e089fdcd6f44eb to your computer and use it in GitHub Desktop.
<?php
/**
* Naive Bayes classifier
*/
include __DIR__ . '/../vendor/autoload.php';
function train($samples)
{
$samples_count = count($samples);
$classes = [];
$freq = [];
foreach ($samples as $sample) {
$label = $sample['label'];
$classes[$label] = ($classes[$label] ?? 0) + 1;
foreach ($sample['features'] as $feature) {
$freq[$label][$feature] = ($freq[$label][$feature] ?? 0) + 1;
}
}
foreach ($freq as $label => $features) {
foreach ($features as $feature => $count) {
$freq[$label][$feature] = $count / $classes[$label];
}
}
foreach ($classes as $label => $count) {
$classes[$label] = $count / $samples_count;
}
return (object) compact('classes', 'freq');
}
function classification($classifier, $features)
{
$res = ['m' => PHP_INT_MAX, 'label' => null];
foreach ($classifier->classes as $label => $p) {
$m = -log($p) + collect($features)->sum(function ($feature) use ($classifier, $label) {
return - log(array_get($classifier->freq, "{$label}.{$feature}", 10**(-7)));
});
if ($m < $res['m']) {
$res = ['m' => $m, 'label' => $label];
}
}
return $res['label'];
}
function extract_features($data)
{
return str_split($data);
}
$samples = [
['label' => 'd', 'value' => '234'],
['label' => 'd', 'value' => 'e14'],
['label' => 'd', 'value' => '95094'],
['label' => 'd', 'value' => '456'],
['label' => 'w', 'value' => 'sdfsldf'],
['label' => 'w', 'value' => 'pwper'],
['label' => 'w', 'value' => 'eee'],
['label' => 'w', 'value' => 'ee1sd'],
['label' => 'w', 'value' => 'ee12d'],
];
foreach ($samples as &$sample) {
$sample['features'] = extract_features($sample['value']);
}
unset($sample);
$classifier = train($samples);
$tests = [
'180456',
'mcnxc',
's89sf66sdf',
'001u',
'e0091',
'ccc',
];
foreach ($tests as $test) {
echo $test . ': ' . classification($classifier, extract_features($test)) . PHP_EOL;
}
/**
* output:
* 180456: d
* mcnxc: w
* s89sf66sdf: w
* 001u: d
* e0091: d
* ccc: w
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment