Skip to content

Instantly share code, notes, and snippets.

@ybenjo
Last active June 3, 2023 03:16
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ybenjo/6151182 to your computer and use it in GitHub Desktop.
Save ybenjo/6151182 to your computer and use it in GitHub Desktop.
biterm topic model(www2013)
# Xiaohui Yan, A biterm topic model for short texts(WWW 2013)
require 'set'
class Biterm
def initialize(alpha, beta, k)
@alpha = alpha
@beta = beta
@k = k
@doc_w = { }
@doc_b = Hash.new{|h, k|h[k] = Array.new}
@all_b = Set.new
@all_w = Set.new
@n_w_z = Hash.new{0.0}
@n_z = Hash.new{0.0}
@b_z = { }
srand(0)
end
def initialize_z
@all_b.each do |b|
w_1, w_2 = b
z = rand(@k)
@b_z[b] = z
@n_z[z] += 1
@n_w_z[w_1 => z] += 1
@n_w_z[w_2 => z] += 1
end
end
def read_document(filename)
open(filename, 'r'){|f|
f.each do |l|
# format
# doc_id \t word \t word
ary = l.chomp.split("\t")
doc_id = ary.shift
@doc_w[doc_id] = ary
# calc biterm
# dont use combination
ary.each do |w_1|
@all_w.add w_1
ary.each do |w_2|
b = [w_1, w_2].sort
@all_b.add b
@doc_b[doc_id].push b
end
end
end
}
initialize_z
@M = @all_w.size
end
def prob(b, z)
w_1, w_2 = b
ret = @n_z[z] * (@n_w_z[w_1 => z] + @beta) * (@n_w_z[w_2 => z] + @beta)
sum = @all_w.inject(0){|s, w| s += @n_w_z[w => z]}
ret / (sum + @M * @beta) ** 2
end
def update(b, z, num)
w_1, w_2 = b
@n_z[z] += num
@n_w_z[w_1 => z] += num
@n_w_z[w_2 => z] += num
end
def sampling
@all_b.each do |b|
# decrement
now_z = @b_z[b]
update(b, now_z, -1)
prob_table = [ ]
0.upto (@k - 1) do |tmp_z|
prob_val = prob(b, tmp_z)
prob_table.push prob_val
prob_table[-1] = prob_table[-1] + prob_table[-2] if tmp_z > 0
end
# normalize
0.upto (@k - 1) do |pos|
prob_table[pos] /= prob_table[-1]
end
# check
r = rand()
new_z = 0
1.upto (@k - 1) do |pos|
if prob_table[pos - 1] < r && prob_table[pos] >= r
new_z = pos
break
end
end
# inclement
update(b, new_z, 1)
@b_z[b] = new_z
end
end
def sampling_all(iter = 100)
iter.times do |i|
p i
sampling
end
end
def output(path = '/tmp')
# output phi
@phi = { }
open(path + '/phi.tsv', 'w'){|f|
0.upto (@k - 1) do |z|
values = { }
@all_w.each do |w|
val = @n_w_z[w => z] + @beta
val /= (@all_w.inject(0.0){|s, w_tmp| s += @n_w_z[w_tmp => z]} + @M * @beta)
values[w] = val
@phi[w => z] = val
end
values.sort_by{|e|e.last}.reverse.each do |elem|
w, val = elem
f.puts [z, w, val].join("\t")
end
end
}
# output theta
open(path + '/theta.tsv', 'w'){|f|
@theta = { }
0.upto (@k - 1) do |z|
val = @n_z[z] + @alpha
val /= (@all_b.size + @k * @alpha)
@theta[z] = val
end
@theta.sort_by{|e|e.last}.reverse.each do |elem|
z, val = elem
f.puts [z, val].join("\t")
end
}
# document topic
open(path + '/d_z.tsv', 'w'){|f|
@doc_b.each_pair do |doc_id, bs|
# P(b|d)
p_b_d = Hash.new{0.0}
bs.each do |b|
p_b_d[b] += 1
end
# normalize P(b|d)
p_b_d.each_key do |b|
p_b_d[b] /= bs.size
end
p_z_d = Hash.new{0.0}
bs.uniq.each do |b|
w_1, w_2 = b
p_z_b = { }
sum = 0.0
# calc P(z|b)
0.upto (@k - 1) do |z|
val = @theta[z] * @phi[w_1 => z] * @phi[w_2 => z]
p_z_b[z] = val
sum += val
end
0.upto (@k - 1) do |z|
p_z_b[z] /= sum
p_z_d[z] += p_z_b[z] * p_b_d[b]
end
end
# output
p_z_d.sort_by{|e|e.last}.reverse.each do |elem|
z, val = elem
f.puts [doc_id, z, val].join("\t")
end
end
}
end
end
if __FILE__ == $0
# alpha, beta, k
b = Biterm.new(0.5, 0.01, 10)
# format
# doc_id ¥t word ¥t word ...
b.read_document('/tmp/test.doc')
b.sampling_all(100)
# output /tmp
b.output
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment