Skip to content

Instantly share code, notes, and snippets.

@icexelloss
Created September 5, 2018 21:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save icexelloss/845beb3d0d6bfc3d51b3c7419edf0dcb to your computer and use it in GitHub Desktop.
Save icexelloss/845beb3d0d6bfc3d51b3c7419edf0dcb to your computer and use it in GitHub Desktop.
df = spark.range(0, 1000 * 1000).toDF('v')
df.cache()
df.count()
from pyspark.sql import Window
w = Window.rowsBetween(-1000, 0)
from pyspark.sql.functions import sum, mean, count
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def my_numpy_udf(v):
# v is numpy.ndrray
return v.mean()
from pyspark.sql.functions import pandas_udf, PandasUDFType
import numba
@numba.njit
def numba_mean(v):
s = 0
c = 0
for i in v:
s += i
c += 1
return s / c
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def my_numba_udf(v):
return numba_mean(v)
df.withColumn('v_sum', my_numpy_udf(df['v']).over(w)).agg(sum('v_sum')).show()
df.withColumn('v_sum', my_numba_udf(df['v']).over(w)).agg(sum('v_sum')).show()
df.withColumn('v_sum', mean('v').over(w)).agg(sum('v_sum')).show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment