Skip to content

Instantly share code, notes, and snippets.

Created April 27, 2019 01:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jiahao/5dc9242e3e6d0c8e5c59e2251c07a955 to your computer and use it in GitHub Desktop.
Save jiahao/5dc9242e3e6d0c8e5c59e2251c07a955 to your computer and use it in GitHub Desktop.
Multinomial naive Bayes in Julia, allowing for generic numeric types for the conditional probabilities. (including rational numbers) that allow you to calculate exact probabilities.
struct NaiveBayes{T, V<:AbstractVector, M<:AbstractMatrix}
train(::Type{NaiveBayes}, T::Type{R}, features, labels, α = 1) where R<:Real =
train(NaiveBayes{T, Vector{T}, Matrix{T}}, features, labels, α)
for (typ, op) in ((Rational, ://), (Real, :/)) @eval begin
function train(::Type{NaiveBayes{T, S, R}},
features, labels, α = 1) where
{T<:$typ, S<:AbstractVector{T}, R<:AbstractMatrix{T}}
m, n = size(features)
pr = zeros(T, 2, n) #probabilities
#Calculate priors and denominators
counts = zeros(Int, 2)
counts_category = zeros(Int, 2)
for i in 1:m
k = if labels[i] 1 else 2 end
counts[k] += 1
counts_category[k] += sum(view(features, i, :))
priors = ($op)(counts, m)
for j in 1:n
counts = zeros(Int, 2)
for i in 1:m
k = if labels[i] 1 else 2 end
counts[k] += features[i, j]
for k in 1:2
pr[k, j] = ($op)(counts[k] + α, counts_category[k] + α * n)
return NaiveBayes{T, S, R}(pr, priors)
function predict(NB::NaiveBayes, T::Type{R}, features,
normalize::Bool=false) where R<:($typ)
pr = convert(Vector{T}, NB.priors)
n = size(features, 1)
for k=1:2, j=1:n
x = features[j]
if x != 0
pr[k] *= (NB.probabilities[k, j])^x
if normalize
pr = ($op)(pr, sum(pr))
return T(pr[1]), T(pr[2])
predict(NB::NaiveBayes, features) = predict(NB, Float32, features)
#Example from
featureset =
#Chinese Beijing Shanghai Macao Japan Tokyo
[2 1 0 0 0 0
2 0 1 0 0 0
1 0 0 1 0 0
1 0 0 0 1 1]
labels = [true, true, true, false]
NB = train(NaiveBayes, Rational, featureset, labels)
@assert NB.priors == [3//4, 1//4]
@assert NB.probabilities[1, 1] == 3//7
@assert NB.probabilities[1, 5] == 1//14
@assert NB.probabilities[1, 6] == 1//14
@assert NB.probabilities[2, 1] == 2//9
@assert NB.probabilities[2, 5] == 2//9
@assert NB.probabilities[2, 6] == 2//9
pred_yes, pred_no = predict(NB, Rational, [3, 0, 0, 0, 1, 1])
#Unnormalized probabilities
upr_yes = 3//4 * (3//7)^3 * 1//14 * 1//14
upr_no = 1//4 * (2//9)^3 * 2//9 * 2//9
pr_yes = upr_yes // (upr_yes + upr_no)
pr_no = upr_no // (upr_yes + upr_no)
@assert pred_yes == pr_yes
@assert pred_no == pr_no
NBF = train(NaiveBayes, Float32, featureset, labels)
predict(NB, Float32, [3, 0, 0, 0, 1, 1])
predict(NBF, Float32, [3, 0, 0, 0, 1, 1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment