Skip to content

Instantly share code, notes, and snippets.

@neubig
Created September 10, 2012 02:41
Show Gist options
  • Save neubig/3688522 to your computer and use it in GitHub Desktop.
Save neubig/3688522 to your computer and use it in GitHub Desktop.
An example implementation of the String Rewriting Kernel in python
#!/usr/bin/python
# A python implementation of the string rewriting kernel
# by Graham Neubig
#
# Reference:
# Fan Bu, Hang Li, Xiaoyan Zhu. "String Rewriting Kernel". ACL 2012
# http://aclweb.org/anthology-new/P/P12/P12-1047.pdf
from math import factorial
from collections import defaultdict
# Convert two strings into maps counting doubles. This is "#_e" in the paper.
def double_count(s1, s2):
ret = defaultdict(lambda: 0)
for k, v1 in enumerate(s1):
ret[(v1, s2[k])] += 1
return ret
# Convert two strings into maps counting only doubles where the values
# differ. This is useful, only k-grams with exactly the same profile of
# differing values will have any valid rewrite rules. (As mentioned in
# 5.2.3). This is "#_N" in the paper.
def inequal_double_count(s1, s2):
ret = defaultdict(lambda: 0)
for k, v1 in enumerate(s1):
if v1 != s2[k]:
ret[(v1, s2[k])] += 1
return ret
# Calculate the binomial coefficient
def binomial(n, k):
return factorial(n) // (factorial(k) * factorial(n - k))
# Compute the number of possible combinations of rewrite rules allowed by
# a particular double pair
# (description (3), (1), (2) on the bottom of page 5 of the paper
def alpha_val(i, e, num_es, num_et):
if e[0] == e[1]:
return binomial(num_es, i) * binomial(num_et, i) * factorial(i)
elif (num_es != num_et) or (i != num_es):
return 0
else:
return factorial(i)
# Compute the number of matched rewriting rules for two double counts of a
# k-gram. Implements function g_e (lines 3-8 of algorithm 1) in the paper
def double_count_kernel(asD, atD, lam = 1):
result = 1.0
# Find all e where at least one of the values is identical
for e in dict(asD.items()+atD.items()).keys():
# Save the number of matches for both
num_es = asD.get(e, 0)
num_et = atD.get(e, 0)
num_e = min(num_es, num_et)
# ge saves the kernel value
ge = 0.0
# Calculate the number of matches
for i in range(0, num_e+1):
ge += alpha_val(i, e, num_es, num_et) * (lam ** (2*i))
result *= ge
# If the result is zero, return
if result == 0:
return result
return result
# Compute the number of matched rewriting rules for two k-grams
# Implements algorithm 1 from the paper
def kgram_kernel(as1, at1, as2, at2, lam = 1):
asD = double_count(as1,as2)
atD = double_count(at1,at2)
return double_count_kernel(asD, atD, lam)
# # According to the paper, this should be:
# # 12\lambda^{12} + 24\lambda^{10} + 14\lambda^8 + 2\lambda^6
# # These tests check that this is correct.
# print kgram_kernel(list('abbccbb'), list('cbcbbcb'), list('abcccdd'), list('cbccdcd'))
# print kgram_kernel(list('abbccbb'), list('cbcbbcb'), list('abcccdd'), list('cbccdcd'), .5)
# Compute the total number of matched rewriting rules for two sentence pairs
# Implements algorithm 2 from the paper
def string_rewriting_kernel(s1, t1, s2, t2, k):
# Initialize the maps and counters
ms = defaultdict(lambda: defaultdict(lambda: 0))
mt = defaultdict(lambda: defaultdict(lambda: 0))
# For each k-gram in s1
for pos1 in range(0, len(s1)-k+1):
as1 = s1[pos1 : pos1+k]
# For each k-gram in s2
for pos2 in range(0, len(s2)-k+1):
as2 = s2[pos2 : pos2+k]
sharp_N = inequal_double_count(as1, as2)
sharp_all = double_count(as1, as2)
ms[repr(sharp_N.items())][repr(sharp_all.items())] += 1
# For each k-gram in t1
for pos1 in range(0, len(t1)-k+1):
at1 = t1[pos1 : pos1+k]
# For each k-gram in t2
for pos2 in range(0, len(t2)-k+1):
at2 = t2[pos2 : pos2+k]
sharp_N = inequal_double_count(at1, at2)
sharp_all = double_count(at1, at2)
mt[repr(sharp_N.items())][repr(sharp_all.items())] += 1
# For each key and value that exists in both maps
result = 0
for key in ms.keys():
if key not in mt:
continue
# Calculate and add the kernel
for s_dub, s_cnt in ms[key].items():
for t_dub, t_cnt in mt[key].items():
result += s_cnt * t_cnt * double_count_kernel(
dict(eval(s_dub)),
dict(eval(t_dub)))
return result
print string_rewriting_kernel(
('he', 'hit', 'the', 'small', 'cat'),
('he', 'walloped', 'the', 'small', 'feline'),
('he', 'hit', 'a', 'big', 'robber'),
('he', 'walloped', 'a', 'big', 'catthief'),
3
);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment