Skip to content

Instantly share code, notes, and snippets.

@pdmack
Last active December 20, 2017 00:04
Show Gist options
  • Save pdmack/f0a6e1e02be098ad5d3a4b47e28ab0ba to your computer and use it in GitHub Desktop.
Save pdmack/f0a6e1e02be098ad5d3a4b47e28ab0ba to your computer and use it in GitHub Desktop.
Adaptation of PySpark/MongoDB example illustrating dynamic jar loading: works with spark-submit AND python interpreter
from __future__ import print_function
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
mongo_jars = ["/root/spark2.2/mongo-spark-connector_2.11-2.2.1.jar","/root/spark2.2/mongo-java-driver-3.5.0.jar"]
spark = SparkSession \
.builder \
.appName("PySpark dynamic jar loading example") \
.config("spark.mongodb.input.uri", "mongodb://127.0.0.1/test.coll") \
.config("spark.mongodb.output.uri", "mongodb://127.0.0.1/test.coll") \
.getOrCreate()
# get the SparkContext singleton from the JVM (not the pyspark API)
context = spark._jvm.org.apache.spark.SparkContext.getOrCreate()
# get the MutableURLClassLoader from the JVM
loader = spark._jvm.Thread.currentThread().getContextClassLoader()
# load jars for our driver AND executors
url = spark._jvm.java.net.URL("file:"+mongo_jars[0])
loader.addURL(url)
context.addJar(mongo_jars[0])
url = spark._jvm.java.net.URL("file:"+mongo_jars[1])
loader.addURL(url)
context.addJar(mongo_jars[1])
urls = loader.getURLs()
for p in urls:
print(p)
logger = spark._jvm.org.apache.log4j
logger.LogManager.getRootLogger().setLevel(logger.Level.FATAL)
# Save some data
characters = spark.createDataFrame([("Bilbo Baggins", 50), ("Gandalf", 1000), ("Thorin", 195), ("Balin", 178), ("Kili", 77), ("Dwalin", 169), ("Oin", 167), ("Gloin", 158), ("Fili", 82), ("Bombur", None)], ["name", "age"])
characters.write.format("com.mongodb.spark.sql").mode("overwrite").save()
# print the schema
print("Schema:")
characters.printSchema()
# read from MongoDB collection
df = spark.read.format("com.mongodb.spark.sql").load()
# SQL
df.registerTempTable("temp")
centenarians = spark.sql("SELECT name, age FROM temp WHERE age >= 100")
print("Centenarians:")
centenarians.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment