Skip to content

Instantly share code, notes, and snippets.

@slowli
Created June 23, 2015 09:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slowli/69ef928bf45f1251460c to your computer and use it in GitHub Desktop.
Save slowli/69ef928bf45f1251460c to your computer and use it in GitHub Desktop.
Dynamic programming in Python with decorators
#!/usr/bin/python
'''
Demonstration of functional-style dynamic programming implementation
in Python using decorators and generators.
Run the script to measure efficiency of decorator-based DP implementations
compared to imperative bottom-up implementations (spoiler: decorators are slow).
The following recurrent formulas are used:
* Catalan numbers (https://en.wikipedia.org/wiki/Catalan_number)
* Binomial coefficients (https://en.wikipedia.org/wiki/Binomial_coefficient)
* Edit distance (https://en.wikipedia.org/wiki/Edit_distance)
'''
##### Auxiliary functions for decorators #####
import math
def pack_tri(n, k):
''' Ordering of triangular pair of arguments (0 <= n, 0 <= k <= n). '''
return n * (n + 1) / 2 + k
def unpack_tri(idx):
''' Inverse of pack_tri. '''
n = math.floor(math.sqrt(idx * 2))
n = int(n)
if n * (n + 1) / 2 > idx:
n -= 1
return (n, idx - n * (n + 1) / 2)
def lin():
''' Generates tuples (0,), (1,), (2,), ... '''
value = 0,
while True:
yield value
i, = value; value = (i + 1,)
def tri():
''' Generates pairs (n, k) satisfying 0 <= n, 0 <= k <= n:
(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), ... '''
value = 0, 0
while True:
yield value
n, k = value
value = (n, k + 1) if k < n else (n + 1, 0)
def rect(width):
''' Generates pairs (x, y) satisfying 0 <= x, 0 <= y < width:
(0, 0), (0, 1), ..., (0, width - 1),
(1, 0), (1, 1), ..., (1, width - 1),
...
'''
value = 0, 0
while True:
yield value
x, y = value
value = (x, y + 1) if y < width - 1 else (x + 1, 0)
##### Decorators #####
def lindp(pack = lambda i: i, unpack = lambda i: (i,)):
''' Linearized dynamic programming decorator.
pack should convert tuple of function arguments into a non-negative integer;
unpack should perform the inverse transformation integer -> arguments.
The ordering provided by pack should coincide with the order in which
function values are calculated. '''
def decorate(fn):
cache = []
def dp_fn(*args):
n = pack(*args)
if len(cache) <= n:
for i in range(len(cache), n + 1):
cache.append(fn(*unpack(i)))
return cache[n]
return dp_fn
return decorate
def dp(iterator = lin()):
''' Dynamic programming decorator.
iterator should iterate over argument tuples for the target function
in the order imposed by DP. '''
def decorate(fn):
cache = dict()
def dp_fn(*args):
if args not in cache:
while True:
i = iterator.next()
cache[i] = fn(*i)
if i == args: break
return cache[args]
return dp_fn
return decorate
##### DP implementations #####
@dp()
def cat(n):
''' Catalan numbers using the DP decorator. '''
if n == 0: return 1
sum = 0
for i in range(n):
sum += cat(i) * cat(n - 1 - i)
return sum
@lindp()
def cat_l(n):
''' Catalan numbers using the linearized DP decorator. '''
if n == 0: return 1
sum = 0
for i in range(n):
sum += cat(i) * cat(n - 1 - i)
return sum
def cat_im(n):
''' Catalan numbers using an imperative DP implementation. '''
C = [ 1 ] * (n + 1)
for i in range(1, n + 1):
sum = 0
for j in range(i):
sum += C[j] * C[i - 1 - j]
C[i] = sum
return C[n]
@lindp(pack = pack_tri, unpack = unpack_tri)
def binc_l(n, k):
''' Binomial coefficents using the linearized DP decorator. '''
if k == 0 or k == n: return 1
return binc(n - 1, k - 1) + binc(n - 1, k)
@dp(iterator = tri())
def binc(n, k):
''' Binomial coefficents using the DP decorator. '''
if k == 0 or k == n: return 1
return binc(n - 1, k - 1) + binc(n - 1, k)
def binc_im(n, k):
''' Binomial coefficients using an imperative DP implementation. '''
C, newC = [ 0 ] * (n + 1), [ 0 ] * (n + 1)
C[0] = 1
for i in range(n + 1):
newC[0] = 1
for j in range(1, i + 1):
newC[j] = C[j - 1] + C[j]
for j in range(i + 1):
C[j] = newC[j]
return C[k]
def edist(s, t):
''' Edit distance using the DP decorator. '''
@dp(iterator = rect(len(t) + 1))
def D(i, j):
if i == 0: return j
if j == 0: return i
if s[i - 1] == t[j - 1]:
return D(i - 1, j - 1)
else:
return min(D(i - 1, j - 1), D(i - 1, j), D(i, j - 1)) + 1
return D(len(s), len(t))
def edist_l(s, t):
''' Edit distance using the linearized DP decorator. '''
L = len(t) + 1
@lindp(pack = lambda x, y: x * L + y, unpack = lambda i: divmod(i, L))
def D(i, j):
if i == 0: return j
if j == 0: return i
if s[i - 1] == t[j - 1]:
return D(i - 1, j - 1)
else:
return min(D(i - 1, j - 1), D(i - 1, j), D(i, j - 1)) + 1
return D(len(s), len(t))
def edist_im(s, t):
''' Edit distance, imperative style. '''
D = [ j for j in range(len(t) + 1) ]
newD = [ 0 ] * (len(t) + 1)
for i in range(1, len(s) + 1):
newD[0] = i
for j in range(1, len(t) + 1):
if s[i - 1] == t[j - 1]:
newD[j] = D[j - 1]
else:
newD[j] = min(newD[j - 1], D[j], D[j - 1]) + 1
for j in range(len(t) + 1):
D[j] = newD[j]
return D[len(t)]
##### Testing #####
def performance_tests():
import time
t = time.clock()
cat_im(1000)
print 'Cat(1000), imperative programming: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
cat(1000)
print 'Cat(1000), with @dp decorator: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
cat_l(1000)
print 'Cat(1000), with @lindp decorator: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
binc_im(1000, 500)
print 'C(1000, 500), imperative style: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
binc(1000, 500)
print 'C(1000, 500), with @dp decorator: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
binc_l(1000, 500)
print 'C(1000, 500), with @lindp decorator: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
edist_im('PLEAS' * 200, 'MEANL' * 200)
print 'dE(s, t), |s| = |t| = 1000, imperative style: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
edist_l('PLEAS' * 200, 'MEANL' * 200)
print 'dE(s, t), |s| = |t| = 1000, with @lindp decorator: {:.3f} s' \
.format(time.clock() - t)
t = time.clock()
edist('PLEAS' * 200, 'MEANL' * 200)
print 'dE(s, t), |s| = |t| = 1000, with @dp decorator: {:.3f} s' \
.format(time.clock() - t)
if __name__ == '__main__':
performance_tests()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment