Skip to content

Instantly share code, notes, and snippets.

@philippmuench
Last active December 23, 2019 16:51
Show Gist options
  • Save philippmuench/4832a25f8a2693b90e7a9e96edfa9bfd to your computer and use it in GitHub Desktop.
Save philippmuench/4832a25f8a2693b90e7a9e96edfa9bfd to your computer and use it in GitHub Desktop.
WaveNet genomic
trainMinimalFunctionalAPI <- function(path = "example_files/fasta") {
library(wavenet)
message("Initialize model! This can take a few minutes.")
maxlen <- 1000
input <- keras::layer_input(batch_shape = c(64, maxlen, 6))
# https://github.com/ibab/tensorflow-wavenet/blob/master/wavenet/ops.py#L46
first <- keras::layer_conv_1d(
object = input,
filters = 32,
kernel_size = 2,
padding = "causal",
use_bias = FALSE
)
skip_connections <- NULL
residual_blocks <- 2^rep(1:8, 3)
if (length(residual_blocks) == 1) {
dilation_rates <- 2^seq_len(residual_blocks)
} else {
dilation_rates <- residual_blocks
}
for (i in dilation_rates) {
out <- layer_wavenet_dilated_causal_convolution_1d(
first,
filters = 32,
kernel_size = 32,
dilation_rate = i
)
x <- out[[1]]
s <- out[[2]]
skip_connections <- append(skip_connections, s)
}
out1 = keras::layer_add(skip_connections)
out2 = keras::layer_activation(object = out1,
activation = "relu")
out3 = keras::layer_conv_1d(object = out2,
filters = 32/2L,
kernel_size = 1,
activation = "relu",
use_bias = FALSE)
pooling = keras::layer_global_max_pooling_1d(object = out3)
dense = keras::layer_dense(object = pooling, 6)
output = keras::layer_activation(object = dense, "softmax")
model <- keras::keras_model(input, output)
summary(model)
model %>% keras::compile(loss = "categorical_crossentropy",
optimizer = "adam",
metrics = c("acc"))
gen <-
fastaFileGenerator(
corpus.dir = path,
batch.size = 64,
maxlen = maxlen,
step = 1,
randomFiles = FALSE,
seqStart = "l",
seqEnd = "l",
withinFile = "p",
vocabulary = c("l", "p", "a", "c", "g", "t")
)
message("Start training ...")
history <-
model %>% keras::fit_generator(
generator = gen,
steps_per_epoch = 200,
max_queue_size = 200,
epochs = 50
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment