Skip to content

Instantly share code, notes, and snippets.

@rdblue
Created August 3, 2018 23:48
Show Gist options
  • Save rdblue/9848a00f49eaad6126fbbcfa1b039e19 to your computer and use it in GitHub Desktop.
Save rdblue/9848a00f49eaad6126fbbcfa1b039e19 to your computer and use it in GitHub Desktop.
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.yarn
import java.io.File
import java.nio.charset.StandardCharsets
import com.google.common.io.Files
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.spark.launcher.SparkAppHandle
class YarnPySparkSuite extends BaseYarnClusterSuite {
override def newYarnConfig(): YarnConfiguration = new YarnConfiguration()
test("PySpark memory limit") {
val (state, result, out) = runPySpark(
"""
|import resource
|def f():
| return resource.getrlimit(resource.RLIMIT_AS)
""".stripMargin,
clientMode = true,
extraConf = Map(
"spark.executor.pyspark.memory" -> "1g",
"spark.python.worker.reuse" -> "false"),
extraEnv = Map.empty)
assert(state === SparkAppHandle.State.FINISHED,
s"Application stdout:\n$out")
assert(result === "(1024, 1024)")
}
private def runPySpark(
pythonMod: String,
clientMode: Boolean,
extraConf: Map[String, String] = Map(),
extraEnv: Map[String, String] = Map()): (SparkAppHandle.State, String, String) = {
val modFile = new File(tempDir, "mod.py")
Files.write(pythonMod, modFile, StandardCharsets.UTF_8)
val primaryPyFile = new File(tempDir, "test.py")
Files.write(
"""
|import sys
|from pyspark.sql import SparkSession, SQLContext
|from pyspark.sql.functions import udf, col
|
|from mod import f
|
|if len(sys.argv) != 2:
| print >> sys.stderr, "Usage: test.py [result file]"
| exit(-1)
|
|spark = SparkSession.builder.getOrCreate()
|
|def run_f(*args):
| return repr(f())
|
|df = spark.createDataFrame([(1,)], ["id"])
|results = df.withColumn("result", udf(run_f)(col("id"))).collect()
|result = results[0].asDict()["result"]
|
|with open(sys.argv[1], 'w') as result_file:
| result_file.write(result)
|
|spark.stop()
""".stripMargin, primaryPyFile, StandardCharsets.UTF_8)
// When running tests, let's not assume the user has built the assembly module, which also
// creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the
// needed locations.
val sparkHome = sys.props("spark.test.home")
val pythonPath = Seq(
s"$sparkHome/python/lib/py4j-0.10.7-src.zip",
s"$sparkHome/python")
val extraEnvVars = Map(
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
"PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv
val resultFile = File.createTempFile("result", null, tempDir)
val stdoutFile = File.createTempFile("output", null, tempDir)
val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath,
appArgs = Seq(resultFile.getAbsolutePath),
sparkArgs = Seq("--py-files" -> modFile.getAbsolutePath),
extraEnv = extraEnvVars,
extraConf = extraConf,
outFile = Some(stdoutFile))
val out = Files.toString(stdoutFile, StandardCharsets.UTF_8)
val result = Files.toString(resultFile, StandardCharsets.UTF_8)
(finalState, result, out)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment