Skip to content

Instantly share code, notes, and snippets.

@gbraccialli
Last active September 20, 2018 13:51
Show Gist options
  • Save gbraccialli/d9301befd0c62bfeb58da3937045d0f8 to your computer and use it in GitHub Desktop.
Save gbraccialli/d9301befd0c62bfeb58da3937045d0f8 to your computer and use it in GitHub Desktop.
spark_scala_python_udf_battle
//scala create datasets
def randomStr(size: Int): String = {
import scala.util.Random
return Random.alphanumeric.take(size).mkString("")
}
val udfRandomStr = udf(randomStr _)
val dfRnd = (1 to 30000).toDF.repartition(3000)
val dfRnd2 = (1 to 10).toDF.withColumnRenamed("value", "value2")
//creates 2.8GB dataset with 300,000 rows
dfRnd.crossJoin(broadcast(dfRnd2)).withColumn("text", udfRandomStr(lit(10000))).withColumn("del1", udfRandomStr(lit(2))).withColumn("del2", udfRandomStr(lit(2))).write.mode("overwrite").save("randomDF")
import scala.util.Random
val dfRnd = (1 to 100000).toDF.repartition(3000)
val dfRnd2 = (1 to 2000).toDF.withColumnRenamed("value", "value2")
//creates 3.1GB dataset with 200,000,000 rows
dfRnd.crossJoin(broadcast(dfRnd2)).withColumn("lat", udf{(a: Any) => -90 + 180*Random.nextDouble}.apply($"value")).withColumn("lon", udf{(a: Any) => -180 + 360*Random.nextDouble}.apply($"value")).write.mode("overwrite").save("randomDF_geo")
#PYTHON TIMES 1000
df = spark.read.load("randomDF_geo").cache()
df.count()
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.functions import avg, udf, substring, col
from pyspark.sql.types import StringType, DoubleType
def times1000(field):
return field * 1000.00
udfTimes1000 = udf(times1000, DoubleType())
@pandas_udf('double', PandasUDFType.SCALAR)
def pandasUdf_times1000(field):
return field * 1000
#1.2 minutes
df.select(udfTimes1000(df.lat).alias("output")).agg(avg("output")).show()
#32 seconds
df.select(pandasUdf_times1000(df.lat).alias("output")).agg(avg("output")).show()
#PYTHON GEOHASH
df = spark.read.load("randomDF_geo").cache()
df.count()
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.functions import avg, udf, substring, col
from pyspark.sql.types import StringType, DoubleType
import geohash
def geohash_pyspark(lat, lon):
return geohash.encode(lat, lon)
udfGeohash = udf(geohash_pyspark, StringType())
@pandas_udf('string', PandasUDFType.SCALAR)
def geohash_pandas_udf(series_lat, series_lon):
df = pd.DataFrame({'lat': series_lat,'lon': series_lon})
return pd.Series(df.apply(lambda row: geohash.encode(row['lat'], row['lon']), axis=1))
df = spark.read.load("randomDF_geo").cache()
#2.7 minutes
df.select(udfGeohash(df.lat, df.lon).alias("geohash")).withColumn("first3", substring(col("geohash"), 1, 3)).groupBy("first3").count().show()
#23 minutes
df.select(geohash_pandas_udf(df.lat, df.lon).alias("geohash")).withColumn("first3", substring(col("geohash"), 1, 3)).groupBy("first3").count().show()
def strExtract(text, del1, del2):
start = text.find(del1)
end = text.find(del2, start)
if start > -1 and end > -1 and end > start+len(del1):
return text[start+len(del1):end]
else:
return "invalid"
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('string', PandasUDFType.SCALAR)
def pandasUdf(series_text, series_delim1, series_delim2):
outputs = []
row = 0
for text in series_text:
del1 = series_delim1[row]
del2 = series_delim2[row]
start = text.find(del1)
end = text.find(del2, start)
outputs.append(strExtract(text, del1, del2))
row += 1
return pd.Series(outputs)
@pandas_udf('string', PandasUDFType.SCALAR)
def pandasUdf2(series_text, series_delim1, series_delim2):
df = pd.DataFrame({'text': series_text,'delim1': series_delim1, 'delim2': series_delim2})
return pd.Series(df.apply(lambda row: strExtract(row['text'], row['delim1'], row['delim2']), axis=1))
def strExtractSplit(concat):
parts = str(concat).split("|")
return strExtract(parts[0], parts[1], parts[2])
@pandas_udf('string', PandasUDFType.SCALAR)
def pandasUdf3(fields_concat):
return fields_concat.apply(lambda row: strExtractSplit(row))
from pyspark.sql.types import StringType, DoubleType
from pyspark.sql.functions import udf, concat, lit
udfExtract = udf(strExtract, StringType())
df = spark.read.load("randomDF").cache()
df.count()
#16 seocnds
df.select(udfExtract(df.text, df.del1, df.del2).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show()
#16 seconds
df.select(pandasUdf(df.text, df.del1, df.del2).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show()
#16 seconds
df.select(pandasUdf2(df.text, df.del1, df.del2).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show()
#15 seconds
df.select(pandasUdf3(concat(df.text, lit('|'), df.del1, lit('|'), df.del2)).alias("output")).groupBy("output").count().orderBy("count", ascending=False).show()
//SCALA times 1000
val df = spark.read.load("randomDF_geo").cache()
df.count()
//1 second
df.select(udf{(a: Double) => a*1000.00}.apply($"lat").alias("output")).agg(avg("output")).show()
//SCALA GEOHASH
//spark-shell --packages com.github.davidmoten:geo:0.7.1
import com.github.davidmoten.geo._
def geohash(lat: Double, lon:Double): String = GeoHash.encodeHash(lat,lon)
def udfGeohash = udf(geohash _)
val df = spark.read.load("randomDF_geo").cache()
df.count()
//23 seconds
df.select(udfGeohash($"lat", $"lon").alias("geohash")).withColumn("first3", substring(col("geohash"), 1, 3)).groupBy("first3").count().show()
def strExtract(input: String, del1: String, del2: String): String = {
val start = input.indexOf(del1)
val end = input.indexOf(del2, start)
if (start > -1 && end > -1 && end > start+del1.length())
return input.substring(start+del1.length(),end)
else
return "invalid"
}
val udfExtract = udf(strExtract _)
val df = spark.read.load("randomDF").cache()
df.count()
//10 seconds
df.select(udfExtract(df.col("text"), df.col("del1"), df.col("del2")).alias("output")).groupBy("output").count().orderBy(desc("count")).show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment