Last active
April 25, 2020 07:31
-
-
Save delonnewman/14d5f4b37247caf2634e to your computer and use it in GitHub Desktop.
Just for Fun: Symbolic Math in Ruby
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Expression | |
OPTS = { | |
:+ => :add, | |
:- => :sub, | |
:* => :mult, | |
:/ => :div, | |
:** => :exp | |
} | |
def +(other) | |
Sum.new(self, other) | |
end | |
def -(other) | |
Difference.new(self, other) | |
end | |
def *(other) | |
Product.new(self, other) | |
end | |
def /(other) | |
Quotient.new(self, other) | |
end | |
def **(other) | |
Exponentiation.new(self, other) | |
end | |
def ground? | |
false | |
end | |
def left_ground? | |
false | |
end | |
def right_ground? | |
false | |
end | |
def simplify | |
self | |
end | |
OPTS.each do |orig, new| | |
if not method_defined? new | |
alias_method new, orig | |
end | |
end | |
end | |
class Exp | |
include Expression | |
attr_reader :op, :left, :right | |
def initialize(op, left, right) | |
@op, @right, @left = op, right, left | |
end | |
def new_of_type(*args) | |
self.class.new(*args) | |
end | |
def op_alias | |
Expression::OPTS[op] | |
end | |
def ground? | |
right.is_a?(Numeric) and left.is_a?(Numeric) | |
end | |
def left_ground? | |
left.is_a?(Numeric) | |
end | |
def right_ground? | |
right.is_a?(Numeric) | |
end | |
def simplified? | |
if ground? | |
false | |
elsif left.is_a?(Expression) and right.is_a?(Expression) | |
left.simplified? and right.simplified? | |
elsif left_ground? and right.is_a?(Expression) | |
right.simplified? | |
elsif right_ground? and left.is_a?(Expression) | |
left.simplified? | |
else | |
false | |
end | |
end | |
# overly simple, but works for now | |
def simplify | |
if simplified? | |
self | |
elsif ground? | |
right.send(op, left) | |
elsif left_ground? and not (r = right.simplify).is_a?(Expression) | |
r.send(op, left) | |
elsif right_ground? and not (l = left.simplify).is_a?(Expression) | |
right.send(op, l) | |
else | |
new_of_type(left.simplify, right.simplify).simplify | |
end | |
end | |
def symbols | |
[left, right].reduce([]) do |list, exp| | |
if exp.is_a? Symbol | |
list << exp | |
elsif exp.is_a? Exp | |
list + exp.symbols | |
else | |
list | |
end | |
end.uniq | |
end | |
# walk the expression tree while preserving it's structure like "map", but for expression trees | |
def walk(&blk) | |
if unary? | |
if left.is_a? Exp | |
new_of_type(op, left.walk(&blk)) | |
else | |
new_of_type(op, blk.call(left)) | |
end | |
else | |
if (left.is_a? Exp and right.is_a? Exp) | |
new_of_type(left.walk(&blk), right.walk(&blk)) | |
elsif left.is_a? Exp | |
new_of_type(left.walk(&blk), blk.call(right)) | |
elsif right.is_a? Exp | |
new_of_type(blk.call(left), right.walk(&blk)) | |
else | |
new_of_type(blk.call(left), blk.call(right)) | |
end | |
end | |
end | |
def arity | |
@_arity ||= symbols.count | |
end | |
def unary? | |
right.nil? | |
end | |
def inspect | |
if unary? | |
"#{op.inspect}#{left.inspect}" | |
else | |
"(#{left.inspect} #{op.inspect} #{right.inspect})" | |
end | |
end | |
def to_s | |
if unary? | |
"#{op}#{left}" | |
else | |
"(#{left} #{op} #{right})" | |
end | |
end | |
def to_ruby | |
if unary? | |
"#{left.inspect}.#{op_alias}" | |
else | |
"#{left.inspect}.#{op_alias}(#{right.inspect})" | |
end | |
end | |
def call(*args) | |
bindings = if args.length == 1 and args[0].is_a?(Hash) | |
args[0] | |
else | |
symbols.zip(args).to_h | |
end | |
walk do |exp| | |
if exp.is_a?(Symbol) | |
bindings[exp] or exp | |
else | |
exp | |
end | |
end.simplify | |
end | |
alias [] call | |
end | |
class Quotient < Exp | |
include Expression | |
def initialize(numerator, denominator) | |
super(:/, numerator, denominator) | |
end | |
def simplify | |
if right.is_a? Numeric and left.is_a? Numeric | |
Rational(left, right) | |
else | |
super | |
end | |
end | |
alias numerator left | |
alias denominator right | |
end | |
class Product < Exp | |
include Expression | |
def initialize(*factors) | |
super(:*, factors[0], factors[1]) | |
end | |
def factors | |
[left, right] | |
end | |
end | |
class Exponentiation < Exp | |
include Expression | |
def initialize(base, exponent) | |
super(:**, base, exponent) | |
end | |
alias base left | |
alias exponent right | |
end | |
class Sum < Exp | |
include Expression | |
def initialize(*terms) | |
super(:+, terms[0], terms[1]) | |
end | |
def terms | |
[left, right] | |
end | |
end | |
class Difference < Exp | |
include Expression | |
def initialize(subtrahend, minuend) | |
super(:-, subtrahend, minuend) | |
end | |
alias subtrahend left | |
alias minuend right | |
end | |
class Application < Exp | |
include Expression | |
def initialize(name, exp) | |
super(name, exp, nil) | |
end | |
alias name op | |
alias exp left | |
def simplify | |
if exp.is_a?(Numeric) | |
Math.send(name, exp) | |
elsif exp.is_a?(Expression) | |
self.class.new(name, exp.simplify) | |
else | |
super | |
end | |
end | |
def inspect | |
"#{op.inspect}[#{left.inspect}]" | |
end | |
def to_s | |
"#{op}[#{left}]" | |
end | |
end | |
class Symbol | |
include Expression | |
def [](exp) | |
Application.new(self, exp).simplify | |
end | |
def simplified? | |
true | |
end | |
end | |
def assert(exp, a, b) | |
if a == b | |
puts "Passed: #{exp} = #{a} == #{b}" | |
else | |
puts "Failed: #{exp} = #{a} != #{b}" | |
end | |
end | |
poly = ((:x + 3) / (:y + 2)) | |
assert poly, poly[10, 2], ((10 + 3) / (2 + 2)) | |
quot = (:x / 1) | |
assert quot, quot[3], (3 / 1) | |
sq = :x ** 2 | |
assert sq, sq[4], 4 ** 2 | |
sqrt = :sqrt[:x] | |
assert sqrt, sqrt[13], Math.sqrt(13) | |
pythag = :sqrt[:a ** 2 + :b ** 2] | |
assert pythag, pythag[2, 3], Math.sqrt(2 ** 2 + 3 ** 3) | |
add1 = :x + 1 | |
assert add1, add1[4], 4 + 1 | |
sub1 = :x - 1 | |
assert sub1, sub1[3], 3 - 1 | |
quad = (:a * (:x ** 2)) + (:b + :x) + :c | |
p quad | |
x = quad[a: 1, b: 2, c: 3] | |
assert x, 1, x.arity | |
quad_plus = ((:b * 1) - (:sqrt[(:b ** 2) - (:a * :c * 4)])) / (:a * 2) | |
assert quad_plus, quad_plus[a: 1, b: 2, c: 3], ((2 * 1) - Math.sqrt((2 ** 2) - (1 * 3 * 4))) / (1 * 2) | |
quad_minus = ((:b * -1) - (:sqrt[(:b ** 2) - (:a * :c * 4)])) / (:a * 2) | |
assert quad_minus, quad_minus[a: 1, b: 2, c: 3], ((2 * -1) - Math.sqrt((2 ** 2) - (1 * 3 * 4))) / (1 * 2) | |
p :exp[2] | |
p :exp[:x + 1][2] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment