Created
May 5, 2017 13:22
-
-
Save rleegates/44e5838e8c3089ea277c70b655f21259 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module SimplePolynomials | |
typealias ExponentInt Int | |
import StaticArrays: SVector | |
import Base: (+), (*), (^), zero, one | |
immutable Polynomial{T,NV} | |
exps::Vector{SVector{NV,ExponentInt}} | |
coeffs::Vector{T} | |
end | |
zero{T,NV}(::Type{Polynomial{T,NV}}) = Polynomial([zeros(SVector{NV,ExponentInt})],[zero(T)]) | |
one{T,NV}(::Type{Polynomial{T,NV}}) = Polynomial([zeros(SVector{NV,ExponentInt})],[one(T)]) | |
function vars{T,NV}(::Type{Polynomial{T,NV}}) | |
return SVector{NV,Polynomial{T,NV}}( | |
ntuple(i -> begin | |
exp = SVector{NV,ExponentInt}(ntuple(ii->ifelse(ii==i, one(ExponentInt), zero(ExponentInt)), Val{NV})) | |
poly = Polynomial([exp],[one(T)]) | |
end, Val{NV}) | |
) | |
end | |
function (*){T1,T2,NV}(a::Tuple{T1,SVector{NV,ExponentInt}},b::Tuple{T2,SVector{NV,ExponentInt}}) | |
# (x^1*y^0*z^3)(x^1*y^0*z^3) = x^2*y^0*z^6 | |
return (a[1]*b[1],a[2]+b[2]) | |
end | |
function (+){T1,T2,NV}(a::Tuple{T1,SVector{NV,ExponentInt}},b::Tuple{T2,SVector{NV,ExponentInt}}) | |
# 1*(x^1*y^0*z^3)+1*(x^1*y^0*z^3) = 2*x^1*y^0*z^3 | |
# must be a[2] === b[2]!!! | |
@assert a[2] === b[2] | |
return (a[1]+b[1],a[2]) | |
end | |
function (*){T1,T2,NV}(a::Polynomial{T1,NV},b::Polynomial{T2,NV}) | |
# (a[1]+a[2])*(b[1]+b[2]) = a[1]*b[1] + a[1]*b[2] + a[2]*b[1] + a[2]*b[2] | |
ae = a.exps | |
ac = a.coeffs | |
be = b.exps | |
bc = b.coeffs | |
res = zero(Polynomial{promote_type(T1,T2),NV}) | |
for i = 1:length(ae) | |
for j = 1:length(be) | |
tmp = (ac[i],ae[i])*(bc[j],be[j]) | |
res += Polynomial([tmp[2]],[tmp[1]]) | |
end | |
end | |
return res | |
end | |
function (*){T1<:Number,T2,NV}(a::T1, b::Polynomial{T2,NV}) | |
be = b.exps | |
bc = b.coeffs | |
return Polynomial(be, a*bc) | |
end | |
(*){T1,T2<:Number,NV}(a::Polynomial{T1,NV}, b::T2) = *(b,a) | |
function fast_findfirst{T}(x::T, v::Vector{T}) | |
for i = 1:length(v) | |
if v[i] === x | |
return i | |
end | |
end | |
return 0 | |
end | |
function (+){T1,T2,NV}(a::Polynomial{T1,NV},b::Polynomial{T2,NV}) | |
ae = a.exps | |
ac = a.coeffs | |
be = b.exps | |
bc = b.coeffs | |
common_bi = Vector{Int}() | |
ce = Vector{SVector{NV,ExponentInt}}() | |
cc = Vector{promote_type(T1,T2)}() | |
for i = 1:length(ae) | |
aei = ae[i] | |
aci = ac[i] | |
bi = fast_findfirst(aei,be) | |
if bi > 0 | |
tmp = (aci,aei)+(bc[bi],be[bi]) | |
push!(cc,tmp[1]) | |
push!(ce,tmp[2]) | |
push!(common_bi,bi) | |
else | |
push!(cc,aci) | |
push!(ce,aei) | |
end | |
end | |
for i = 1:length(be) | |
bei = be[i] | |
cbi = fast_findfirst(i,common_bi) | |
if cbi == 0 | |
bci = bc[i] | |
push!(cc,bci) | |
push!(ce,bei) | |
end | |
end | |
return Polynomial(ce, cc) | |
end | |
function (^){T1,T2<:Int,NV}(a::Polynomial{T1,NV},power::T2) | |
if power == 0 | |
return one(Polynomial{T1,NV}) | |
elseif power == 1 | |
return a | |
else | |
f, r = divrem(power, 2) | |
return a^(f+r) * a^f | |
end | |
end | |
function (a::Polynomial{T1,NV}){T1,T2,NV,MV}(x::SVector{NV,Polynomial{T2,MV}}) | |
ae = a.exps | |
ac = a.coeffs | |
RES_T = Polynomial{promote_type(T1,T2),MV} | |
I = one(RES_T) | |
res = zero(RES_T) | |
for i = 1:length(ae) | |
aei = ae[i] | |
aci = ac[i] | |
tmp = aci * I | |
for j = 1:NV | |
tmp *= x[j]^aei[j] | |
end | |
res += tmp | |
end | |
return res | |
end | |
function (a::Polynomial{T1,NV}){T1,T2<:Number,NV,MV}(x::SVector{MV,T2}) | |
ae = a.exps | |
ac = a.coeffs | |
RES_T = promote_type(T1,T2) | |
I = one(RES_T) | |
res = zero(RES_T) | |
for i = 1:length(ae) | |
aei = ae[i] | |
aci = ac[i] | |
tmp = aci * I | |
for j = 1:NV | |
tmp *= x[j]^aei[j] | |
end | |
res += tmp | |
end | |
return res | |
end | |
end | |
import SimplePolynomials: Polynomial, vars | |
using BenchmarkTools | |
import StaticArrays: SVector | |
info("Testing init") | |
x = vars(Polynomial{Float64,2}) | |
display(@benchmark vars(Polynomial{Float64,2})) | |
(xx,yy) = x | |
info("Testing (+)") | |
test_plus{T,NV}(x::Polynomial{T,NV},y::Polynomial{T,NV}) = (r = x; for _ = 1:10000; r += x+x+y+y; end; r) | |
test_plus(xx,yy) | |
display(@benchmark test_plus(xx,yy)) | |
info("Testing evaluation at polynomials") | |
p = xx+yy+2*xx^2+yy^3 | |
p(x) | |
display(@benchmark p(x)) | |
info("Testing evaluation at floats") | |
p = xx+yy+2*xx^2+yy^3 | |
vals = SVector{2,Float64}((1.,2.)) | |
p(vals) | |
display(@benchmark p(vals)) | |
info("Testing evaluation at floats (MPoly)") | |
import MultiPoly: MPoly, generators, evaluate | |
xxx,yyy = generators(MPoly{Float64},:x,:y) | |
p = xxx+yyy+2*xxx^2+yyy^3 | |
evaluate(p,1.,2.) | |
display(@benchmark evaluate(p,1.,2.)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment