Skip to content

Instantly share code, notes, and snippets.

@lidavidm
Last active December 15, 2015 00:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lidavidm/5171100 to your computer and use it in GitHub Desktop.
Save lidavidm/5171100 to your computer and use it in GitHub Desktop.
Derivative steps using strategies module
import sympy
from sympy.strategies.core import switch
from sympy.core.function import AppliedUndef
from sympy.functions.elementary.trigonometric import TrigonometricFunction
import collections
def Rule(name, props=""):
return collections.namedtuple(name, props + " context symbol")
ConstantRule = Rule("ConstantRule", "number")
PowerRule = Rule("PowerRule", "base exp")
AddRule = Rule("AddRule", "substeps")
ExpRule = Rule("ExpRule", "f base")
class Derivative(sympy.Derivative):
def __new__(cls, expr, variable, recurse):
expr = sympy.Derivative.__new__(cls, expr, variable)
expr._recurse = recurse
return expr
def recurse(self, expr, symbol):
return self._recurse(expr, symbol)
def step(rule):
return lambda derivative: rule(derivative)
def do(rule):
def _do(derivative):
r = rule(derivative)
return evaluators[r.__class__](derivative, *r)
return _do
evaluators = {}
def evaluates(rule):
def _evaluates(func):
func.rule = rule
evaluators[rule] = func
return func
return _evaluates
@evaluates(PowerRule)
def eval_power(derivative, base, exp, expr, symbol):
return exp * base ** (exp - 1)
@evaluates(AddRule)
def eval_add(derivative, substeps, expr, symbol):
return sum(substeps[1:], substeps[0])
@evaluates(ConstantRule)
def eval_number(derivative, number, expr, symbol):
return 0
@evaluates(ExpRule)
def eval_exp(derivative, f, base, expr, symbol):
return f * sympy.ln(base)
def power_rule(derivative):
expr, symbol = derivative.args
base, exp = expr.as_base_exp()
if base.is_constant(symbol):
return ExpRule(expr, base, expr, symbol)
return PowerRule(base, exp, expr, symbol)
def add_rule(derivative):
expr, symbol = derivative.args
return AddRule([derivative.recurse(arg, symbol) for arg in expr.args],
expr, symbol)
def constant_rule(derivative):
return ConstantRule(derivative.args[0], derivative.args[0], derivative.args[1])
def _make_diff(stepfunction):
def _diff_steps(expr, symbol):
deriv = Derivative(expr, symbol, _diff_steps)
def key(expr):
expr = expr.args[0]
if isinstance(expr, TrigonometricFunction):
return TrigonometricFunction
elif isinstance(expr, AppliedUndef):
return AppliedUndef
elif expr.is_constant(symbol):
return 'constant'
else:
return expr.func
return switch(key, {
sympy.Pow: stepfunction(power_rule),
sympy.Symbol: stepfunction(power_rule),
sympy.Add: stepfunction(add_rule),
'constant': stepfunction(constant_rule)
})(deriv)
return _diff_steps
diff_steps = _make_diff(step)
diff = _make_diff(do)
x = sympy.Symbol('x')
print diff_steps(x**2 + x**3 + x, x)
print diff(2 ** x + 2, x)
print diff(x**2 + x**3 + x, x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment