Skip to content

Instantly share code, notes, and snippets.

@dfalbel
Last active April 8, 2018 09:58
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save dfalbel/a5d63d6bffe683072cc4781d7c8420ff to your computer and use it in GitHub Desktop.
library(readr)
library(keras)
library(purrr)
FLAGS <- flags(
flag_integer("vocab_size", 50000),
flag_integer("max_len_padding", 20),
flag_integer("embedding_size", 256),
flag_numeric("regularization", 0.0001),
flag_integer("seq_embedding_size", 512)
)
df <- read_tsv("quora_duplicate_questions.tsv")
tokenizer <- text_tokenizer(num_words = FLAGS$vocab_size)
fit_text_tokenizer(tokenizer, x = c(df$question1, df$question2))
question1 <- texts_to_sequences(tokenizer, df$question1)
question2 <- texts_to_sequences(tokenizer, df$question2)
question1 <- pad_sequences(question1, maxlen = FLAGS$max_len_padding, value = FLAGS$vocab_size + 1)
question2 <- pad_sequences(question2, maxlen = FLAGS$max_len_padding, value = FLAGS$vocab_size + 1)
# keras model
input1 <- layer_input(shape = c(FLAGS$max_len_padding))
input2 <- layer_input(shape = c(FLAGS$max_len_padding))
embedding <- layer_embedding(
input_dim = FLAGS$vocab_size + 2,
output_dim = FLAGS$embedding_size,
input_length = FLAGS$max_len_padding,
embeddings_regularizer = regularizer_l2(l = FLAGS$regularization)
)
seq_emb <- layer_lstm(
units = FLAGS$seq_embedding_size,
recurrent_regularizer = regularizer_l2(l = FLAGS$regularization)
)
vector1 <- embedding(input1) %>%
seq_emb()
vector2 <- embedding(input2) %>%
seq_emb()
out <- layer_dot(list(vector1, vector2), axes = 1) %>%
layer_dense(1, activation = "sigmoid")
model <- keras_model(list(input1, input2), out)
model %>% compile(
optimizer = "adam",
loss = "binary_crossentropy",
metrics = list(
acc = metric_binary_accuracy
)
)
set.seed(1817328)
val_sample <- sample.int(nrow(question1), size = 0.1*nrow(question1))
model %>%
fit(
list(question1[-val_sample,], question2[-val_sample,]),
df$is_duplicate[-val_sample],
batch_size = 128,
epochs = 30,
validation_data = list(
list(question1[val_sample,], question2[val_sample,]), df$is_duplicate[val_sample]
),
callbacks = list(
callback_early_stopping(patience = 5),
callback_reduce_lr_on_plateau(patience = 3)
)
)
save_model_hdf5(model, "model-question-pairs.hdf5", include_optimizer = TRUE)
save_text_tokenizer(tokenizer, "tokenizer-question-pairs.hdf5")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment