Created
July 28, 2014 04:41
-
-
Save alpicola/da5db098d2d5bac274ad to your computer and use it in GitHub Desktop.
Biterm Topic Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// X. Yan, J. Guo, Y. Lan, and X. Cheng, A Biterm Topic Model for Short Texts, | |
// in WWW. ACM, 2013, pp. 1445–145 | |
import scala.collection._ | |
import scala.io.Source | |
import scala.util.Random | |
import java.io._ | |
class BTM(val alpha:Double, val beta:Double, val k:Int, val iterN:Int) { | |
private var words:Array[String] = null | |
private var documents:Array[(String, Array[Int])] = null | |
private var biterms:Array[(Int, Int)] = null | |
private var m:Int = 0 | |
private var b_z:Array[Int] = null | |
private var n_z:Array[Long] = null | |
private var n_w_z:Array[Long] = null | |
private var theta:Array[Double] = null | |
private var phi:Array[Double] = null | |
private var table:Array[Double] = null | |
def load(file:String) { | |
val dict = mutable.HashMap[String, Int]() | |
val count = Iterator.from(0) | |
val buf1 = mutable.ArrayBuffer[(String, Array[Int])]() | |
val buf2 = mutable.ArrayBuffer[(Int, Int)]() | |
val s = Source.fromFile(file) | |
try { | |
s.getLines.foreach { line => | |
val row = line.stripLineEnd.split("\t") | |
val d = row.tail.map(word => dict.getOrElseUpdate(word, count.next)) | |
buf1 += ((row.head, d)) | |
buf2 ++= getBiterms(d) | |
} | |
} finally { | |
s.close | |
} | |
m = count.next | |
words = new Array(m) | |
dict.iterator.foreach { case (word, i) => | |
words(i) = word | |
} | |
documents = buf1.toArray | |
biterms = buf2.toArray | |
b_z = new Array(biterms.length) | |
n_z = new Array(k) | |
n_w_z = new Array(k * m) | |
theta = new Array(k) | |
phi = new Array(k * m) | |
table = new Array(k) | |
println(s"|B|: ${biterms.length}, K: $k, M: $m") | |
} | |
def estimate { | |
Iterator.continually(0L).copyToArray(n_z) | |
Iterator.continually(0L).copyToArray(n_w_z) | |
biterms.iterator.zipWithIndex.foreach { case (b, i) => | |
setTopic(b, i, Random.nextInt(k)) | |
} | |
Iterator.range(0, iterN).foreach { n => | |
println(s"iteration ${n+1}") | |
biterms.iterator.zipWithIndex.foreach { case (b, i) => | |
unsetTopic(b, b_z(i)) | |
setTopic(b, i, sampleTopic(b)) | |
} | |
} | |
calcTheta | |
calcPhi | |
println("done!") | |
} | |
def report { | |
val o1 = new PrintWriter(new File(s"topics.k$k")) | |
Iterator.range(0, k).foreach { z => | |
val ws = (0 until m).sortBy(w => -phi(w*k+z)).take(20) | |
o1.println(s"${theta(z)}\t" ++ ws.map(words).mkString("\t")) | |
} | |
o1.close | |
val o2 = new PrintWriter(new File(s"words.k$k")) | |
Iterator.range(0, m).foreach { w => | |
val p_w_z = Iterator.range(w*k, (w+1)*k-1).map(phi) | |
val weight = p_w_z.zip(theta.iterator).map { case (p, q) => p * q }.toArray | |
val h = 1.0 / weight.sum | |
val p_z_w = weight.iterator.map(_ * h) | |
o2.println(s"${words(w)}\t" ++ p_z_w.mkString("\t")) | |
} | |
o2.close | |
val o3 = new PrintWriter(new File(s"documents.k$k")) | |
documents.foreach { case (id, d) => | |
val bs = getBiterms(d).toArray | |
val hs = bs.map { b => | |
val (w1, w2) = b | |
1.0 / Iterator.range(0, k).map { z => | |
theta(z) * phi(w1*k+z) * phi(w2*k+z) | |
}.sum * bs.count(_ == b) / bs.length | |
} | |
val p_z_d = Iterator.range(0, k).map { z => | |
bs.iterator.zip(hs.iterator).map { case (b, h) => | |
val (w1, w2) = b | |
theta(z) * phi(w1*k+z) * phi(w2*k+z) * h | |
}.sum | |
} | |
o3.println(s"${id}\t" ++ p_z_d.mkString("\t")) | |
} | |
o3.close | |
} | |
private def getBiterms(d:Array[Int]):Iterator[(Int, Int)] = { | |
d.toSeq.combinations(2).map { case Seq(w1, w2) => | |
if (w1 < w2) (w1, w2) else (w2, w1) | |
} | |
} | |
private def setTopic(b:(Int, Int), i:Int, z:Int) { | |
val (w1, w2) = b | |
b_z(i) = z | |
n_z(z) += 1 | |
n_w_z(w1*k+z) += 1 | |
n_w_z(w2*k+z) += 1 | |
} | |
private def unsetTopic(b:(Int, Int), z:Int) { | |
val (w1, w2) = b | |
n_z(z) -= 1 | |
n_w_z(w1*k+z) -= 1 | |
n_w_z(w2*k+z) -= 1 | |
} | |
private def sampleTopic(b:(Int, Int)):Int = { | |
val (w1, w2) = b | |
Iterator.range(0, k).map { z => | |
val h = m / (n_z(z) * 2 + m * beta) | |
val p_z_w1 = (n_w_z(w1*k+z) + beta) * h | |
val p_z_w2 = (n_w_z(w2*k+z) + beta) * h | |
(n_z(z) + alpha) * p_z_w1 * p_z_w2 | |
}.scanLeft(0.0)(_ + _).drop(1).copyToArray(table) | |
val r = Random.nextDouble * table.last | |
table.indexWhere(_ >= r) | |
} | |
private def calcTheta { | |
Iterator.range(0, k).map { z => | |
(n_z(z) + alpha) / (biterms.length + k * alpha) | |
}.copyToArray(theta) | |
} | |
private def calcPhi { | |
Iterator.range(0, m).flatMap { w => | |
Iterator.range(0, k).map { z => | |
(n_w_z(w*k+z) + beta) / (n_z(z) * 2 + m * beta) | |
} | |
}.copyToArray(phi) | |
} | |
} | |
object BTM { | |
def main(args:Array[String]) { | |
val btm = new BTM(1.0 / 20, 0.01, 20, 200) | |
btm.load(args(0)) | |
btm.estimate | |
btm.report | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
require 'mongo' | |
require 'MeCab' | |
require 'dotenv' | |
Dotenv.load | |
db = Mongo::Connection.new.db(ENV['MONGODB_DB']) | |
collection = db.collection(ENV['MONGODB_COLLECTION']) | |
mecab = MeCab::Tagger.new('-d /usr/share/mecab/dic/ipadic') | |
stopwords = [] | |
open('stopwords.txt') {|f| | |
f.each_line {|line| stopwords << line.chomp } | |
} | |
open('tweets.tsv', 'w') {|f| | |
collection.find.each do |status| | |
if status['retweeted_status'] | |
status = status['retweeted_status'] | |
end | |
text = status['text'] | |
mentions = [] | |
hashtags = [] | |
domains = [] | |
status['entities']['user_mentions'].each do |item| | |
mentions << '@' + item['screen_name'] | |
text.sub!(mentions.last, '') | |
end | |
status['entities']['hashtags'].each do |item| | |
hashtags << '#' + item['text'] | |
text.sub!(hashtags.last, '') | |
end | |
status['entities']['urls'].each do |item| | |
domains << item['display_url'].split('/')[0] | |
text.sub!(item['url'], '') | |
end | |
(status['entities']['media'] || []).each do |item| | |
text.sub!(item['url'], '') | |
end | |
words = [] | |
node = mecab.parseToNode(text) | |
while node | |
word = nil | |
feature = node.feature.split(',') | |
case feature[0] | |
when '名詞' | |
word = node.surface | |
when '動詞', '形容詞', '形容動詞' | |
if feature[6] != '*' | |
word = feature[6] | |
end | |
end | |
if word && !stopwords.include?(word) | |
words << word | |
end | |
node = node.next | |
end | |
words = words + mentions + hashtags + domains | |
if words.length > 1 | |
words.unshift(status['id_str']) | |
f.puts words.join("\t") | |
end | |
end | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
require 'tweetstream' | |
require 'mongo' | |
require 'dotenv' | |
Dotenv.load | |
db = Mongo::Connection.new.db(ENV['MONGODB_DB']) | |
collection = db.collection(ENV['MONGODB_COLLECTION']) | |
TweetStream.configure do |config| | |
config.consumer_key = ENV['CONSUMER_KEY'] | |
config.consumer_secret = ENV['CONSUMER_SECRET'] | |
config.oauth_token = ENV['ACCESS_TOKEN_KEY'] | |
config.oauth_token_secret = ENV['ACCESS_SECRET'] | |
config.auth_method = :oauth | |
end | |
count = 0 | |
limit = 100000 | |
TweetStream::Client.new.sample do |status| | |
if status.user.lang == 'ja' | |
collection.insert(status.to_h) | |
count += 1 | |
if count % 100 == 0 | |
puts "saved #{count} tweets" | |
if count >= limit | |
puts "done!" | |
exit | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment