Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robertfeldt/103f078b3154c5621f52cee3d061bf81 to your computer and use it in GitHub Desktop.
Save robertfeldt/103f078b3154c5621f52cee3d061bf81 to your computer and use it in GitHub Desktop.
Speeding up StringDistances.jl qgram distances with precalculation of qgram counts (here in count dict)
using StringDistances
function countdict(qgrams)
d = Dict{eltype(qgrams), Int32}()
for qg in qgrams
index = Base.ht_keyindex2!(d, qg)
if index > 0
d.age += 1
@inbounds d.keys[index] = qg
@inbounds d.vals[index] = d.vals[index][1] + 1
else
@inbounds Base._setindex!(d, 1, qg, -index)
end
end
d
end
struct QgramDict{Q,K}
qdict::Dict{K,Int}
end
function QgramDict(s::Union{AbstractString, AbstractVector}, q::Integer = 2)
@assert q >= 1
qgs = StringDistances.qgrams(s, q)
QgramDict{q, eltype(qgs)}(countdict(qgs))
end
QgramDict(s, q::Integer = 2) = QgramDict(collect(s), q)
q(qd::QgramDict{I,K}) where {I,K} = I
mutable struct IntersectionCounter
ndistinct1::Int
ndistinct2::Int
nintersect::Int
IntersectionCounter() = new(0, 0, 0)
end
function (c::IntersectionCounter)(n1::Integer, n2::Integer)
c.ndistinct1 += (n1 > 0)
c.ndistinct2 += (n2 > 0)
c.nintersect += (n1 > 0) & (n2 > 0)
end
newcounter(d::StringDistances.QGramDistance) = IntersectionCounter()
calc(d::StringDistances.Jaccard, c::IntersectionCounter) =
1.0 - c.nintersect / (c.ndistinct1 + c.ndistinct2 - c.nintersect)
function _yield_on_co_count_pairs(fn, d1::Dict{K,I}, d2::Dict{K,I}) where {K,I<:Integer}
for (k1, c1) in d1
index = Base.ht_keyindex2!(d2, k1)
if index > 0
fn(c1, d2.vals[index])
else
fn(c1, 0)
end
end
for (k2, c2) in d2
index = Base.ht_keyindex2!(d1, k2)
if index <= 0
fn(0, c2)
end
end
end
function StringDistances.evaluate(d::StringDistances.QGramDistance, qc1::Dict{S,I}, qc2::Dict{S,I}) where {S,I<:Integer}
c = newcounter(d)
_yield_on_co_count_pairs(c, qc1, qc2)
calc(d, c)
end
function StringDistances.evaluate(d::StringDistances.QGramDistance, qd1::QgramDict, qd2::QgramDict)
@assert d.q == q(qd1)
@assert d.q == q(qd2)
evaluate(d, qd1.qdict, qd2.qdict)
end
using Random, Test, BenchmarkTools
TimeEval = Float64[]
TimeTotal = Float64[]
@testset "compare evaluate with and without pre-calculation" begin
for _ in 1:1000
qlen = rand(2:9)
d = StringDistances.Jaccard(qlen)
s1 = randstring(rand(5:10000))
ci1 = rand(2:div(length(s1), 2))
ci2 = rand((ci1+1):(length(s1)-1))
s2 = randstring(ci1-1) * s1[ci1:ci2] * randstring(length(s1)-ci2)
p1 = @elapsed qd1 = QgramDict(s1, qlen)
p2 = @elapsed qd2 = QgramDict(s2, qlen)
#@test evaluate(d, s1, s2) == evaluate(d, qd1, qd2)
t1 = @elapsed distval1 = evaluate(d, s1, s2)
t2 = @elapsed distval2 = evaluate(d, qd1, qd2)
@test distval1 == distval2
push!(TimeEval, t1/t2)
push!(TimeTotal, t1/(t2+p1+p2))
end
end
mean(TimeEval) # About 2.5-3x faster on my machine
mean(TimeTotal) # About 0.5x slower on my machine
# So we 'loose' 50% performance when pre-calculating if we are only comparing two strings once,
# but we can gain quite some if we need to calculate distance repeatedly, say for example
# when calculating a distance matrix.
function dist_matrix(d::StringDistances.QGramDistance, strs::AbstractVector{<:AbstractString}; precalc = true)
ss = precalc ? map(s -> QgramDict(s, d.q), strs) : strs
N = length(strs)
dm = zeros(Float64, N, N)
for i in 1:N
for j = (i+1):N
dm[i, j] = dm[j, i] = evaluate(d, ss[i], ss[j])
end
end
return dm
end
N = 100
strs = map(_ -> randstring(rand(2:1000)), 1:N)
d = Jaccard(2)
t1 = @elapsed dist_matrix(d, strs; precalc = true)
t2 = @elapsed dist_matrix(d, strs; precalc = false)
t2/t1 # 1.5x to 3.7x faster to pre-calc on my machine (depending on N)
@benchmark dist_matrix(d, strs; precalc = true)
@benchmark dist_matrix(d, strs; precalc = false)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment