Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save philippmuench/6b9bbb9f9f987ab22efb573f9f19160f to your computer and use it in GitHub Desktop.
Save philippmuench/6b9bbb9f9f987ab22efb573f9f19160f to your computer and use it in GitHub Desktop.
train for wavenet binary target
#' @title Trains a (mostly) LSTM model on genomic data. Designed for developing genome based language models (GenomeNet)
#'
#' @description
#' Depth and number of neurons per layer of the netwok can be specified. First layer can be a Convolutional Neural Network (CNN) that is designed to capture codons.
#' If a path to a folder where FASTA files are located is provided, batches will ge generated using an external generator which
#' is recommended for big training sets. Alternative, a dataset can be supplied that holds the preprocessed batches (generated by \code{preprocessSemiRedundant()})
#' and keeps them in RAM. Supports also training on instances with multiple GPUs and scales linear with number of GPUs present.
#' @param train_type Either "lm" for language model, "label_header" or "label_folder". Language model is trained to predict next character in sequence.
#' label_header/label_folder are trained to predict a corresponding class, given a sequence as input. If "label_header", class will be read from fasta headers.
#' If "label_folder", class will be read from folder, i.e. all fasta files in one folder must belong to the same class.
#' @param model A keras model.
#' @param built_model Call to a function that creates a model. \code{create_model_function} can be either "create_model_lstm_cnn" or "create_model_wavenet".
#' In \code{function_args} arguments of the corresponding can be specified, if no argument is given default values will be used.
#' Example: \code{built_model = list(create_model_function = "create_model_lstm_cnn", function_args = list(maxlen = 50, layer.size = 32, layers.lstm = 1)}
#' @param model_path Path to a pretrained model.
#' @param path Path to folder where individual or multiple FASTA files are located for training. If \code{train_type} is \code{label_folder}, should be a vector
#' containing a path for each class.
#' @param path.val Path to folder where individual or multiple FASTA files are located for validation.If \code{train_type} is \code{label_folder}, should be a vector
#' containing a path for each class.
#' @param dataset Dataframe holding training samples in RAM instead of using generator.
#' @param checkpoint_path Path to checkpoints folder.
#' @param validation.split Defines the fraction of the batches that will be used for validation (compared to size of training data).
#' @param run.name Name of the run (without file ending). Name will be used to identify output from callbacks.
#' @param batch.size Number of samples that are used for one network update.
#' @param epochs Number of iterations.
#' @param max.queue.size Queue on fit_generator().
#' @param lr.plateau.factor Factor of decreasing learning rate when plateau is reached.
#' @param patience Number of epochs waiting for decrease in loss before reducing learning rate.
#' @param cooldown Number of epochs without changing learning rate.
#' @param steps.per.epoch Number of batches to finish one epoch.
#' @param step Frequency of sampling steps.
#' @param randomFiles TRUE/FALSE go through files sequentially or shuffle beforehand.
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector.
#' @param initial_epoch Epoch at which to start training, set to 0 if no \code{model_path} argument is given. Note that network
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds.
#' @param tensorboard.log Path to tensorboard log directory.
#' @param save_best_only Only save model that improved on best val_loss score.
#' @param compile Whether to compile the model after loading.
#' @param solver Optimization method, options are "adam", "adagrad", "rmsprop" or "sgd". Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE.
#' Otherwise solver is determined when model is created.
#' @param learning.rate Learning rate for optimizer. Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE.
#' Otherwise learning rate is determined when model is created.
#' @param seed Sets seed for set.seed function, for reproducible results when using \code{randomFiles} or \code{shuffleFastaEntries}
#' @param shuffleFastaEntries Logical, shuffle entries in file.
#' @param output List of optional outputs, no output if none is TRUE.
#' @param tb_images Boolean, whether to show plots in tensorboard. Note this doubles the time needed for validation step.
#' @param format File format, "fasta" or "fastq".
#' @param fileLog Write name of files to csv file if path is specified.
#' @param labelVocabulary Character vector of possible targets. Targets outside \code{labelVocabulary} will get discarded.
#' @param numberOfFiles Use only specified number of files, ignored if greater than number of files in corpus.dir.
#' @param reverseComplements Logical, half of batch contains sequences and other its reverse complements. Reverse complement
#' is given by reversed order of sequence and switching A/T and C/G. \code{batch.size} argument has to be even, otherwise 1 will be added
#' to \code{batch.size}
#' @param wavenet_format Boolean. If true target is a sequence equal to input shifted by one position to the right (last target position is not in input).
#' If sequence is ACGT, maxlen = 3, first input corresponds to ACG and target to CGT.
#' @param target_middle Boolean, target is in middle of sequence.
#' @param reset_states Boolean, whether to reset hidden states of RNN layer at every new input file.
#' @param ambiguous_nuc How to handle nucleotides outside vocabulary, either "zero", "discard", "empirical" or "equal". If "zero", input gets encoded as zero vector;
#' if "equal" input is 1/length(vocabulary) x length(vocabulary). If "discard" samples containing nucleotides outside vocabulary get discarded.
#' If "empirical" use nucleotide distribution of current file.
#' @param percentage_per_file Numerical value between 0 and 1. Proportion of possible samples to take from one file. Takes samples from random subsequence.
#' @export
trainNetwork <- function(train_type = "lm",
model_path = NULL,
built_model = list(create_model_function = NULL, function_args = list()),
model = NULL,
path = NULL,
path.val = NULL,
dataset = NULL,
checkpoint_path,
validation.split = 0.2,
run.name = "run",
batch.size = 64,
epochs = 10,
max.queue.size = 100,
lr.plateau.factor = 0.9,
patience = 20,
cooldown = 1,
steps.per.epoch = 1000,
step = 1,
randomFiles = FALSE,
initial_epoch = 0,
vocabulary = c("a", "c", "g", "t"),
tensorboard.log,
save_best_only = TRUE,
compile = TRUE,
learning.rate = NULL,
solver = NULL,
seed = c(1234, 4321),
shuffleFastaEntries = FALSE,
output = list(none = FALSE,
checkpoints =TRUE,
tensorboard = TRUE,
log = FALSE,
serialize_model = FALSE,
full_model = FALSE
),
tb_images = FALSE,
format = "fasta",
fileLog = NULL,
labelVocabulary = NULL,
numberOfFiles = NULL,
reverseComplements = FALSE,
wavenet_format = FALSE,
target_middle = FALSE,
reset_states = FALSE,
ambiguous_nuc = "zero",
percentage_per_file = NULL) {
stopifnot(train_type %in% c("lm", "label_header", "label_folder"))
stopifnot(ambiguous_nuc %in% c("zero", "equal", "discard", "empirical"))
if (is.null(built_model$create_model_function) + is.null(model) == 0) {
stop("Two models were specified. Set either model or built_model$create_model_function argument to NULL.")
}
if (train_type == "lm") {
labelGen <- FALSE
labelByFolder <- FALSE
}
if (train_type == "label_header") {
labelGen <- TRUE
labelByFolder <- FALSE
stopifnot(!is.null(labelVocabulary))
}
if (train_type == "label_folder") {
labelGen <- TRUE
labelByFolder <- TRUE
stopifnot(!is.null(labelVocabulary))
stopifnot(length(path) == length(labelVocabulary))
}
if (output$none) {
output$checkpoints <- FALSE
output$tensorboard <- FALSE
output$log <- FALSE
output$serialize_model <- FALSE
output$full_model <- FALSE
}
# set model arguments
if (!is.null(built_model[[1]])) {
if (built_model[[1]] == "create_model_lstm_cnn_target_middle") {
target_middle <- TRUE
wavenet_format <- FALSE
}
if (built_model[[1]] == "create_model_lstm_cnn") {
target_middle <- FALSE
wavenet_format <- FALSE
}
if (built_model[[1]] == "create_model_wavenet") {
target_middle <- TRUE
wavenet_format <- TRUE
}
new_arguments <- names(built_model[[2]])
default_arguments <- formals(built_model[[1]])
# overwrite default arguments
for (arg in new_arguments) {
default_arguments[arg] <- built_model[[2]][arg]
}
# create model
if (built_model[[1]] == "create_model_lstm_cnn") {
formals(create_model_lstm_cnn) <- default_arguments
model <- create_model_lstm_cnn()
}
if (built_model[[1]] == "create_model_lstm_cnn_target_middle") {
formals(create_model_lstm_cnn_target_middle) <- default_arguments
model <- create_model_lstm_cnn_target_middle()
}
if (built_model[[1]] == "create_model_wavenet") {
if (!wavenet_format) {
warning("Argument wavenet_format should be TRUE when using wavenet architecture.")
}
formals(create_model_wavenet) <- default_arguments
model <- create_model_wavenet()
}
}
# function arguments
argumentList <- as.list(match.call(expand.dots=FALSE))
label.vocabulary.size <- length(labelVocabulary)
vocabulary.size <- length(vocabulary)
# extract maxlen from model
if (!target_middle) {
maxlen <- model$input$shape[[2]]
} else {
maxlen <- model$input[[1]]$shape[[2]] + model$input[[2]]$shape[[2]]
}
if (labelByFolder) {
if (length(path) == 1) warning("Training with just one label")
}
if (output$checkpoints) {
## create folder for checkpoints using run.name
## filenames contain epoch, validation loss and validation accuracy
checkpoint_dir <- paste0(checkpoint_path, "/", run.name, "_checkpoints")
dir.create(checkpoint_dir, showWarnings = FALSE)
filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-val_loss{val_loss:.2f}-val_acc{val_acc:.3f}.hdf5")
}
# Check if fileLog is unique
if (!is.null(fileLog) && dir.exists(fileLog)) {
stop(paste0("fileLog entry is already present. Please give this file a unique name."))
}
# Check if run.name is unique
if (dir.exists(file.path(tensorboard.log, run.name)) & output$tensorboard) {
stop(paste0("Tensorboard entry '", run.name , "' is already present. Please give your run a unique name."))
}
# Load pretrained model
if (!is.null(model_path)) {
# epochs arguments can be misleading
if (!missing(initial_epoch)) {
if (initial_epoch >= epochs) {
stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.")
}
}
# extract initial_epoch from filename if no argument is given
if (is.null(initial_epoch)) {
epochFromFilename <- stringr::str_extract(model_path, "Ep.\\d+")
initial_epoch <- as.integer(substring(epochFromFilename, 4, nchar(epochFromFilename)))
if (initial_epoch >= epochs) {
stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.")
}
}
# load model
model <- keras::load_model_hdf5(model_path, compile = compile)
model$hparam <- reticulate::dict()
summary(model)
# extract maxlen
if (!target_middle) {
maxlen <- model$input$shape[[2]]
} else {
maxlen <- model$input[[1]]$shape[[2]] + model$input[[2]]$shape[[2]]
}
if (compile & (!is.null(learning.rate)|!is.null(solver))) {
message("Arguments for solver and learning rate will be ignored. Set compile to FALSE to use custom solver and learning rate.")
}
if (!compile) {
# choose optimization method
if (solver == "adam")
optimizer <-
keras::optimizer_adam(lr = learning.rate)
if (solver == "adagrad")
optimizer <-
keras::optimizer_adagrad(lr = learning.rate)
if (solver == "rmsprop")
optimizer <-
keras::optimizer_rmsprop(lr = learning.rate)
if (solver == "sgd")
optimizer <-
keras::optimizer_sgd(lr = learning.rate)
model %>% keras::compile(loss = "categorical_crossentropy",
optimizer = optimizer, metrics = c("acc", percentage_training_files_cb))
}
}
# if no dataset is supplied, external fasta generator will generate batches
if (is.null(dataset)) {
message("Starting fasta generator...")
# tempory file to log training data
removeLog <- FALSE
if (is.null(fileLog)) {
removeLog <- TRUE
fileLog <- tempfile(pattern = "", fileext = ".csv")
}
if (reset_states) {
fileLogVal <- tempfile(pattern = "", fileext = ".csv")
} else {
fileLogVal <- NULL
}
if (!labelGen) {
# generator for training
gen <- fastaFileGenerator(corpus.dir = path, batch.size = batch.size,
maxlen = maxlen, step = step, randomFiles = randomFiles,
vocabulary = vocabulary, seed = seed[1],
shuffleFastaEntries = shuffleFastaEntries, format = format,
fileLog = fileLog, reverseComplements = reverseComplements,
wavenet_format = wavenet_format, target_middle = target_middle,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
# generator for validation
gen.val <- fastaFileGenerator(corpus.dir = path.val, batch.size = batch.size,
maxlen = maxlen, step = step, randomFiles = randomFiles,
vocabulary = vocabulary, seed = seed[2],
shuffleFastaEntries = shuffleFastaEntries, format = format,
fileLog = fileLogVal, reverseComplements = FALSE,
wavenet_format = wavenet_format, target_middle = target_middle,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
if (tb_images) {
# TODO: check if gen_cb uses same data if max_samples_per_file != NULL
gen_cb <- fastaFileGenerator(corpus.dir = path.val, batch.size = batch.size,
maxlen = maxlen, step = step, randomFiles = randomFiles,
vocabulary = vocabulary, seed = seed[2],
shuffleFastaEntries = shuffleFastaEntries, format = format,
fileLog = NULL, reverseComplements = FALSE,
wavenet_format = wavenet_format, target_middle = target_middle,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
}
# label generator
} else {
# label by folder
if (labelByFolder) {
# initialize training generators
initializeGenerators(directories = path, format = format, batch.size = batch.size, maxlen = maxlen, vocabulary = vocabulary,
verbose = FALSE, randomFiles = randomFiles, step = step, showWarnings = FALSE, seed = seed[1],
shuffleFastaEntries = shuffleFastaEntries, numberOfFiles = numberOfFiles,
fileLog = fileLog, reverseComplements = reverseComplements, val = FALSE,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
# initialize validation generators
initializeGenerators(directories = path.val, format = format, batch.size = batch.size, maxlen = maxlen,
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step,
showWarnings = FALSE, seed = seed[2], shuffleFastaEntries = shuffleFastaEntries,
numberOfFiles = NULL, fileLog = fileLogVal, reverseComplements = FALSE, val = TRUE,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
gen <- labelByFolderGeneratorWrapper(val = FALSE, path = path)
gen.val <- labelByFolderGeneratorWrapper(val = TRUE, path = path.val)
if (tb_images) {
gen_cb <- labelByFolderGeneratorWrapper(val = TRUE, path = path.val)
}
} else {
# generator for training
gen <- fastaLabelGenerator(corpus.dir = path, format = format, batch.size = batch.size, maxlen = maxlen,
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step,
showWarnings = FALSE, seed = seed[1], shuffleFastaEntries = shuffleFastaEntries,
fileLog = fileLog, labelVocabulary = labelVocabulary, reverseComplements = reverseComplements,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
# generator for validation
gen.val <- fastaLabelGenerator(corpus.dir = path.val, format = format, batch.size = batch.size, maxlen = maxlen,
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step,
showWarnings = FALSE, seed = seed[2], shuffleFastaEntries = shuffleFastaEntries,
fileLog = fileLogVal, labelVocabulary = labelVocabulary, reverseComplements = FALSE,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
if (tb_images) {
gen_cb <- fastaLabelGenerator(corpus.dir = path.val, format = format, batch.size = batch.size, maxlen = maxlen,
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step,
showWarnings = FALSE, seed = seed[2], shuffleFastaEntries = shuffleFastaEntries,
fileLog = fileLogVal, labelVocabulary = labelVocabulary, reverseComplements = FALSE,
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file)
}
}
}
# callbacks
callbacks <- vector("list")
callbacks[[1]] <- reduce_lr_cb(patience = patience, cooldown = cooldown, lr.plateau.factor = lr.plateau.factor)
if (output$log) {
callbacks <- c(callbacks, log_cb(run.name))
}
if (output$tensorboard) {
# count files in path
num_train_files <- rep(0, length(path))
if (train_type != "label_folder" && endsWith(path, paste0(".", format))) {
num_train_files <- 1
} else {
for (k in 1:length(path)) {
if (endsWith(path[k], paste0(".", format))) {
num_train_files[k] <- 1
} else {
num_train_files[k] <- length(list.files(path[k], pattern = paste0(".", format)))
}
}
}
complete_tb <- tensorboard_complete_cb(default_arguments = default_arguments, model = model, tensorboard.log = tensorboard.log, run.name = run.name, train_type = train_type,
model_path = model_path, path = path, validation.split = validation.split, batch.size = batch.size, epochs = epochs,
max.queue.size = max.queue.size, lr.plateau.factor = lr.plateau.factor, patience = patience, cooldown = cooldown,
steps.per.epoch = steps.per.epoch, step = step, randomFiles = randomFiles, initial_epoch = initial_epoch, vocabulary = vocabulary,
learning.rate = learning.rate, shuffleFastaEntries = shuffleFastaEntries, labelVocabulary = labelVocabulary, solver = solver,
numberOfFiles = numberOfFiles, reverseComplements = reverseComplements, wavenet_format = wavenet_format,
create_model_function = built_model$create_model_function, vocabulary.size = vocabulary.size, gen_cb = gen_cb, argumentList = argumentList,
maxlen = maxlen, labelGen = labelGen, labelByFolder = labelByFolder, label.vocabulary.size = label.vocabulary.size, tb_images = tb_images,
target_middle = target_middle, num_train_files = num_train_files, fileLog = fileLog, percentage_per_file = percentage_per_file)
callbacks <- c(callbacks, complete_tb)
}
if (output$checkpoints) {
callbacks <- c(callbacks, checkpoint_cb(filepath = filepath_checkpoints, save_weights_only = TRUE,
save_best_only = save_best_only))
}
if (reset_states) {
callbacks <- c(callbacks, reset_states_cb(fileLog = fileLog, fileLogVal = fileLogVal))
}
# training
message("Start training ...")
history <-
model %>% keras::fit_generator(
generator = gen,
validation_data = gen.val,
validation_steps = ceiling(steps.per.epoch * validation.split),
steps_per_epoch = steps.per.epoch,
max_queue_size = max.queue.size,
epochs = epochs,
initial_epoch = initial_epoch,
callbacks = callbacks,
verbose = 1
)
} else {
message("Start training ...")
history <- model %>% keras::fit(
dataset$X,
dataset$Y,
batch_size = batch.size,
validation_split = validation.split,
epochs = epochs)
}
if (removeLog) {
file.remove(fileLog)
}
# save final model
message("Training done.")
if (output$serialize_model) {
Rmodel <-
keras::serialize_model(model, include_optimizer = TRUE)
save(Rmodel, file = paste0(run.name, "_full_model.Rdata"))
}
if (output$full_model) {
keras::save_model_hdf5(
model,
paste0(run.name, "_full_model.hdf5"),
overwrite = TRUE,
include_optimizer = TRUE
)
}
return(history)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment