Skip to content

Instantly share code, notes, and snippets.

@janoPig
Created April 13, 2023 18:45
Show Gist options
  • Save janoPig/51d722e46d161ad3adbc04163be3384d to your computer and use it in GitHub Desktop.
Save janoPig/51d722e46d161ad3adbc04163be3384d to your computer and use it in GitHub Desktop.
Optimize float const in sympy expression
import numpy as np
import pandas as pd
import sympy as sp
import scipy.optimize as opt
def get_constants(expr):
const_symbols = []
def get_constants_helper(expr):
if isinstance(expr, sp.Float):
const_symbols.append(float(expr))
if isinstance(expr, (sp.Symbol, sp.Number)):
return
if isinstance(expr, sp.Expr):
for arg in expr.args:
get_constants_helper(arg)
get_constants_helper(expr)
return const_symbols
def optimize_constants(expr, X, y):
const_count = 0
const_symbols = []
def replace_constants_with_symbols(expr):
nonlocal const_count, const_symbols
if isinstance(expr, sp.core.Float):
const_symbols.append(sp.Symbol('c' + str(const_count)))
const_count += 1
return const_symbols[-1]
if isinstance(expr, (sp.Symbol, sp.Number)):
return expr
if isinstance(expr, sp.Expr):
args = []
for arg in expr.args:
args.append(replace_constants_with_symbols(arg))
return type(expr)(*args)
return expr
def objective(coeffs):
y_pred = np.array(
f(*[X.iloc[:, i] for i in range(len(variables))], *coeffs)).reshape(-1, 1)
r = np.mean((y - y_pred)**2, axis=0)
return r
expr_with_symbols = replace_constants_with_symbols(expr)
f = sp.lambdify(list(expr.free_symbols) + const_symbols,
expr_with_symbols, "numpy")
variables = list(expr.free_symbols)
init_vals = get_constants(expr)
result = opt.minimize(objective, init_vals)
# Substitute optimal constants back into expression
for i, c in enumerate(const_symbols):
expr_with_symbols = expr_with_symbols.subs(c, result.x[i])
return expr_with_symbols
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment