Last active
November 8, 2019 14:16
-
-
Save internaut/2db3d0f0c753fa1e6caaa1e6b7e0103b to your computer and use it in GitHub Desktop.
Function to calculate word co-occurrence from document-term matrix and a test using the hypothesis package
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def word_cooccurrence(dtm): | |
""" | |
Calculate the co-document frequency (aka word co-occurrence) matrix for a document-term matrix `dtm`, i.e. how often | |
each pair of tokens occurs together at least once in the same document. | |
:param dtm: (sparse) document-term-matrix of size NxM (N docs, M is vocab size) with raw term counts. | |
:return: co-document frequency (aka word co-occurrence) matrix with shape MxM | |
""" | |
if dtm.ndim != 2: | |
raise ValueError('`dtm` must be a 2D array/matrix') | |
bin_dtm = (dtm >= 1).astype(np.int) | |
return bin_dtm.T @ bin_dtm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from hypothesis import given, strategies as st | |
from hypothesis.extra.numpy import arrays, array_shapes | |
from cooc import word_cooccurrence | |
@given(dtm=arrays(np.int, array_shapes(2, 2), elements=st.integers(min_value=0, max_value=1000))) | |
def test_word_cooccurrence(dtm): | |
res = word_cooccurrence(dtm) | |
n_docs, vocab_size = dtm.shape | |
assert isinstance(res, np.ndarray) | |
assert res.dtype == np.int | |
assert res.ndim == 2 | |
assert res.shape == (vocab_size, vocab_size) | |
assert np.all((res >= 0) & (res <= n_docs)) | |
assert np.array_equal(res, res.T) | |
if np.array_equal(dtm, np.zeros(dtm.shape, dtype=np.int)): | |
assert np.array_equal(res, np.zeros(res.shape, dtype=np.int)) | |
ident = np.eye(n_docs) | |
if n_docs == vocab_size and np.array_equal(dtm, ident): | |
assert np.array_equal(res, ident) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment