Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@rleegates
Created May 5, 2017 13:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rleegates/44e5838e8c3089ea277c70b655f21259 to your computer and use it in GitHub Desktop.
Save rleegates/44e5838e8c3089ea277c70b655f21259 to your computer and use it in GitHub Desktop.
module SimplePolynomials
typealias ExponentInt Int
import StaticArrays: SVector
import Base: (+), (*), (^), zero, one
immutable Polynomial{T,NV}
exps::Vector{SVector{NV,ExponentInt}}
coeffs::Vector{T}
end
zero{T,NV}(::Type{Polynomial{T,NV}}) = Polynomial([zeros(SVector{NV,ExponentInt})],[zero(T)])
one{T,NV}(::Type{Polynomial{T,NV}}) = Polynomial([zeros(SVector{NV,ExponentInt})],[one(T)])
function vars{T,NV}(::Type{Polynomial{T,NV}})
return SVector{NV,Polynomial{T,NV}}(
ntuple(i -> begin
exp = SVector{NV,ExponentInt}(ntuple(ii->ifelse(ii==i, one(ExponentInt), zero(ExponentInt)), Val{NV}))
poly = Polynomial([exp],[one(T)])
end, Val{NV})
)
end
function (*){T1,T2,NV}(a::Tuple{T1,SVector{NV,ExponentInt}},b::Tuple{T2,SVector{NV,ExponentInt}})
# (x^1*y^0*z^3)(x^1*y^0*z^3) = x^2*y^0*z^6
return (a[1]*b[1],a[2]+b[2])
end
function (+){T1,T2,NV}(a::Tuple{T1,SVector{NV,ExponentInt}},b::Tuple{T2,SVector{NV,ExponentInt}})
# 1*(x^1*y^0*z^3)+1*(x^1*y^0*z^3) = 2*x^1*y^0*z^3
# must be a[2] === b[2]!!!
@assert a[2] === b[2]
return (a[1]+b[1],a[2])
end
function (*){T1,T2,NV}(a::Polynomial{T1,NV},b::Polynomial{T2,NV})
# (a[1]+a[2])*(b[1]+b[2]) = a[1]*b[1] + a[1]*b[2] + a[2]*b[1] + a[2]*b[2]
ae = a.exps
ac = a.coeffs
be = b.exps
bc = b.coeffs
res = zero(Polynomial{promote_type(T1,T2),NV})
for i = 1:length(ae)
for j = 1:length(be)
tmp = (ac[i],ae[i])*(bc[j],be[j])
res += Polynomial([tmp[2]],[tmp[1]])
end
end
return res
end
function (*){T1<:Number,T2,NV}(a::T1, b::Polynomial{T2,NV})
be = b.exps
bc = b.coeffs
return Polynomial(be, a*bc)
end
(*){T1,T2<:Number,NV}(a::Polynomial{T1,NV}, b::T2) = *(b,a)
function fast_findfirst{T}(x::T, v::Vector{T})
for i = 1:length(v)
if v[i] === x
return i
end
end
return 0
end
function (+){T1,T2,NV}(a::Polynomial{T1,NV},b::Polynomial{T2,NV})
ae = a.exps
ac = a.coeffs
be = b.exps
bc = b.coeffs
common_bi = Vector{Int}()
ce = Vector{SVector{NV,ExponentInt}}()
cc = Vector{promote_type(T1,T2)}()
for i = 1:length(ae)
aei = ae[i]
aci = ac[i]
bi = fast_findfirst(aei,be)
if bi > 0
tmp = (aci,aei)+(bc[bi],be[bi])
push!(cc,tmp[1])
push!(ce,tmp[2])
push!(common_bi,bi)
else
push!(cc,aci)
push!(ce,aei)
end
end
for i = 1:length(be)
bei = be[i]
cbi = fast_findfirst(i,common_bi)
if cbi == 0
bci = bc[i]
push!(cc,bci)
push!(ce,bei)
end
end
return Polynomial(ce, cc)
end
function (^){T1,T2<:Int,NV}(a::Polynomial{T1,NV},power::T2)
if power == 0
return one(Polynomial{T1,NV})
elseif power == 1
return a
else
f, r = divrem(power, 2)
return a^(f+r) * a^f
end
end
function (a::Polynomial{T1,NV}){T1,T2,NV,MV}(x::SVector{NV,Polynomial{T2,MV}})
ae = a.exps
ac = a.coeffs
RES_T = Polynomial{promote_type(T1,T2),MV}
I = one(RES_T)
res = zero(RES_T)
for i = 1:length(ae)
aei = ae[i]
aci = ac[i]
tmp = aci * I
for j = 1:NV
tmp *= x[j]^aei[j]
end
res += tmp
end
return res
end
function (a::Polynomial{T1,NV}){T1,T2<:Number,NV,MV}(x::SVector{MV,T2})
ae = a.exps
ac = a.coeffs
RES_T = promote_type(T1,T2)
I = one(RES_T)
res = zero(RES_T)
for i = 1:length(ae)
aei = ae[i]
aci = ac[i]
tmp = aci * I
for j = 1:NV
tmp *= x[j]^aei[j]
end
res += tmp
end
return res
end
end
import SimplePolynomials: Polynomial, vars
using BenchmarkTools
import StaticArrays: SVector
info("Testing init")
x = vars(Polynomial{Float64,2})
display(@benchmark vars(Polynomial{Float64,2}))
(xx,yy) = x
info("Testing (+)")
test_plus{T,NV}(x::Polynomial{T,NV},y::Polynomial{T,NV}) = (r = x; for _ = 1:10000; r += x+x+y+y; end; r)
test_plus(xx,yy)
display(@benchmark test_plus(xx,yy))
info("Testing evaluation at polynomials")
p = xx+yy+2*xx^2+yy^3
p(x)
display(@benchmark p(x))
info("Testing evaluation at floats")
p = xx+yy+2*xx^2+yy^3
vals = SVector{2,Float64}((1.,2.))
p(vals)
display(@benchmark p(vals))
info("Testing evaluation at floats (MPoly)")
import MultiPoly: MPoly, generators, evaluate
xxx,yyy = generators(MPoly{Float64},:x,:y)
p = xxx+yyy+2*xxx^2+yyy^3
evaluate(p,1.,2.)
display(@benchmark evaluate(p,1.,2.))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment