Skip to content

Instantly share code, notes, and snippets.

Last active July 22, 2024 10:16
Show Gist options
  • Save tnagler/62f6ce1f996333c799c81f1aef147e72 to your computer and use it in GitHub Desktop.
Save tnagler/62f6ce1f996333c799c81f1aef147e72 to your computer and use it in GitHub Desktop.
Simulation study for the paper "Statistical Foundations of Prior-Data Fitted Networks" (ICML 2023)
# required libraries
## Python setup -------------------------------------------
use_python("/usr/local/bin/python") # may need modification
# python library in env:
interface <- import("tabpfn.scripts.transformer_prediction_interface")
pfn <- interface$TabPFNClassifier(device = "cuda")
## Simulation/evaluation logic -----------------------------
# true model
p_0 <- function(x) 0.5 + 0.5 * sin(rowSums(as.matrix(x)))
# fitted model
p_hat <- function(x, fit) {
p <- fit$predict(x, return_winning_probability = TRUE)
(1 - p[[2]]) * (p[[1]] == 0) + p[[2]] * p[[1]]
# localized version
localized_pfn_predict <- function(x_test, x_train, y_train) {
d <- NCOL(x_train)
n_train <- NROW(x_train)
x_test <- matrix(x_test, ncol = d)
n_test <- NROW(x_test)
# find k_n nearest neighbors
p_test <- numeric(n_test)
k_opt <- round(n_train * min((n_train / 500)^(-d/(d + 4)), 1))
nns <- RANN::nn2(x_train, x_test, k = k_opt)
# predict each test sample
for (i in 1:n_test) {
i_sel <- nns$nn.idx[i, ]
pfn_fit <- pfn$fit(x_train[i_sel, , drop = FALSE], y_train[i_sel],
overwrite_warning = TRUE)
p_test[i] <- p_hat(x_test[i, , drop = FALSE], pfn_fit)
# one round of simulation/evaluation
do_one <- function(seed, n, x_test) {
print(paste0("n = ", n, " seed = ", seed))
x_train <- replicate(5, rnorm(n))
y_train <- rbinom(n, 1, p_0(x_train))
pfn <- pfn$fit(x_train, y_train, overwrite_warning = TRUE)
err_tab <- p_0(x_test) - p_hat(x_test, pfn)
err_ltab <- if (n > 1000) {
p_0(x_test) - localized_pfn_predict(x_test, x_train, y_train)
} else {
data.frame(error = err_tab, i = 1:nrow(x_test), method = "tab") |>
rbind(data.frame(error = err_ltab, i = 1:nrow(x_test), method = "ltab")) |>
mutate(seed = seed, n = n, true = rep(p_0(x_test), 2))
## run study ---------------------------------------
x_test <- qnorm(qrng::ghalton(500, 5))
scenarios <- crossing(seed = 1:500, n = c(200, 500, 1000, 2000, 4000))
results <- map2_dfr(scenarios$seed, scenarios$n, ~ do_one(.x, .y, x_test))
readr::write_csv(results, "sim-results.csv")
# results <- readr::read_csv("sim-results.csv")
## analyze results --------------------------------
bias <- group_by(results, n, method, i) |>
summarize(val = mean(error)^2, se = var(error) / length(error)) |>
group_by(n, method) |>
summarize(val = mean(val), se = mean(se) / sqrt(length(se)))
variance <- group_by(results, n, method, i) |>
summarize(val = var(error), se = var(error) / sqrt(0.5 * length(error)))|>
group_by(n, method) |>
summarize(val = mean(val), se = mean(se) / sqrt(length(se)))
df <- bind_rows(
`average squared bias` = bias,
`average variance` = variance,
.id = "type"
df |>
method = factor(method, c("tab", "ltab")) |>
fct_recode(`TabPFN` = "tab", `TabPFN + Localization` = "ltab")
) |>
ggplot(aes(n, val, color = method)) +
facet_wrap(~type, scales = "free") +
geom_point() +
geom_line() +
aes(ymin = val - 1.96 * se, ymax = val + 1.96 * se),
width = 1,
position = position_dodge(0.05)
) +
theme_minimal() +
ylab("") +
expand_limits(y = 0, x = 0) +
scale_color_manual(values = rev(ggthemes::hc_pal()(2))) +
theme(legend.position = "bottom")
ggsave("bias-variance.pdf", width = 8, height = 4)
Copy link

You should unlock the door to academic success with our team of expert writers poised to elevate your papers to new heights. From in-depth research to eloquent writing, our professionals ensure your ideas are brought to life with precision and creativity. Say to writing challenges and welcome top-notch quality with our dedicated service. Let us guide you towards excellence with tailored papers that showcase your knowledge and expertise. Experience academic brilliance today

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment