Skip to content

Instantly share code, notes, and snippets.

@jgillis
Created November 16, 2021 20:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jgillis/59331e778bc8e2254dc961ac5a2653a9 to your computer and use it in GitHub Desktop.
Save jgillis/59331e778bc8e2254dc961ac5a2653a9 to your computer and use it in GitHub Desktop.
from casadi import *
x = MX.sym("x",30)
x0 = DM.rand(x.shape)
print("Reference approach")
# Repetitive
c = 0
for i in range(10):
a = x[2*i:2*i+2]
b = x[3*i]
c += dot(sin(a)*b,cos(a))
f = Function('f',[x],[c])
H = Function('f',[x],hessian(f(x),x))
H_ref, g_ref = H(x0)
H_ref.sparsity().spy()
print("Approach #1: lift the indexing outside of the stencil")
a = MX.sym("a",2)
b = MX.sym("b")
c = dot(sin(a)*b,cos(a))
f_stencil = Function('f_stencil',[a,b],[c])
reduce_in = [False,False] # for each input declare if reduced (False: input will be varying during a map call, True: argument will be static during a map call)
reduce_out = [True] # for each output declare if reduced (False: do not sum output, True: sum output)
f = f_stencil.map(10,reduce_in,reduce_out)
I_a = hcat([DM(range(2*i,2*i+2)) for i in range(10)])
I_b = hcat([DM(3*i) for i in range(10)])
H = Function('f',[x],hessian(f(x[I_a],x[I_b]),x))
H_approach1, f_approach1 = H(x0)
H_approach1.sparsity().spy()
print("Approach #1 error:",norm_inf(H_ref-H_approach1))
print("Approach #2: index inside the stencil using symbolic indexing")
i = MX.sym("i")
a = x[2*i+vertcat(0,1)]
b = x[3*i]
c = dot(sin(a)*b,cos(a))
f_stencil = Function('f_stencil',[i,x],[c])
reduce_in = [False,True] # for each input declare if reduced (False: input will be varying during a map call, True: argument will be static during a map call)
reduce_out = [True] # for each output declare if reduced (False: do not sum output, True: sum output)
f = f_stencil.map(10,reduce_in,reduce_out)
H = Function('f',[x],hessian(f(range(10),x),x))
H_approach2, g_approach2 = H(x0)
H_approach2.sparsity().spy()
print("Approach #2 error:",norm_inf(H_ref-H_approach2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment