Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created March 18, 2018 22:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ResidentMario/6abf17a728e9891ebd1c07b8b2a141cf to your computer and use it in GitHub Desktop.
Save ResidentMario/6abf17a728e9891ebd1c07b8b2a141cf to your computer and use it in GitHub Desktop.
Symbolic differenciation.

This is a short symbolic differenciation program. It works on single-variable polynomial expressions.

The main codebase is in sym.py. To run the tests execute:

python -m unittest tests.py
class Term:
"""
Polynomial term.
"""
def __init__(self, var, power, mult):
"""
:param var: The variable used in the expression, e.g. `x` for an expression in x, or `1` for a constant term.
:param power: The variable power, e.g. `2` for squared. Can be `0`.
:param mult: The constant multiple, e.g. `2` for a multiple of 2. Can be `0`.
"""
self.var, self.power, self.mult = var, power, mult
def diff(self, var):
"""
Differentiate with respect to a variable.
:param var: The variable being diffed.
:return: A new Term object encapsulating the resulting state.
"""
if self.var != var or self.power == 0 or self.mult == 0:
return Term(self.var, 0, 0)
else:
return Term(self.var, self.power - 1, self.power * self.mult)
def __eq__(self, other):
# Is the multiple is 0 then do not compare the variable or the power.
if self.mult == 0:
return other.mult == 0
# If the power is 0 then do not compare the variable.
elif self.power == 0:
return other.power == 0 and self.mult == other.mult
# Otherwise compare everything.
else:
return all([self.__getattribute__(att) == other.__getattribute__(att) for att in ["var", "power", "mult"]])
def __repr__(self):
return "<Term(var={0}, power={1}, mult={2})>".format(self.var, self.power, self.mult)
class Expression:
"""
Polynomial expression consisting of a list of Term objects.
Note that all Terms are considered w.r.t. addition; the subtraction operator is instead stored as a negative
multiple.
"""
def __init__(self, input):
"""
:param input: A list of Term objects.
"""
if isinstance(input, str):
# TODO: Allow/parse string input, much more convenient for the end-user.
raise NotImplementedError
else: # isinstance(input, list)
self.terms = input
def diff(self, var):
"""
Differenciate the expression with respect to a variable.
:param var: The variable being diffed.
:return: A new Expression object encapsulating the resulting state.
"""
return Expression(input=[term.diff(var) for term in self.terms])
def __eq__(self, other):
return all([a == b for a, b in zip(self.terms, other.terms)])
def __repr__(self):
return "<Expression([...{0} term(s)])>".format(len(self.terms))
import unittest
from sym import Term, Expression
import itertools
class TestTermCases(unittest.TestCase):
# Test equality.
def testZeroMultEq(self):
"""Zero multiple terms evaluate to the same value."""
exps = [Term(var='x', power=0, mult=0), Term(var='x', power=1, mult=0),
Term(var='1', power=0, mult=0)]
assert all([a == b for a, b in itertools.combinations(exps, 2)])
def testZeroPowerEq(self):
"""Zero power terms evaluate without regard to the variable."""
result = Term(var='x', power=0, mult=1)
expected = Term(var='y', power=0, mult=1)
assert result == expected
def testNonzeroEq(self):
"""Non-zero terms evaluate using normal rules."""
exp = Term(var='x', power=1, mult=1)
assert exp == Term(var='x', power=1, mult=1)
assert exp != Term(var='x', power=1, mult=2)
assert exp != Term(var='x', power=2, mult=1)
assert exp != Term(var='y', power=1, mult=1)
# Test zero-term differentiation.
def testZeroTermMultDiff(self):
"""Zero-multiple terms are nil."""
result = Term(var='x', power=1, mult=0).diff('x')
expected = Term(var='x', power=0, mult=0)
assert result == expected
def testZeroTermPowerDiff(self):
"""Zero-power terms are nil."""
result = Term(var='x', power=0, mult=1).diff('x')
expected = Term(var='x', power=0, mult=0)
assert result == expected
def testZeroTermMultAndPowerDiff(self):
"""Zero {multiple, power} terms are nil."""
result = Term(var='x', power=0, mult=0).diff('x')
expected = Term(var='x', power=0, mult=0)
assert result == expected
# Test routine differentiation.
def testPowerOneDiff(self):
"""Power-one terms are reduced to a constant term."""
result = Term(var='x', power=1, mult=1).diff('x')
expected = Term(var='x', power=0, mult=1)
assert result == expected
def testPowerTenDiff(self):
"""Higher power terms are reduced correctly."""
result = Term(var='x', power=10, mult=1).diff('x')
expected = Term(var='x', power=9, mult=10)
assert result == expected
def testPowerNegTenDiff(self):
"""Negative terms are reduced correctly."""
result = Term(var='x', power=-10, mult=1).diff('x')
expected = Term(var='x', power=-11, mult=-10)
assert result == expected
class TestExpressionCases(unittest.TestCase):
def testEq(self):
"""Test trivial equality."""
result = Expression(input=[Term(var='x', power=2, mult=1), Term(var='x', power=1, mult=1)])
expected = Expression(input=[Term(var='x', power=2, mult=1), Term(var='x', power=1, mult=1)])
assert result == expected
def testDiff(self):
"""Test differentiation."""
result = Expression(input=[
Term(var='x', power=2, mult=1), Term(var='x', power=1, mult=1)
]).diff('x')
expected = Expression(input=[
Term(var='x', power=1, mult=2), Term(var='x', power=0, mult=1)
])
assert result == expected
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment