Created
December 27, 2021 14:04
-
-
Save jtrecenti/3c2cd573a184576a27e8ffe5ce04498a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# download and create dataset --------------------------------------------- | |
fs::dir_delete("data-raw/trt") | |
## create custom transformers using purrr::compose | |
# tr <- purrr::compose( | |
# captcha::captcha_transform_image, | |
# purrr::partial(torchvision::transform_resize, size = c(32, 192)), | |
# .dir = "forward" | |
# ) | |
captcha_ds <- captcha::captcha_dataset( | |
root = "data-raw/trt", | |
captcha = "trt2", | |
download = TRUE | |
) | |
# range(as.numeric(captcha_ds$data)) | |
# create train and validation data loaders -------------------------------- | |
set.seed(1) | |
ids <- seq_along(captcha_ds) | |
id_train <- sample(ids, .8 * length(captcha_ds)) | |
id_valid <- setdiff(ids, id_train) | |
captcha_dl_train <- torch::dataloader( | |
torch::dataset_subset(captcha_ds, id_train), | |
batch_size = 40, | |
shuffle = TRUE | |
) | |
captcha_dl_valid <- torch::dataloader( | |
torch::dataset_subset(captcha_ds, id_valid), | |
batch_size = 40 | |
) | |
# specify model ----------------------------------------------------------- | |
# model <- captcha::net_captcha | |
library(magrittr) | |
model <- captcha::net_captcha | |
# run model --------------------------------------------------------------- | |
fitted <- model |> | |
luz::setup( | |
loss = torch::nn_multilabel_soft_margin_loss(), | |
optimizer = torch::optim_adam, | |
metrics = list(captcha::captcha_accuracy()) | |
) |> | |
luz::set_hparams( | |
input_dim = dim(captcha_ds$data)[-1], | |
output_vocab_size = dim(captcha_ds$target)[3], | |
output_ndigits = dim(captcha_ds$target)[2], | |
vocab = captcha_ds$vocab, | |
dropout = c(0.5, 0.5), | |
dense_units = 800 | |
) |> | |
luz::set_opt_hparams( | |
lr = .01 | |
) |> | |
luz::fit( | |
captcha_dl_train, | |
valid_data = captcha_dl_valid, | |
epochs = 100 | |
) | |
plot(fitted) | |
# luz::luz_save(fitted, "data-raw/trt.pt") | |
# evaluate ---------------------------------------------------------------- | |
files_valid <- captcha_ds$files[id_valid] | |
pred <- captcha::decrypt(files_valid, fitted) | |
labs <- files_valid |> | |
basename() |> | |
tools::file_path_sans_ext() |> | |
stringr::str_extract("(?<=_)[0-9a-zA-Z]+") | |
mean(pred == labs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment