Skip to content

Instantly share code, notes, and snippets.

@tnagler
Last active May 17, 2023 15:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tnagler/62f6ce1f996333c799c81f1aef147e72 to your computer and use it in GitHub Desktop.
Save tnagler/62f6ce1f996333c799c81f1aef147e72 to your computer and use it in GitHub Desktop.
Simulation study for the paper "Statistical Foundations of Prior-Data Fitted Networks" (ICML 2023)
# required libraries
library("reticulate")
library("tidyverse")
library("RANN")
library("qrng")
library("readr")
library("ggthemes")
## Python setup -------------------------------------------
use_python("/usr/local/bin/python") # may need modification
# python library in env: https://github.com/automl/TabPFN
interface <- import("tabpfn.scripts.transformer_prediction_interface")
pfn <- interface$TabPFNClassifier(device = "cuda")
## Simulation/evaluation logic -----------------------------
# true model
p_0 <- function(x) 0.5 + 0.5 * sin(rowSums(as.matrix(x)))
# fitted model
p_hat <- function(x, fit) {
p <- fit$predict(x, return_winning_probability = TRUE)
(1 - p[[2]]) * (p[[1]] == 0) + p[[2]] * p[[1]]
}
# localized version
localized_pfn_predict <- function(x_test, x_train, y_train) {
d <- NCOL(x_train)
n_train <- NROW(x_train)
x_test <- matrix(x_test, ncol = d)
n_test <- NROW(x_test)
# find k_n nearest neighbors
p_test <- numeric(n_test)
k_opt <- round(n_train * min((n_train / 500)^(-d/(d + 4)), 1))
nns <- RANN::nn2(x_train, x_test, k = k_opt)
# predict each test sample
for (i in 1:n_test) {
i_sel <- nns$nn.idx[i, ]
pfn_fit <- pfn$fit(x_train[i_sel, , drop = FALSE], y_train[i_sel],
overwrite_warning = TRUE)
p_test[i] <- p_hat(x_test[i, , drop = FALSE], pfn_fit)
}
p_test
}
# one round of simulation/evaluation
do_one <- function(seed, n, x_test) {
print(paste0("n = ", n, " seed = ", seed))
set.seed(seed)
x_train <- replicate(5, rnorm(n))
y_train <- rbinom(n, 1, p_0(x_train))
pfn <- pfn$fit(x_train, y_train, overwrite_warning = TRUE)
err_tab <- p_0(x_test) - p_hat(x_test, pfn)
err_ltab <- if (n > 1000) {
p_0(x_test) - localized_pfn_predict(x_test, x_train, y_train)
} else {
err_tab
}
data.frame(error = err_tab, i = 1:nrow(x_test), method = "tab") |>
rbind(data.frame(error = err_ltab, i = 1:nrow(x_test), method = "ltab")) |>
mutate(seed = seed, n = n, true = rep(p_0(x_test), 2))
}
## run study ---------------------------------------
x_test <- qnorm(qrng::ghalton(500, 5))
scenarios <- crossing(seed = 1:500, n = c(200, 500, 1000, 2000, 4000))
results <- map2_dfr(scenarios$seed, scenarios$n, ~ do_one(.x, .y, x_test))
readr::write_csv(results, "sim-results.csv")
# results <- readr::read_csv("sim-results.csv")
## analyze results --------------------------------
bias <- group_by(results, n, method, i) |>
summarize(val = mean(error)^2, se = var(error) / length(error)) |>
group_by(n, method) |>
summarize(val = mean(val), se = mean(se) / sqrt(length(se)))
variance <- group_by(results, n, method, i) |>
summarize(val = var(error), se = var(error) / sqrt(0.5 * length(error)))|>
group_by(n, method) |>
summarize(val = mean(val), se = mean(se) / sqrt(length(se)))
df <- bind_rows(
`average squared bias` = bias,
`average variance` = variance,
.id = "type"
)
df |>
mutate(
method = factor(method, c("tab", "ltab")) |>
fct_recode(`TabPFN` = "tab", `TabPFN + Localization` = "ltab")
) |>
ggplot(aes(n, val, color = method)) +
facet_wrap(~type, scales = "free") +
geom_point() +
geom_line() +
geom_errorbar(
aes(ymin = val - 1.96 * se, ymax = val + 1.96 * se),
width = 1,
position = position_dodge(0.05)
) +
theme_minimal() +
ylab("") +
expand_limits(y = 0, x = 0) +
scale_color_manual(values = rev(ggthemes::hc_pal()(2))) +
theme(legend.position = "bottom")
ggsave("bias-variance.pdf", width = 8, height = 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment