Skip to content

Instantly share code, notes, and snippets.

Last active November 9, 2020 11:22
Show Gist options
  • Save asdfgeoff/5d63704c17052e642d3ea93351dda152 to your computer and use it in GitHub Desktop.
Save asdfgeoff/5d63704c17052e642d3ea93351dda152 to your computer and use it in GitHub Desktop.
A DIY implementation of Multinomial Naive Bayes using NumPy | Details →
from typing import Callable, Union
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.utils.validation import check_X_y, check_array
from IPython.display import display
array_like = Union[list, np.ndarray]
matrix_like = Union[np.ndarray, pd.DataFrame]
def make_spam_dataset() -> (pd.DataFrame, np.ndarray, Callable):
""" Create a small toy dataset for MultinomialNB implementation
X: word count matrix
y: indicator of whether or not message is spam
msg_tx_func: a function to transform new test data into word count matrix
vocab = [
'secret', 'offer', 'low', 'price', 'valued', 'customer', 'today',
'dollar', 'million', 'sports', 'is', 'for', 'play', 'healthy', 'pizza'
spam = [
'million dollar offer',
'secret offer today',
'secret is secret'
not_spam = [
'low price for valued customer',
'play secret sports today',
'sports is healthy',
'low price pizza'
all_messages = spam + not_spam
vectorizer = CountVectorizer(vocabulary=vocab)
word_counts = vectorizer.fit_transform(all_messages).toarray()
df = pd.DataFrame(word_counts, columns=vocab)
is_spam = [1] * len(spam) + [0] * len(not_spam)
msg_tx_func = lambda x: vectorizer.transform(x).toarray()
return df.to_numpy(), np.array(is_spam), msg_tx_func
class NaiveBayes(object):
""" DIY implementation of binary Naive Bayes classifier based on categorical data
- inspired by
- cannot fully vectorize fit method, since classes may have unequal sizes
def __init__(self, alpha=1.0):
""" """
self.alpha = alpha
self.prior = None
self.word_counts = None
self.word_proba = None
self.is_fitted_ = False
def fit(self, X: matrix_like, y: array_like):
""" Fit training data for Naive Bayes classifier """
# not strictly necessary, but this ensures we have clean input
X, y = check_X_y(X, y)
n = X.shape[0]
X_by_class = np.array([X[y == c] for c in np.unique(y)])
self.prior = np.array([len(X_class) / n for X_class in X_by_class])
self.word_counts = np.array([sub_arr.sum(axis=0) for sub_arr in X_by_class]) + self.alpha
self.lk_word = self.word_counts / self.word_counts.sum(axis=1).reshape(-1, 1)
self.is_fitted_ = True
return self
def predict_proba(self, X: matrix_like):
""" Predict probability of class membership """
assert self.is_fitted_, 'Model must be fit before predicting'
X = check_array(X)
# loop over each observation to calculate conditional probabilities
class_numerators = np.zeros(shape=(X.shape[0], self.prior.shape[0]))
for i, x in enumerate(X):
word_exists = x.astype(bool)
lk_words_present = self.lk_word[:, word_exists] ** x[word_exists]
lk_message = (lk_words_present).prod(axis=1)
class_numerators[i] = lk_message * self.prior
normalize_term = class_numerators.sum(axis=1).reshape(-1, 1)
conditional_probas = class_numerators / normalize_term
assert (conditional_probas.sum(axis=1) - 1 < 0.001).all(), 'Rows should sum to 1'
return conditional_probas
def predict(self, X: matrix_like):
""" Predict class with highest probability """
return self.predict_proba(X).argmax(axis=1)
def test_against_benchmark():
""" Check that DIY model matches outputs from scikit-learn estimator """
from sklearn.naive_bayes import MultinomialNB
X, y, _ = make_spam_dataset()
bench = MultinomialNB().fit(X, y)
model = NaiveBayes(alpha=1).fit(X, y)
assert (model.prior / np.exp(bench.class_log_prior_) - 1 < 0.001).all()
print('[✔︎] Identical prior probabilities')
assert (model.lk_word / np.exp(bench.feature_log_prob_) - 1 < 0.001).all()
print('[✔︎] Identical word likelihoods')
assert (model.predict_proba(X) / bench.predict_proba(X) - 1 < 0.001).all()
print('[✔︎] Identical predictions')
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment