Last active
July 21, 2021 04:55
-
-
Save pazzo83/a7bdf5ef69fea8b5cf9bac0036664a51 to your computer and use it in GitHub Desktop.
TFIDF transformer for MLJ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MLJModelInterface, MLJBase, TextAnalysis, SparseArrays | |
MLJModelInterface.@mlj_model mutable struct TfidfTransformer <: MLJModelInterface.Unsupervised | |
end | |
struct TfidfTransformerResult | |
vocab::Vector{String} | |
idf_vector::Vector{Float64} | |
end | |
function MLJBase.fit(transformer::TfidfTransformer, verbosity::Int, X::Vector{String}) | |
corpus = Corpus(StringDocument.(X)) | |
return MLJBase.fit(transformer, verbosity, corpus) | |
end | |
function MLJBase.fit(transformer::TfidfTransformer, verbosity::Int, X::Vector{StringDocument{String}}) | |
corpus = Corpus(X) | |
return MLJBase.fit(transformer, verbosity, corpus) | |
end | |
function MLJBase.fit(transformer::TfidfTransformer, verbosity::Int, X::Corpus) | |
update_lexicon!(X) | |
m = DocumentTermMatrix(X) | |
n = size(m.dtm, 1) | |
documents_containing_term = vec(sum(m.dtm .> 0, dims=1)) | |
idf = log.(n ./ documents_containing_term) | |
res = TfidfTransformerResult(m.terms, idf) | |
return res, nothing, NamedTuple() | |
end | |
function build_tfidf!(dtm::SparseMatrixCSC{T}, tfidf::SparseMatrixCSC{F}, idf_vector::Vector{F}) where {T <: Real, F <: AbstractFloat} | |
rows = rowvals(dtm) | |
dtmvals = nonzeros(dtm) | |
tfidfvals = nonzeros(tfidf) | |
@assert size(dtmvals) == size(tfidfvals) | |
p = size(dtm, 2) | |
# TF tells us what proportion of a document is defined by a term | |
words_in_documents = F.(sum(dtm, dims=2)) | |
oneval = one(F) | |
for i = 1:p | |
for j in nzrange(dtm, i) | |
row = rows[j] | |
tfidfvals[j] = dtmvals[j] / max(words_in_documents[row], oneval) * idf_vector[i] | |
end | |
end | |
return tfidf | |
end | |
function MLJBase.transform(transformer::TfidfTransformer, result::TfidfTransformerResult, v::Vector{StringDocument{String}}) | |
corpus = Corpus(v) | |
return MLJBase.transform(transformer, result, corpus) | |
end | |
function MLJBase.transform(::TfidfTransformer, result::TfidfTransformerResult, v::Corpus) | |
m = DocumentTermMatrix(v, result.vocab) | |
tfidf = similar(m.dtm, eltype(result.idf_vector)) | |
build_tfidf!(m.dtm, tfidf, result.idf_vector) | |
return tfidf | |
end |
Yes, I think that would work great!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Okay, should I create this at JuliaAI then?