Skip to content

Instantly share code, notes, and snippets.

@takuti
Created July 30, 2015 14:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save takuti/e6975eb6f755b3fbc188 to your computer and use it in GitHub Desktop.
Save takuti/e6975eb6f755b3fbc188 to your computer and use it in GitHub Desktop.
# coding: utf-8
""" Usage
$ ./IBM_Model1.py ${english_filename} ${japanese_filename}
NOTE: prefix for filenames 'kftt-alignments/data/' will be automatically added
"""
# Reference: http://www.statmt.org/book/slides/04-word-based-models.pdf
import sys
def main():
if len(sys.argv) < 3:
exit('[Usage] $ ./IBM_Model1.py ${english_filename} ${japanese_filename}')
""" Load data
"""
english_filename = sys.argv[1]
with open('kftt-alignments/data/%s' % english_filename) as f:
e_sentences = map(lambda l: l.rstrip().split(' '), f.readlines())
e_tokens = set([t for inner_list in e_sentences for t in inner_list])
japanese_filename = sys.argv[2]
with open('kftt-alignments/data/%s' % japanese_filename) as f:
j_sentences = map(lambda l: l.rstrip().decode('utf-8').split(' '), f.readlines())
j_tokens = set([t for inner_list in j_sentences for t in inner_list])
t = {}
for e in e_tokens:
for j in j_tokens:
t[(e, j)] = .25
""" Model learning using EM-algorithms
"""
count = {}
total = {}
eps = 1e-3
while True:
# initialization
for e in e_tokens:
for j in j_tokens:
count[(e, j)] = 0.
for j in j_tokens:
total[j] = 0.
# Maximization
s_total = {}
for e_sentence, j_sentence in zip(e_sentences, j_sentences):
for e in e_sentence:
s_total[e] = 0.
for j in j_sentence:
s_total[e] += t[(e, j)]
for e in e_sentence:
for j in j_sentence:
count[(e, j)] += t[(e, j)] / s_total[e]
total[j] += t[(e, j)] / s_total[e]
# Expectation
cnt = 0
for e in e_tokens:
for j in j_tokens:
new_t = count[(e, j)] / total[j]
if abs(new_t - t[(e, j)]) > eps: cnt += 1
t[(e, j)] = new_t
if cnt == 0: break
else: print cnt # for debug
print 'total: %d English tokens <-> %d Japanese tokens' % (len(e_tokens), len(j_tokens))
for e, j in t.keys():
if t[(e, j)] > .9: print e, j, t[(e, j)]
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment