Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Last active January 26, 2020 00:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaroslavvb/ce6331e5a0bc1f3a8d8a9d225ae0dda1 to your computer and use it in GitHub Desktop.
Save yaroslavvb/ce6331e5a0bc1f3a8d8a9d225ae0dda1 to your computer and use it in GitHub Desktop.
An example of turning contraction order into sequence of einsum calls
# An example of turning contraction order into sequence of einsum calls
from opt_einsum import helpers as oe_helpers
import opt_einsum as oe
def print_schedule(path, indices, output_subscript, terms):
"""
Args:
path: contraction path in einsum optimizer format, ie, [(0,), (2,), (1, 3), (0, 2), (0, 1)]
indices: ['ij','jk','kl','lm']
output_subscript: ['kl']
terms: ['term1', 'term2', 'term3', 'term4']
Example1:
print_einsums([(0,), (2,), (1, 3), (0, 2), (0, 1)], ['ij', 'jk', 'kl', 'lm'], 'kl', ['term1', 'term2', 'term3', 'term4'])
"""
input_index_sets = [set(x) for x in indices]
output_indices = set(output_subscript)
derived_count = 0
for i, contract_inds in enumerate(path):
contract_inds = tuple(sorted(list(contract_inds), reverse=True))
contract_tuple = oe_helpers.find_contraction(contract_inds, input_index_sets, output_indices)
out_inds, input_index_sets, _, idx_contract = contract_tuple
current_input_index_sets = [indices.pop(x) for x in contract_inds]
current_terms = [terms.pop(x) for x in contract_inds]
# Last contraction
if (i - len(path)) == -1:
current_output_indices = output_subscript
derived_term = f'derived{derived_count}'
else:
all_input_inds = "".join(current_input_index_sets)
current_output_indices = "".join(sorted(out_inds, key=all_input_inds.find))
derived_term = f'derived{derived_count}'
derived_count += 1
indices.append(current_output_indices)
terms.append(derived_term)
einsum_str = ",".join(current_input_index_sets) + "->" + current_output_indices
print(f'{derived_term}=einsum({einsum_str}, {current_terms})')
def optimize_and_print_schedule(einsum_str):
indices = einsum_str.split('->')[0].split(',')
output_indices = einsum_str.split('->')[1]
unique_inds = set(einsum_str) - {',', '-', '>'}
index_size = [5]*len(unique_inds)
sizes_dict = dict(zip(unique_inds, index_size))
views = oe.helpers.build_views(einsum_str, sizes_dict)
path, info = oe.contract_path(einsum_str, *views, optimize='dp')
terms = [f'term{i}' for i in range(len(indices))]
print('\noptimizing ', einsum_str, terms)
print_schedule(path, indices, output_indices, terms)
def test_print_schedule():
"""
Should print something like this:
derived0=einsum(ij->j, ['term1'])
derived1=einsum(lm->l, ['term4'])
derived2=einsum(l,kl->lk, ['derived1', 'term3'])
derived3=einsum(lk,jk->lkj, ['derived2', 'term2'])
derived4=einsum(lkj,j->kl, ['derived3', 'derived0'])
optimizing ij,jk,kl,lm-> ['term0', 'term1', 'term2', 'term3']
derived0=einsum(ij->j, ['term0'])
derived1=einsum(lm->l, ['term3'])
derived2=einsum(l,kl->k, ['derived1', 'term2'])
derived3=einsum(k,jk->j, ['derived2', 'term1'])
derived4=einsum(j,j->, ['derived3', 'derived0'])
optimizing ij,jk,kl,lm->im ['term0', 'term1', 'term2', 'term3']
derived0=einsum(lm,kl->mk, ['term3', 'term2'])
derived1=einsum(mk,jk->mj, ['derived0', 'term1'])
derived2=einsum(mj,ij->im, ['derived1', 'term0'])
optimizing ijkl,ijkl-> ['term0', 'term1']
derived0=einsum(ijkl,ijkl->, ['term1', 'term0'])
optimizing ijkl,ijkl->i ['term0', 'term1']
derived0=einsum(ijkl,ijkl->i, ['term1', 'term0'])
optimizing ijk,jkl,klm,lmn->in ['term0', 'term1', 'term2', 'term3']
derived0=einsum(lmn,klm->lnk, ['term3', 'term2'])
derived1=einsum(lnk,jkl->nkj, ['derived0', 'term1'])
derived2=einsum(nkj,ijk->in, ['derived1', 'term0'])
optimizing i,j,k,l,m-> ['term0', 'term1', 'term2', 'term3', 'term4']
derived0=einsum(i->, ['term0'])
derived1=einsum(j->, ['term1'])
derived2=einsum(,->, ['derived1', 'derived0'])
derived3=einsum(k->, ['term2'])
derived4=einsum(,->, ['derived3', 'derived2'])
derived5=einsum(l->, ['term3'])
derived6=einsum(,->, ['derived5', 'derived4'])
derived7=einsum(m->, ['term4'])
derived8=einsum(,->, ['derived7', 'derived6'])
"""
print_schedule([(0,), (2,), (1, 3), (0, 2), (0, 1)], ['ij', 'jk', 'kl', 'lm'], 'kl', ['term1', 'term2', 'term3',
'term4'])
optimize_and_print_schedule('ij,jk,kl,lm->')
optimize_and_print_schedule('ij,jk,kl,lm->im')
optimize_and_print_schedule('ijkl,ijkl->')
optimize_and_print_schedule('ijkl,ijkl->i')
optimize_and_print_schedule('ijk,jkl,klm,lmn->in')
optimize_and_print_schedule('i,j,k,l,m->')
if __name__ == '__main__':
test_print_schedule()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment