Skip to content

Instantly share code, notes, and snippets.

@fKunstner
Created December 22, 2018 07:35
Show Gist options
  • Save fKunstner/248f5f0125b0ee32ca89bc7529529119 to your computer and use it in GitHub Desktop.
Save fKunstner/248f5f0125b0ee32ca89bc7529529119 to your computer and use it in GitHub Desktop.
Optimizing einsum formulas

Optimizing Einsum Formulas

opt_einsum is a neat library that can optimize eisum formulations on the fly, but can also be used to simplify formulas to find better structures.

The script eincheck.py shows an example run on

X = np.random.randn(16, 9, 676)
Y = np.random.randn(16, 32, 676)
np.einsum('bkl,bml,bkn,bmn->mk',  X, Y, X, Y)

The "optimized path" gives

np.einsum('bmk,bmk->mk', np.einsum('bml,bkl->bmk', Y, X), np.einsum('bmn,bkn->bmk', Y, X))

but the optimizer does not know that the inputs are (X, Y, X, Y), so it is possible to further simplify things by hand to avoid computing the same thing twice, leading to

np.sum(np.einsum('bml,bkl->bmk', Y, X)**2, axis=0)

giving nicer results

+-----------------+-----------------+-----------------+
|       Functions |      Time (tot) | Time (per iter) |
+-----------------+-----------------+-----------------+
|       [0] naive |       0.047003s |       0.004700s |
|   [1] optimized |       0.033002s |       0.003300s |
|  [2] optimized2 |       0.017001s |       0.001700s |
+-----------------+-----------------+-----------------+
r"""
Optimized Einsum formula using opt_einsum
Requirements
* numpy
* https://github.com/fKunstner/quickbench/tree/v0.1.0
* https://github.com/dgasmith/opt_einsum/tree/v2.3.2
"""
import numpy as np
import quickbench
import opt_einsum as oe
###
print()
print("Analysis of the starting formula")
print("--------------------------------")
print()
###
einsum_string = 'bkl,bml,bkn,bmn->mk'
# Build random views to represent this contraction
index_size = [16, 9, 676, 32, 676]
unique_inds = set(einsum_string) - {',', '-', '>'}
sizes_dict = dict(zip(unique_inds, index_size))
views = oe.helpers.build_views(einsum_string, sizes_dict)
path, path_info = oe.contract_path(einsum_string, *views, optimize='optimal')
# Order of the simplification
print(path)
#> [(0, 1), (0, 1), (0, 1)]
print(path_info)
#> Complete contraction: bkl,bml,bkn,bmn->mk
#> Naive scaling: 5
#> Optimized scaling: 4
#> Naive FLOP count: 8.423e+9
#> Optimized FLOP count: 6.142e+8
#> Theoretical speedup: 13.714
#> Largest intermediate: 7.312e+6 elements
#> --------------------------------------------------------------------------------
#> scaling BLAS current remaining
#> --------------------------------------------------------------------------------
#> 4 0 bml,bkl->bmk bkn,bmn,bmk->mk
#> 4 0 bmn,bkn->bmk bmk,bmk->mk
#> 3 0 bmk,bmk->mk mk->mk
###
print()
print("Testing the new formula")
print("-----------------------")
print()
###
X = np.random.randn(16, 9, 676)
Y = np.random.randn(16, 32, 676)
def datafunc():
return X, Y
def naive(X, Y):
return np.einsum(einsum_string, X, Y, X, Y)
def optimized(X, Y):
return np.einsum('bmk,bmk->mk', np.einsum('bml,bkl->bmk', Y, X), np.einsum('bmn,bkn->bmk', Y, X))
quickbench.check(datafunc, [naive, optimized], compfunc=lambda x, y: np.allclose(x, y))
#> [1] optimized matches [0] naive: True
quickbench.bench(datafunc, [naive, optimized])
#>+-----------------+-----------------+-----------------+
#>| Functions | Time (tot) | Time (per iter) |
#>+-----------------+-----------------+-----------------+
#>| [0] naive | 0.047003s | 0.004700s |
#>| [1] optimized | 0.033002s | 0.003300s |
#>+-----------------+-----------------+-----------------+
###
print()
print("Simplifying the resulting formula by hand")
print("-----------------------------------------")
print()
###
def optimized2(X, Y):
T = np.einsum('bml,bkl->bmk', Y, X)
return np.sum(T**2, axis=0)
quickbench.check(datafunc, [naive, optimized, optimized2], compfunc=lambda x, y: np.allclose(x, y))
#> [1] optimized matches [0] naive: True
#> [2] optimized2 matches [0] naive: True
quickbench.bench(datafunc, [naive, optimized, optimized2])
#>+-----------------+-----------------+-----------------+
#>| Functions | Time (tot) | Time (per iter) |
#>+-----------------+-----------------+-----------------+
#>| [0] naive | 0.047003s | 0.004700s |
#>| [1] optimized | 0.033002s | 0.003300s |
#>| [2] optimized2 | 0.017001s | 0.001700s |
#>+-----------------+-----------------+-----------------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment