Skip to content

Instantly share code, notes, and snippets.

@maxdrohde
Created July 23, 2021 18:10
Show Gist options
  • Save maxdrohde/8ae661944de9ddd753dbe90008893f4e to your computer and use it in GitHub Desktop.
Save maxdrohde/8ae661944de9ddd753dbe90008893f4e to your computer and use it in GitHub Desktop.
library(tidymodels)
library(tidyverse)
library(discrim)
library(gganimate)
# grid size
n <- 100
parabolic_grid <-
expand.grid(X1 = seq(-5, 5, length = n),
X2 = seq(-5, 5, length = n))
get_data <- function(neighbors){
parabolic_grid <-
expand.grid(X1 = seq(-5, 5, length = n),
X2 = seq(-5, 5, length = n))
fda_mod <-
nearest_neighbor(neighbors=neighbors) %>%
set_mode("classification") %>%
set_engine("kknn") %>%
fit(class ~ ., data = parabolic)
parabolic_grid$fda <-
predict(fda_mod, parabolic_grid, type = "prob")$.pred_Class1
parabolic_grid <-
parabolic_grid %>%
mutate(neighbors = neighbors)
return(parabolic_grid)
}
df <- map_df(c(1,2,3,4,5,6,7,8,9,10,15,20,30,40,50,100,200,300), ~get_data(neighbors=.x))
anim <-
ggplot(parabolic, aes(x = X1, y = X2)) +
geom_point(aes(color = class), alpha = .5) +
geom_contour(data = df, aes(z = fda), col = "black", breaks = .5) +
cowplot::theme_minimal_grid(font_family = "Source Sans Pro", font_size=12) +
theme(legend.position = "none") +
coord_equal() +
transition_manual(neighbors) +
labs(title= "Neighbors: {current_frame}")
gif <- animate(anim,
duration=12,
height = 5,
width = 5,
units = "in",
res = 300,
renderer = gifski_renderer())
# Save to mp4
anim_save(animation = gif, filename = "knn_anim.gif")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment