Skip to content

Instantly share code, notes, and snippets.

@fela
Last active August 29, 2015 14:08
Show Gist options
  • Save fela/f999637ccb813318e6e5 to your computer and use it in GitHub Desktop.
Save fela/f999637ccb813318e6e5 to your computer and use it in GitHub Desktop.
#########################################
# helpers #
#########################################
# helper to calculate s
s.func <- function(X) {
# norms ^ 4
norms4 <- apply(X, 2, function(Xi) (Xi%*%Xi)^2)
n <- ncol(X)
s <- (sum(norms4)/n)^(1/4)
return(s)
}
# helper to calculate Ntheta
N.func <- function(theta, X) {
n <- ncol(X)
# squared products
sq.prods <- apply(X, 2, function(Xi) (theta %*% Xi)^2)
N <- sum(sq.prods)/n
return(N)
}
# helper to calculate psi
psi.func <- function(t) {
# use ifelse to make sure it's all vectorized
# and removed recursive call
res <- ifelse(
t >= 1,
log(2), # greater than 1
ifelse(
t >= 0,
-log(1-t+t^2/2), # between 0 and 1
ifelse(
t >= -1,
log(1+t+(-t)^2/2), # between -1 and 0
-log(2) # less then -1
)
)
)
return(res)
}
#########################################
# hatN #
#########################################
hatN <- function(theta, X) {
# c
const <- 15*exp((1+2*sqrt(2))/2)/(8*(sqrt(2)-1)*log(2))
k <- 3
e <- 0.05
# s
s <- s.func(X)
# N(theta)
Ntheta <- N.func(theta, X)
# lambda
n <- ncol(X)
fac1 <- 2/(n*(k-1))
fac2 <- (2+3*const)*s^2 / (4*(2+const)*sqrt(k)*Ntheta) + log(1/e)
lambda <- sqrt(fac1*fac2)
params = apply(X, 2, function(Xi) (Xi %*% theta))^2
uglyfunc <- function(a) {
sum(psi.func(params * a^2 - lambda))
}
# plot uglyfunc
#alpha <- seq(-0.4, 0.4, length.out=100)
#plot(alpha, sapply(alpha, uglyfunc))
a <- find.zero.right(uglyfunc, 0)
return(lambda/a^2)
}
# find the zero at the right of min of a function
# should work is the function is monotone and smooth
find.zero.right <- function(func, min) {
if (func(min) >= 0) {
stop("the function should be less than zero at `min`")
}
# first find a point where the function is > 0
max <- 1 # starting max that will be multiplied till it's big enough
mult <- 100 # multiplier
while(func(max) < 0) {
max <- max * mult
}
res <<- uniroot(func, c(min, max), tol=1e-08)
return(res$root)
}
#########################################
# tests #
#########################################
# plots psi
test.psi <- function() {
x <- seq(-1.2, 1.2, 0.02)
plot(x, psi(x))
}
# testing Ntheta and hatn
test.hatN <- function(n=10000) {
theta <- runif(2)
# make uniform
theta <- theta / sqrt(theta %*% theta)
X <- rnorm(2*n)
dim(X) <- c(2, n)
cat("N", N.func(theta, X), "\n")
cat("hatN", hatN(theta, X))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment