Skip to content

Instantly share code, notes, and snippets.

@jtrecenti
Created December 27, 2021 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jtrecenti/3c2cd573a184576a27e8ffe5ce04498a to your computer and use it in GitHub Desktop.
Save jtrecenti/3c2cd573a184576a27e8ffe5ce04498a to your computer and use it in GitHub Desktop.
# download and create dataset ---------------------------------------------
fs::dir_delete("data-raw/trt")
## create custom transformers using purrr::compose
# tr <- purrr::compose(
# captcha::captcha_transform_image,
# purrr::partial(torchvision::transform_resize, size = c(32, 192)),
# .dir = "forward"
# )
captcha_ds <- captcha::captcha_dataset(
root = "data-raw/trt",
captcha = "trt2",
download = TRUE
)
# range(as.numeric(captcha_ds$data))
# create train and validation data loaders --------------------------------
set.seed(1)
ids <- seq_along(captcha_ds)
id_train <- sample(ids, .8 * length(captcha_ds))
id_valid <- setdiff(ids, id_train)
captcha_dl_train <- torch::dataloader(
torch::dataset_subset(captcha_ds, id_train),
batch_size = 40,
shuffle = TRUE
)
captcha_dl_valid <- torch::dataloader(
torch::dataset_subset(captcha_ds, id_valid),
batch_size = 40
)
# specify model -----------------------------------------------------------
# model <- captcha::net_captcha
library(magrittr)
model <- captcha::net_captcha
# run model ---------------------------------------------------------------
fitted <- model |>
luz::setup(
loss = torch::nn_multilabel_soft_margin_loss(),
optimizer = torch::optim_adam,
metrics = list(captcha::captcha_accuracy())
) |>
luz::set_hparams(
input_dim = dim(captcha_ds$data)[-1],
output_vocab_size = dim(captcha_ds$target)[3],
output_ndigits = dim(captcha_ds$target)[2],
vocab = captcha_ds$vocab,
dropout = c(0.5, 0.5),
dense_units = 800
) |>
luz::set_opt_hparams(
lr = .01
) |>
luz::fit(
captcha_dl_train,
valid_data = captcha_dl_valid,
epochs = 100
)
plot(fitted)
# luz::luz_save(fitted, "data-raw/trt.pt")
# evaluate ----------------------------------------------------------------
files_valid <- captcha_ds$files[id_valid]
pred <- captcha::decrypt(files_valid, fitted)
labs <- files_valid |>
basename() |>
tools::file_path_sans_ext() |>
stringr::str_extract("(?<=_)[0-9a-zA-Z]+")
mean(pred == labs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment