Last active
December 15, 2015 00:09
-
-
Save lidavidm/5171100 to your computer and use it in GitHub Desktop.
Derivative steps using strategies module
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sympy | |
from sympy.strategies.core import switch | |
from sympy.core.function import AppliedUndef | |
from sympy.functions.elementary.trigonometric import TrigonometricFunction | |
import collections | |
def Rule(name, props=""): | |
return collections.namedtuple(name, props + " context symbol") | |
ConstantRule = Rule("ConstantRule", "number") | |
PowerRule = Rule("PowerRule", "base exp") | |
AddRule = Rule("AddRule", "substeps") | |
ExpRule = Rule("ExpRule", "f base") | |
class Derivative(sympy.Derivative): | |
def __new__(cls, expr, variable, recurse): | |
expr = sympy.Derivative.__new__(cls, expr, variable) | |
expr._recurse = recurse | |
return expr | |
def recurse(self, expr, symbol): | |
return self._recurse(expr, symbol) | |
def step(rule): | |
return lambda derivative: rule(derivative) | |
def do(rule): | |
def _do(derivative): | |
r = rule(derivative) | |
return evaluators[r.__class__](derivative, *r) | |
return _do | |
evaluators = {} | |
def evaluates(rule): | |
def _evaluates(func): | |
func.rule = rule | |
evaluators[rule] = func | |
return func | |
return _evaluates | |
@evaluates(PowerRule) | |
def eval_power(derivative, base, exp, expr, symbol): | |
return exp * base ** (exp - 1) | |
@evaluates(AddRule) | |
def eval_add(derivative, substeps, expr, symbol): | |
return sum(substeps[1:], substeps[0]) | |
@evaluates(ConstantRule) | |
def eval_number(derivative, number, expr, symbol): | |
return 0 | |
@evaluates(ExpRule) | |
def eval_exp(derivative, f, base, expr, symbol): | |
return f * sympy.ln(base) | |
def power_rule(derivative): | |
expr, symbol = derivative.args | |
base, exp = expr.as_base_exp() | |
if base.is_constant(symbol): | |
return ExpRule(expr, base, expr, symbol) | |
return PowerRule(base, exp, expr, symbol) | |
def add_rule(derivative): | |
expr, symbol = derivative.args | |
return AddRule([derivative.recurse(arg, symbol) for arg in expr.args], | |
expr, symbol) | |
def constant_rule(derivative): | |
return ConstantRule(derivative.args[0], derivative.args[0], derivative.args[1]) | |
def _make_diff(stepfunction): | |
def _diff_steps(expr, symbol): | |
deriv = Derivative(expr, symbol, _diff_steps) | |
def key(expr): | |
expr = expr.args[0] | |
if isinstance(expr, TrigonometricFunction): | |
return TrigonometricFunction | |
elif isinstance(expr, AppliedUndef): | |
return AppliedUndef | |
elif expr.is_constant(symbol): | |
return 'constant' | |
else: | |
return expr.func | |
return switch(key, { | |
sympy.Pow: stepfunction(power_rule), | |
sympy.Symbol: stepfunction(power_rule), | |
sympy.Add: stepfunction(add_rule), | |
'constant': stepfunction(constant_rule) | |
})(deriv) | |
return _diff_steps | |
diff_steps = _make_diff(step) | |
diff = _make_diff(do) | |
x = sympy.Symbol('x') | |
print diff_steps(x**2 + x**3 + x, x) | |
print diff(2 ** x + 2, x) | |
print diff(x**2 + x**3 + x, x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment