Created
April 13, 2023 18:45
-
-
Save janoPig/51d722e46d161ad3adbc04163be3384d to your computer and use it in GitHub Desktop.
Optimize float const in sympy expression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import sympy as sp | |
import scipy.optimize as opt | |
def get_constants(expr): | |
const_symbols = [] | |
def get_constants_helper(expr): | |
if isinstance(expr, sp.Float): | |
const_symbols.append(float(expr)) | |
if isinstance(expr, (sp.Symbol, sp.Number)): | |
return | |
if isinstance(expr, sp.Expr): | |
for arg in expr.args: | |
get_constants_helper(arg) | |
get_constants_helper(expr) | |
return const_symbols | |
def optimize_constants(expr, X, y): | |
const_count = 0 | |
const_symbols = [] | |
def replace_constants_with_symbols(expr): | |
nonlocal const_count, const_symbols | |
if isinstance(expr, sp.core.Float): | |
const_symbols.append(sp.Symbol('c' + str(const_count))) | |
const_count += 1 | |
return const_symbols[-1] | |
if isinstance(expr, (sp.Symbol, sp.Number)): | |
return expr | |
if isinstance(expr, sp.Expr): | |
args = [] | |
for arg in expr.args: | |
args.append(replace_constants_with_symbols(arg)) | |
return type(expr)(*args) | |
return expr | |
def objective(coeffs): | |
y_pred = np.array( | |
f(*[X.iloc[:, i] for i in range(len(variables))], *coeffs)).reshape(-1, 1) | |
r = np.mean((y - y_pred)**2, axis=0) | |
return r | |
expr_with_symbols = replace_constants_with_symbols(expr) | |
f = sp.lambdify(list(expr.free_symbols) + const_symbols, | |
expr_with_symbols, "numpy") | |
variables = list(expr.free_symbols) | |
init_vals = get_constants(expr) | |
result = opt.minimize(objective, init_vals) | |
# Substitute optimal constants back into expression | |
for i, c in enumerate(const_symbols): | |
expr_with_symbols = expr_with_symbols.subs(c, result.x[i]) | |
return expr_with_symbols |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment