Last active
January 21, 2024 10:40
-
-
Save MaartenGr/6d131e497e15ade4b9cb5356d514307c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import scipy.sparse as sp | |
from sklearn.preprocessing import normalize | |
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer | |
class CTFIDFVectorizer(TfidfTransformer): | |
def __init__(self, *args, **kwargs): | |
super(CTFIDFVectorizer, self).__init__(*args, **kwargs) | |
def fit(self, X: sp.csr_matrix, n_samples: int): | |
"""Learn the idf vector (global term weights) """ | |
_, n_features = X.shape | |
df = np.squeeze(np.asarray(X.sum(axis=0))) | |
idf = np.log(n_samples / df) | |
self._idf_diag = sp.diags(idf, offsets=0, | |
shape=(n_features, n_features), | |
format='csr', | |
dtype=np.float64) | |
return self | |
def transform(self, X: sp.csr_matrix) -> sp.csr_matrix: | |
"""Transform a count-based matrix to c-TF-IDF """ | |
X = X * self._idf_diag | |
X = normalize(X, axis=1, norm='l1', copy=False) | |
return X |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment