Created
September 24, 2021 21:13
-
-
Save b-rodrigues/4915812a8fa983651b3389c254f9b7cb to your computer and use it in GitHub Desktop.
Script for the video https://youtu.be/xErgS19d4vw
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
# from https://skill-lync.com/student-projects/week-4-genetic-algorithm-295 | |
# maximum is at x = 0.0663 y = 0.0673 f = 1 | |
stalagmite <- function(x, y){ | |
f1_x <- (sin(5.1*pi*x + 0.5))^6 | |
f1_y <- (sin(5.1*pi*y + 0.5))^6 | |
f2_x <- exp((-4*log(2))*((x-0.0667)^2)/0.64) | |
f2_y <- exp((-4*log(2))*((y-0.0667)^2)/0.64) | |
(f1_x*f2_x*f1_y*f2_y) | |
} | |
stalagmite2 <- function(x = c(0, 0)){ | |
x1 <- x[1] | |
x2 <- x[2] | |
f1_x1 <- (sin(5.1*pi*x1 + 0.5))^6 | |
f1_x2 <- (sin(5.1*pi*x2 + 0.5))^6 | |
f2_x1 <- exp((-4*log(2))*((x1-0.0667)^2)/0.64) | |
f2_x2 <- exp((-4*log(2))*((x2-0.0667)^2)/0.64) | |
(f1_x1*f2_x1*f1_x2*f2_x2) | |
} | |
library(rgl) | |
options(rgl.printRglwidget = TRUE) | |
persp3d(x = seq(0, 1, length.out = 100), y = seq(0, 1, length.out = 100), | |
z = outer(seq(0, 1, length.out = 100), seq(0, 1, length.out = 100), stalagmite), | |
col = "lightblue") | |
second_example <- function(x, y){ | |
-(x**2 + y - 11)**2 - (x + y**2 - 7)**2 | |
} | |
second_example2 <- function(x = c(0, 0)){ | |
x1 <- x[1] | |
x2 <- x[2] | |
-(x1**2 + x2 - 11)**2 - (x1 + x2**2 - 7)**2 | |
} | |
persp3d(x = seq(0, 5, length.out = 100), y = seq(0, 5, length.out = 100), | |
z = outer(seq(0, 5, length.out = 100), seq(0, 5, length.out = 100), second_example), | |
col = "lightblue") | |
third_example <- function(x, y){ | |
-(second_example(x, y)*exp(second_example(x, y))) | |
} | |
fourth_example <- function(x, y, z){ | |
-(second_example(x, z)*exp(second_example(x, y))) | |
} | |
fourth_example2 <- function(x = c(0, 0, 0)){ | |
x1 <- x[1] | |
x2 <- x[2] | |
x3 <- x[3] | |
-(second_example(x1, x3)*exp(second_example(x1, x2))) | |
} | |
persp3d(x = seq(2.3, 2.35, length.out = 100), y = seq(1.65, 1.78, length.out = 100), | |
z = outer(seq(0, 5, length.out = 100), seq(0, 5, length.out = 100), third_example), | |
col = "lightblue") | |
init_pop <- function(objective_function, pop_size = 100, upper_bound = 1, lower_bound = 0){ | |
#parameters <- formals(objective_function) | |
parameters <- formals(objective_function)[[1]] %>% | |
eval | |
purrr::rerun(length(parameters), runif(n = pop_size, | |
min = lower_bound, | |
max = upper_bound)) %>% | |
dplyr::bind_cols() %>% | |
janitor::clean_names() | |
} | |
init_pop_base <- function(objective_function, pop_size = 100, upper_bound = 1, lower_bound = 0){ | |
parameters <- formals(objective_function)[[1]] %>% | |
eval | |
replicate(length(parameters), | |
runif(n = pop_size, min = lower_bound, max = upper_bound)) | |
} | |
evaluate_candidates <- function(objective_function, population){ | |
population %>% | |
rowwise() %>% | |
mutate(score = objective_function(c_across(everything()))) %>% | |
ungroup() | |
#mutate(score = apply(X = cur_data(), MARGIN = 1, FUN = objective_function)) | |
#mutate(score = pmap_dbl(cur_data(), objective_function)) #<- works for obj_funcs with 3 or more args | |
#dplyr::mutate(score = purrr::map2_dbl(.x = x, .y = y, objective_function)) #<- only 2 args | |
} | |
evaluate_candidates_pmap <- function(objective_function, population){ | |
population %>% | |
mutate(score = pmap_dbl(., .f = lift_vd(objective_function))) | |
#population %>% | |
# rowwise() %>% | |
# mutate(score = objective_function(c_across(everything()))) %>% | |
# ungroup() | |
} | |
evaluate_candidates_base <- function(objective_function, population){ | |
scores <- apply(population, MARGIN = 1, FUN = objective_function) | |
cbind(population, scores) | |
} | |
select_parents <- function(scores, k = 10){ | |
scores <- tibble::rowid_to_column(scores) | |
top <- scores %>% | |
slice_max(order_by = score, n = k, with_ties = FALSE) | |
bottom <- scores %>% | |
slice_min(order_by = score, n = k, with_ties = FALSE) | |
others <- scores %>% | |
filter(!(rowid %in% top$rowid)) %>% | |
filter(!(rowid %in% bottom$rowid)) %>% | |
sample_n(size = k) | |
bind_rows(top, | |
others) %>% | |
select(-rowid) | |
} | |
select_parents_base <- function(scores, k = 10){ | |
scores <- scores[order(scores[, ncol(scores)], decreasing = TRUE),] | |
top <- head(scores, k) | |
bottom <- tail(scores, k) | |
others <- tail(scores, -k) # remove top k | |
others <- head(others, -k) # remove bottom k | |
random_index <- sample(seq(1, nrow(others)), size = k) | |
others <- others[random_index, ] | |
rbind(top, others) | |
} | |
crossover <- function(parents, r_cross = 0.99){ | |
parents %>% | |
select(-score) %>% | |
cross_df() %>% | |
rowwise() %>% | |
mutate(prob_cross = runif(1)) %>% | |
ungroup() %>% | |
filter(prob_cross < r_cross) %>% | |
select(-prob_cross) | |
} | |
crossover_base <- function(parents, r_cross = 0.99){ | |
parents <- parents[, -ncol(parents)] #remove score | |
parents <- expand.grid(asplit(parents, 2)) %>% | |
as.matrix() | |
prob_cross <- runif(nrow(parents)) | |
parents <- cbind(parents, prob_cross) | |
parents <- parents[parents[, ncol(parents)] < r_cross,] | |
parents[, -ncol(parents)] | |
} | |
mutation <- function(new_generation, r_mut = 0.05){ | |
new_generation %>% | |
rowwise() %>% | |
mutate(prob_mut = runif(1)) %>% | |
mutate(noise = ifelse(prob_mut < r_mut, rnorm(1), 0)) %>% | |
mutate(across(-c(prob_mut, noise), ~`+`(.x, noise))) %>% | |
select(-prob_mut, -noise) %>% | |
ungroup() | |
} | |
mutation_base <- function(new_generation, r_mut = 0.05){ | |
prob_mut <- runif(nrow(new_generation)) | |
noise <- ifelse(prob_mut < r_mut, rnorm(1), 0) | |
noise_array <- replicate(ncol(new_generation), noise) | |
new_generation + noise_array | |
} | |
genetic_alg <- function(objective_function, | |
evaluate_function, | |
population_size = 100, | |
iter = 10, | |
upper_bound = 1, | |
lower_bound = 0){ | |
one_run <- function(previous_run_result){ | |
previous_run_result %>% | |
select_parents() %>% | |
crossover() %>% | |
mutation() %>% | |
evaluate_function(objective_function, population = .) | |
} | |
prev_run <- init_pop(objective_function, | |
pop_size = population_size, | |
upper_bound = upper_bound, | |
lower_bound = lower_bound) %>% | |
evaluate_function(objective_function, population = .) %>% | |
select_parents() %>% | |
crossover() %>% | |
mutation() %>% | |
evaluate_function(objective_function, population = .) | |
result <- prev_run %>% | |
slice_max(order_by = score, n = 1, with_ties = FALSE) %>% | |
mutate(iteration = 1) | |
diff_score <- -1 | |
step <- 1 | |
while(step < iter){ | |
step <- step + 1 | |
this_run <- one_run(prev_run) %>% | |
mutate(iteration = step) | |
this_run_result <- this_run %>% | |
slice_max(order_by = score, n = 1, with_ties = FALSE) | |
result <- bind_rows(result, | |
this_run_result) | |
#diff_score <- result %>% | |
# filter(iteration %in% c(step, step-1)) %>% | |
# group_by(iteration) %>% | |
# summarise(max_score = max(score)) | |
#diff_score <- pull(filter(diff_score, iteration == step), max_score) - | |
# pull(filter(diff_score, iteration == (step - 1)), max_score) | |
prev_run <- this_run | |
if(step >= iter){ | |
message("Maximum iterations reached!") | |
break | |
} | |
message("Iteration number: ", step) | |
} | |
result | |
} | |
genetic_alg(stalagmite2, evaluate_function = evaluate_candidates, iter = 10) | |
genetic_alg(stalagmite2, evaluate_function = evaluate_candidates_pmap, iter = 10) | |
opt_stalagmite2 <- rerun(2, genetic_alg(objective_function = stalagmite2, iter = 10)) | |
genetic_alg(second_example2, iter = 20, upper_bound = 10, lower_bound = 0) | |
genetic_alg(third_example, iter = 20, upper_bound = 10, lower_bound = 0) | |
genetic_alg(fourth_example, iter = 20, upper_bound = 10, lower_bound = 0) | |
genetic_alg(fourth_example2, iter = 20, upper_bound = 10, lower_bound = 0) | |
genetic_alg_base <- function(objective_function, | |
population_size = 100, | |
iter = 10, | |
upper_bound = 1, | |
lower_bound = 0){ | |
one_run <- function(previous_run_result){ | |
previous_run_result |> | |
select_parents_base() |> | |
crossover_base() |> | |
mutation_base() |> | |
(\(x) evaluate_candidates_base(objective_function, population = x))() | |
} | |
prev_run <- init_pop_base(objective_function, | |
pop_size = population_size, | |
upper_bound = upper_bound, | |
lower_bound = lower_bound) |> | |
(\(x) evaluate_candidates_base(objective_function, population = x))() |> | |
select_parents_base() |> | |
crossover_base() |> | |
mutation_base() |> | |
(\(x) evaluate_candidates_base(objective_function, population = x))() | |
#result <- prev_run[which.max(prev_run[, ncol(prev_run)]), ] | |
result <- prev_run[which.max(prev_run[, "scores"]), ] | |
iteration <- 1 | |
diff_score <- -1 | |
step <- 1 | |
while(step < iter){ | |
step <- step + 1 | |
this_run <- one_run(prev_run) | |
iteration <- step | |
this_run_result <- this_run[which.max(this_run[, "scores"]), ] | |
result <- rbind(result, | |
this_run_result) | |
prev_run <- this_run | |
if(step >= iter){ | |
message("Maximum iterations reached!") | |
break | |
} | |
message("Iteration number: ", step) | |
} | |
result | |
} | |
microbenchmark::microbenchmark( | |
genetic_alg_base(stalagmite2, iter = 10, upper_bound = 1, lower_bound = 0), | |
genetic_alg(stalagmite2, | |
evaluate_candidates, | |
iter = 10, upper_bound = 1, lower_bound = 0), | |
genetic_alg(stalagmite2, | |
evaluate_candidates_pmap, | |
iter = 10, upper_bound = 1, lower_bound = 0), | |
times = 5 | |
) | |
#Unit: milliseconds | |
#exp min lq mean median uq max neval | |
#base 74.76085 81.26259 98.64852 85.78427 89.7189 161.716 5 | |
#candidates_pmap 13107.68274 13173.23539 13231.69735 13257.03845 13286.4004 13334.130 5 | |
#candidates_rowwise 21702.32029 21740.35980 21846.23679 21741.14040 21773.5148 22273.849 5 | |
#library(profvis) | |
# | |
#profvis({ | |
# genetic_alg(stalagmite2, iter = 10, upper_bound = 1, lower_bound = 0) | |
#} | |
#) | |
# | |
#profvis({ | |
# genetic_alg_base(stalagmite2, iter = 10, upper_bound = 1, lower_bound = 0) | |
#} | |
#) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment