Skip to content

Instantly share code, notes, and snippets.

@kylebgorman
Created September 8, 2013 00:36
Show Gist options
  • Save kylebgorman/6480811 to your computer and use it in GitHub Desktop.
Save kylebgorman/6480811 to your computer and use it in GitHub Desktop.
lda-match.R: perform group matching via backward selection using a heuristic based on Fisher's linear discriminant
#!/usr/bin/env Rscript
# lda-match.R: Perform group matching via backward selection using a heuristic based on Fisher's
# linear discriminant
# Kyle Gorman <gormanky@ohsu.edu>
require(MASS)
lda.match <- function(x, grouping, term.fnc=univariate.all) {
# Create a matched group via backward selection using a heuristic
# based on Fisher's linear discriminant. Observations are removed
# in the order of their distance from the mean value of a linear
# projection.
#
# This proceudre
#
# @param x a matrix in which columns contain numerical
# features on which to match
# @param grouping a factor vector containing corresponding group
# labels
# @param term.fnc function to which (x, grouping) will be applied;
# selection halts iff it returns TRUE
# @return a logical vector, TRUE iff row is in the match
#
stopifnot(nrow(x) == length(grouping))
# compute the projection
if (ncol(x) == 1)
projection <- x
else
projection <- (x %*% lda(x, grouping$scaling)[, 1]
# things to keep track of
include <- rep(TRUE, nrow(x))
gtable <- table(grouping[include])
nxt <- tapply(levels(grouping), levels(grouping), function(x) 1)
# by-group list of indices of (relative) outliers
ord <- order(projection)
by.group <- split(ord, grouping[ord])
# reverse order of indices if group mean is on the righthand side
projection.mu <- mean(projection)
for (group in levels(grouping))
if (mean(projection[grouping==group]) > projection.mu)
by.group[[group]] <- rev(by.group[[group]])
while (any(include)) {
# determine larger group and take one out
group <- levels(grouping)[which.max(gtable)]
include[by.group[[group]]][nxt[group]] <- FALSE
# check for convergence
if (term.fnc(x[include, ], grouping[include]))
break
# adjust group sizes and "nxt" tables
gtable[group] <- gtable[group] - 1
nxt[group] <- nxt[group] + 1
}
return(include)
}
## some sample termination functions
univariate.all <- function(x, grouping, p.value=.2, fun=t.test) {
# perform a by-group t-test on all columns
for (i in 1:ncol(x))
if (fun(x[, i] ~ grouping)$p.value < p.value)
return(FALSE)
return(TRUE)
}
manova.one <- function(x, grouping, p.value=.2) {
# perform a MANOVA on grouping factor
summary(manova(x ~ grouping), test='Wilks')$stats[1, 6] > p.value
}
## sample run
d <- droplevels(subset(read.csv('DX-NRT.csv'), DX %in% c('ALI', 'ALN')))
print(table(d$DX))
mm <- matrix(with(d, c(NVIQ, ADOS, CA)), ncol=3)
print(table(d[lda.match(mm, d$DX), ]$DX))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment