Skip to content

Instantly share code, notes, and snippets.

@pazzo83
Last active July 21, 2021 04:55
Show Gist options
  • Save pazzo83/a7bdf5ef69fea8b5cf9bac0036664a51 to your computer and use it in GitHub Desktop.
Save pazzo83/a7bdf5ef69fea8b5cf9bac0036664a51 to your computer and use it in GitHub Desktop.
TFIDF transformer for MLJ
using MLJModelInterface, MLJBase, TextAnalysis, SparseArrays
MLJModelInterface.@mlj_model mutable struct TfidfTransformer <: MLJModelInterface.Unsupervised
end
struct TfidfTransformerResult
vocab::Vector{String}
idf_vector::Vector{Float64}
end
function MLJBase.fit(transformer::TfidfTransformer, verbosity::Int, X::Vector{String})
corpus = Corpus(StringDocument.(X))
return MLJBase.fit(transformer, verbosity, corpus)
end
function MLJBase.fit(transformer::TfidfTransformer, verbosity::Int, X::Vector{StringDocument{String}})
corpus = Corpus(X)
return MLJBase.fit(transformer, verbosity, corpus)
end
function MLJBase.fit(transformer::TfidfTransformer, verbosity::Int, X::Corpus)
update_lexicon!(X)
m = DocumentTermMatrix(X)
n = size(m.dtm, 1)
documents_containing_term = vec(sum(m.dtm .> 0, dims=1))
idf = log.(n ./ documents_containing_term)
res = TfidfTransformerResult(m.terms, idf)
return res, nothing, NamedTuple()
end
function build_tfidf!(dtm::SparseMatrixCSC{T}, tfidf::SparseMatrixCSC{F}, idf_vector::Vector{F}) where {T <: Real, F <: AbstractFloat}
rows = rowvals(dtm)
dtmvals = nonzeros(dtm)
tfidfvals = nonzeros(tfidf)
@assert size(dtmvals) == size(tfidfvals)
p = size(dtm, 2)
# TF tells us what proportion of a document is defined by a term
words_in_documents = F.(sum(dtm, dims=2))
oneval = one(F)
for i = 1:p
for j in nzrange(dtm, i)
row = rows[j]
tfidfvals[j] = dtmvals[j] / max(words_in_documents[row], oneval) * idf_vector[i]
end
end
return tfidf
end
function MLJBase.transform(transformer::TfidfTransformer, result::TfidfTransformerResult, v::Vector{StringDocument{String}})
corpus = Corpus(v)
return MLJBase.transform(transformer, result, corpus)
end
function MLJBase.transform(::TfidfTransformer, result::TfidfTransformerResult, v::Corpus)
m = DocumentTermMatrix(v, result.vocab)
tfidf = similar(m.dtm, eltype(result.idf_vector))
build_tfidf!(m.dtm, tfidf, result.idf_vector)
return tfidf
end
@pazzo83
Copy link
Author

pazzo83 commented Jul 21, 2021

Yes, I think that would work great!

@ablaom
Copy link

ablaom commented Jul 21, 2021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment