-
-
Save Immortalin/76760e67c4b1044cd01fee7334c2fb8e to your computer and use it in GitHub Desktop.
Symbolic Math in Ruby
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Expression | |
def ops | |
{ :+ => :add, | |
:- => :sub, | |
:* => :mult, | |
:/ => :div, | |
:** => :exp } | |
end | |
module_function :ops | |
ops.each do |orig, new| | |
if method_defined? orig | |
alias_method orig, new | |
end | |
end | |
def +(other) | |
Sum.new(self, other) | |
end | |
def -(other) | |
Difference.new(self, other) | |
end | |
def *(other) | |
Product.new(self, other) | |
end | |
def /(other) | |
Quotient.new(self, other) | |
end | |
def **(other) | |
Exponentiation.new(self, other) | |
end | |
def expression? | |
true | |
end | |
ops.each do |orig, new| | |
if not method_defined? new | |
alias_method new, orig | |
end | |
end | |
end | |
class Object | |
def expression? | |
false | |
end | |
end | |
class Exp | |
include Expression | |
attr_reader :op, :left, :right | |
def initialize(op, left, right) | |
@op, @right, @left = op, right, left | |
end | |
def new_of_type(*args) | |
self.class.new(*args) | |
end | |
def op_alias | |
Expression.ops[op] | |
end | |
# overly simple, but works for now | |
def simplify | |
if right.is_a? Exp and left.is_a? Exp | |
self.new_of_type(left.simplify, right.simplify).simplify | |
elsif right.is_a? Numeric and left.is_a? Numeric | |
eval(to_s) | |
else | |
self | |
end | |
end | |
def symbols | |
[left, right].reduce([]) do |list, exp| | |
if exp.is_a? Symbol | |
list << exp | |
elsif exp.is_a? Exp | |
list + exp.symbols | |
else | |
list | |
end | |
end.uniq | |
end | |
# walk the expression tree while preserving it's structure like "map", but for expression trees | |
def walk(&blk) | |
if unary? | |
if left.is_a? Exp | |
self.new_of_type(op, left.walk(&blk)) | |
else | |
self.new_of_type(op, blk.call(left)) | |
end | |
else | |
if (left.is_a? Exp and right.is_a? Exp) | |
self.new_of_type(left.walk(&blk), right.walk(&blk)) | |
elsif left.is_a? Exp | |
self.new_of_type(left.walk(&blk), blk.call(right)) | |
elsif right.is_a? Exp | |
self.new_of_type(blk.call(left), right.walk(&blk)) | |
else | |
self.new_of_type(blk.call(left), blk.call(right)) | |
end | |
end | |
end | |
def arity | |
@_arity ||= symbols.count | |
end | |
def unary? | |
right.nil? | |
end | |
def inspect | |
if unary? | |
"#{op.inspect}#{left.inspect}" | |
else | |
"(#{left.inspect} #{op.inspect} #{right.inspect})" | |
end | |
end | |
def to_s | |
if unary? | |
"#{op}#{left}" | |
else | |
"(#{left} #{op} #{right})" | |
end | |
end | |
def to_ruby | |
if unary? | |
"#{left.inspect}.#{op_alias}" | |
else | |
"#{left.inspect}.#{op_alias}(#{right.inspect})" | |
end | |
end | |
def call(*args) | |
bindings = symbols.zip(args).reduce({}) { |h, pair| h.merge(pair[0] => pair[1]) } | |
walk do |exp| | |
if exp.is_a? Symbol | |
bindings[exp] or exp | |
else | |
exp #.call(*args) | |
end | |
end.simplify | |
end | |
alias [] call | |
end | |
class Quotient < Exp | |
include Expression | |
def initialize(numerator, denominator) | |
super(:/, numerator, denominator) | |
end | |
def simplify | |
if right.is_a? Numeric and left.is_a? Numeric | |
Rational(left, right) | |
else | |
super() | |
end | |
end | |
alias numerator left | |
alias denominator right | |
end | |
class Product < Exp | |
include Expression | |
def initialize(*factors) | |
super(:*, factors[0], factors[1]) | |
end | |
def factors | |
[left, right] | |
end | |
end | |
class Exponentiation < Exp | |
include Expression | |
def initialize(base, exponent) | |
super(:**, base, exponent) | |
end | |
alias base left | |
alias exponent right | |
end | |
class Sum < Exp | |
include Expression | |
def initialize(*terms) | |
super(:+, terms[0], terms[1]) | |
end | |
def terms | |
[left, right] | |
end | |
end | |
class Difference < Exp | |
include Expression | |
def initialize(subtrahend, minuend) | |
super(:-, subtrahend, minuend) | |
end | |
alias subtrahend left | |
alias minuend right | |
end | |
class Function | |
include Expression | |
def initialize(name, args, body) | |
end | |
end | |
class Application < Exp | |
include Expression | |
def initialize(name, exp) | |
super(name, exp, nil) | |
end | |
alias name op | |
alias exp left | |
# TODO: this needs fixing | |
def simplify | |
if exp.is_a? Numeric | |
Math.send(name, exp) | |
elsif exp.is_a? Exp | |
exp.simplify | |
else | |
#exp | |
super() | |
end | |
end | |
def inspect | |
"#{op.inspect}[#{left.inspect}]" | |
end | |
def to_s | |
"#{op}[#{left}]" | |
end | |
end | |
class Symbol | |
include Expression | |
def [](exp) | |
Application.new(self, exp) | |
end | |
end | |
p ((:x + 3) / (:y + 2))[10, 2] | |
p (:x / 1)[3] | |
p sq = :x ** 2 | |
p sq[4] | |
p :sqrt[:x][13] | |
p :sqrt[:a ** 2 + :b ** 2][2, 3] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment