Skip to content

Instantly share code, notes, and snippets.

@iwelch
Last active September 9, 2018 00:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iwelch/5090b475b07bf19320b3b6ea4a603a6c to your computer and use it in GitHub Desktop.
Save iwelch/5090b475b07bf19320b3b6ea4a603a6c to your computer and use it in GitHub Desktop.
AS75 in julia
################################################################
## Implementation of Gentleman's AS75 Regression in Julia
## open source, attribute first julia draft to ivo welch, but please improve
##
## available as https://gist.github.com/iwelch/5090b475b07bf19320b3b6ea4a603a6c in iwelch/Gentleman's AS75 WLS Regression in Julia
##
## feb 2018, first draft, iaw
## sep 2018, updated to julia 1.0
################################################################
using LinearAlgebra
const TINY= 1e-8
mutable struct Reg
regname::String
varnames::Array{String}
k::Int ## number of variables, must be provided
neverabort::Bool
d::Vector{Float64}
thetabar::Vector{Float64}
rbar::Vector{Float64} ## longer!
sy::Float64
sx::Float64
syy::Float64
sse::Float64
wghtn::Float64
sigmasq::Float64
n::Int
iscomputed::Bool ## set below is valid only when iscomputed
theta::Vector{Float64}
ybar::Float64
sst::Float64
rsq::Float64
adjrsq::Float64
scale::Float64 ## control print output
storeobs::Bool ## if set, then keep obs in stack below
obs::Vector{ Vector{Float64} } ## should push onto this
iscomputedx::Bool
xpxinv::Matrix{Float64}
thetase::Vector{Float64}
end#struct reg
function Reg( numindependent::Int, iregname::String= "unknown", iregvarnames::Vector{String}= repmat( [""], numindependent+1 ), keepobs::Bool= false )
@assert (length(iregvarnames)-1 == numindependent) "Wrong Usage ($(length(iregvarnames))). Names Must be y x1 x2 x3..."
matxlen= div( (numindependent+1)*(0+numindependent), 2 )
Reg( iregname, iregvarnames,
numindependent, true,
zeros(numindependent), zeros(numindependent), zeros( matxlen ),
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0,
false, fill( NaN, numindependent ), NaN, NaN, NaN, NaN, 1.0,
keepobs, [],
false, fill( NaN, numindependent, numindependent), fill( NaN, numindependent ) )
end#function Reg
################
""" include a data point (y, x vector) with weight in the WLS r """
function reginclude(r::Reg, yelement::Float64, xcopy::Vector{Float64}, weight::Float64 = 1.0)
( (isnan(yelement)) || (any(isnan.(xcopy))) ) && return
( weight == 0.0) && return
r.iscomputed= r.iscomputedx= false ; r.theta[1]= r.thetase[1]= NaN
x= deepcopy(xcopy) ## otherwise, xcopy would be trashed
(r.storeobs) && push!(r.obs, vcat( y, xcopy ))
r.syy+= (weight*(yelement*yelement))
r.sy+= (weight*(yelement))
r.n+= (weight>0.0) ? 1 : (-1) ## not checked in weighted CCREGRession: covers -1 as deletion
r.wghtn+= weight
cbar::Float64= 0.0; sbar::Float64= 0.0; xi::Float64= 0.0; xk::Float64= 0.0; di::Float64= 0.0; dprimei::Float64= 0.0
residual::Float64= yelement;
for i= 1:r.k
(weight == 0.0) && return
if (abs(x[i])>TINY)
xi= x[i];
di= r.d[i];
dprimei= di+weight*(xi*xi); cbar= di/dprimei;
sbar= weight*xi/dprimei;
weight*= (cbar)
r.d[i]= dprimei;
nextr= (i-1)* div( (2*r.k-i), 2 ) + 1
for k= (i+1):(r.k)
xk= x[k]; x[k]= xk-xi*r.rbar[nextr];
r.rbar[nextr]= cbar*r.rbar[nextr] + sbar*xk;
nextr+= 1
end#for
xk= residual; residual-= xi*r.thetabar[i];
r.thetabar[i]= cbar*r.thetabar[i]+sbar*xk;
end#if
end#for
r.sse+= weight*(residual*residual)
end#function reginclude
################
""" calculate the coefficients and r^2 of the WLS r. """
function regcoefs(r::Reg)::Vector{Float64}
(r.iscomputed) && return r.theta
if (r.n < r.k)
if (r.neverabort) r.theta= repmat( [NaN], r.k )
return r.theta
else
error("fewer observations $(r.n) than coefficient estimates $(r.k)")
end#if
end#if
for i= r.k:-1:1 ## count down, not up
r.theta[i]= r.thetabar[i];
nextr= (i-1) * div( 2*r.k-i, 2 ) + 1
for k= (i+1):r.k
r.theta[i] -= (r.rbar[nextr]*r.theta[k])
nextr+= 1
end#for k
end#for i
r.ybar= r.sy/r.wghtn;
r.sst= (r.syy - r.wghtn*r.ybar*r.ybar)
if (r.sst < 1e-7)
if (r.neverabort)
r.sst= 0.0
else
error("no meaningful variation in regression y: $(r.sst) $(r.syy) $(r.wghtn)")
end#if /* no misleading */
end#if
r.rsq= 1.0 - r.sse / r.sst;
r.adjrsq= 1.0- (1.0-r.rsq)*(r.n-1)/(r.n - r.k)
r.iscomputed= true
r.theta
end#function regcoefs
################################
## the xpxinv function is *not* well written, because julia has better facilities built-in---but it works well
""" internal function: the thetacov X' X = R' D R """
function rbrindx(r::Reg, x::Int,y::Int)::Int
( (x == y) ? error("eq $x= $y") : (x>y) ? error("$x>$y") : (((x-1)*( div(2*r.k-x, 2) )+ 1 + y - 1 - x )) )
end#function
function rbr(r::Reg, x::Int,y::Int)::Float64
((x == y) ? (1.0) : ( (x > y) ? (0.0) : r.rbar[Int( rbrindx(r,x,y)) ]))
end#function
function xpxinv(r::Reg)::Matrix{Float64}
u= zeros( r.k, r.k )
for j= r.k:-1:1
u[j,j]= 1.0/(rbr(r,j,j))
for k= j-1:-1:1
u[k,j]= 0;
for i= (k+1):j; u[k,j]+= rbr(r,k,i)*u[i,j]; end#for
u[k,j]*= (-1.0)/rbr(r,k,k)
end#for
end#for
# now we have the inverse in u[], let's multiply it by D (-1/2)
for i= 1:r.k
for j= 1:r.k
if (abs(r.d[j]) < TINY)
u[i,j]*= sqrt(1.0/TINY);
(fabs(r.d[j]) == 0.0) && error("cannot compute the theta-covariance matrix for variable j= $(j) (d[j]= $(r.d[j]))\n")
else
u[i,j]*= sqrt(1/r.d[j]);
end
end
end
# and let's square this U matrix
r.xpxinv = zeros( r.k, r.k )
r.sigmasq= (r.n <= r.k) ? Inf : (r.sse/(r.n - r.k));
for i= 1:r.k
for j= 1:r.k
r.xpxinv[i,j]= 0.0;
for k= 1:r.k; r.xpxinv[i,j]+= u[i,k] * u[j,k]; end#for
end#for
end#for
r.xpxinv
end#function xpxinv
""" calculate the stderr of the coefficients of the WLS r. """
function regcoefsse(r::Reg)::Vector{Float64}
(r.iscomputedx) && return r.thetase
r.iscomputedx= true
return r.thetase= sqrt.(diag( xpxinv(r) )*r.sigmasq)
end#function regcoefsse
################
""" predict with the regression r the y value at a provided point x """
function regpredict(r::Reg, xrow::Vector{Float64} )::Float64
regcoefs(r)
return sum( r.theta .* xrow )
end#function fit
################
""" predict with the regression r the y value **with standard errors** at a provided point x """
function regpredictse(r::Reg, xcopy::Vector{Float64})::Float64
error("neither draw se nor mean se implemented yet")
end#function regcoefsse
################
""" convenience """
function regresiduals(r::Reg)::Vector{Float64}
if (!storeobs) error("you need to init with storeobs to obtain residuals"); end#if
error("regresiduals not yet implemented")
end#function regresiduals
################
""" convenience """
function regfitted(r::Reg)::Vector{Float64}
if (!storeobs) error("you need to init with storeobs to obtain fitted"); end#if
error("regfitted not yet implemented")
end#function regfitted
################
""" print the output for regression r in a nice way """
function regprint(r::Reg, file)::String
error("nice reg printing not yet implemented")
end#function regprint
################################################################################################################################
if ( (PROGRAM_FILE != "") && (endswith(@__FILE__, PROGRAM_FILE)) )
println("Blank Invokation --- Test Unit")
function testrun(verbose::Bool= false)
myreg= Reg(3, "testing regression", [ "i", "constant", "i-squared", "i-cubed" ] ) ## constant and two independent variables
for i= 1:10
f= Float64(i)
println("y= $f\t x= 1.0 $(f*f) $(f*f*f)")
reginclude(myreg, f, [ 1.0, f*f, f*f*f ] )
end#for
println("\nResults:")
c= regcoefs(myreg);
println("Coefs= ", c)
@assert (abs( sum( c - [ 1.247, 0.201943, -0.0116423 ] ) ) < 1e-3) "Coefficients are Wrong"
cse= regcoefsse(myreg)
println("Coefs.se= ", cse)
@assert (abs( sum( cse - [0.179253, 0.0158779, 0.00157859] ) ) < 1e-3) "Standard Errors are Wrong"
println( "rsq= ", myreg.rsq )
@assert( abs( myreg.rsq - 0.9942721640021418 ) < 1e-6, "wrong R^2 $(myreg.rsq)" )
p= regpredict(myreg, [1.0,4.0,-8.0] )
println( "predict(1,4,-8)= ", p)
@assert( abs( p - 2.1479079017823004 ) < 1e-6, "wrong p" )
println("\nresults checked out.");
if (verbose); println( "\n\nInternals:"); dump(myreg); end#if
end#function
testrun()
end#if#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment