Skip to content

Instantly share code, notes, and snippets.

@javierluraschi
Created May 5, 2020 18:54
Show Gist options
  • Save javierluraschi/4110721f99c4ecb093e81feb6caa39c6 to your computer and use it in GitHub Desktop.
Save javierluraschi/4110721f99c4ecb093e81feb6caa39c6 to your computer and use it in GitHub Desktop.
Script to perform distributed deep learning in Spark from R
library(sparklyr)
sc <- spark_connect(master = "yarn", spark_home = "/usr/lib/spark/", config = list(spark.dynamicAllocation.enabled = FALSE, `sparklyr.shell.executor-cores` = 8, `sparklyr.shell.num-executors` = 3, sparklyr.apply.env.WORKON_HOME = "/tmp/.virtualenvs"))

sdf_len(sc, 3, repartition = 3) %>%
  spark_apply(function(df, barrier) {
    tryCatch({
      library(tensorflow)
      library(keras)
      
      Sys.setenv(TF_CONFIG = jsonlite::toJSON(list(
        cluster = list(worker = paste(gsub(":[0-9]+$", "", barrier$address), 8020 + seq_along(barrier$address), sep = ":")),
        task = list(type = 'worker', index = barrier$partition)
      ), auto_unbox = TRUE))
      
      if (is.null(tf_version())) install_tensorflow()
      
      strategy <- tf$distribute$experimental$MultiWorkerMirroredStrategy()
      
      num_workers <- 3L
      batch_size <- 64L * num_workers
      
      mnist <- dataset_mnist()
      x_train <- mnist$train$x
      y_train <- mnist$train$y
      
      x_train <- array_reshape(x_train, c(nrow(x_train), 28, 28, 1))
      x_train <- x_train / 255
      
      with (strategy$scope(), {
        model <- keras_model_sequential() %>%
          layer_conv_2d(
            filters = 32,
            kernel_size = 3,
            activation = 'relu',
            input_shape = c(28, 28, 1)
          ) %>%
          layer_max_pooling_2d() %>%
          layer_flatten() %>%
          layer_dense(units = 64, activation = 'relu') %>%
          layer_dense(units = 10)
        
        model %>% compile(
          loss = tf$keras$losses$SparseCategoricalCrossentropy(from_logits = TRUE),
          optimizer = tf$keras$optimizers$SGD(learning_rate = 0.001),
          metrics = 'accuracy')
      })
      
      result <- model %>% fit(x_train, y_train, batch_size = batch_size, epochs = 3, steps_per_epoch = 5)
      
      model_file <- paste0("trained-", barrier$partition, ".hdf5")
      save_model_hdf5(model, model_file)
      
      paste0(result$metrics$accuracy, collapse = ",")
      # if (barrier$partition == 0) base64enc::base64encode(model_file) else ""
    }, error = function(e) e$message)
  }, barrier = TRUE, columns = c(address = "character")) %>%
  collect()


sdf_len(sc, 4, repartition = 4) %>%
  spark_apply(function(df, barrier) {
    as.character(jsonlite::toJSON(list(
      cluster = list(worker = paste(gsub(":[0-9]+$", "", barrier$address), 1000 + seq_along(barrier$address), sep = ":")),
      task = list(type = 'worker', index = barrier$partition)
    ), auto_unbox = TRUE))
  }, barrier = TRUE, columns = c(address = "character")) %>%
  collect()

sdf_len(sc, 3, repartition = 3) %>%
  spark_apply(function(df, barrier) {
    tryCatch({
      py_socket <- reticulate::import("socket")
      py_sock <- py_socket$socket(py_socket$AF_INET, py_socket$SOCK_STREAM)
      
      current_address <- barrier$address[[as.integer(barrier$partition) + 1]]
      parts <- strsplit(current_address, ":")[[1]]
      
      port_open <- py_sock$connect_ex(reticulate::tuple(parts[1], as.integer(parts[2])))
      py_sock$close()
      
      host_name <- py_socket$gethostname()
      host_ip <- py_socket$gethostbyname(host_name) 
      
      paste0(current_address, " - ", host_ip, ": ", port_open)
    }, error = function(e) e$message)
  }, barrier = TRUE, columns = c(address = "character")) %>%
  collect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment