This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import findspark | |
findspark.init() | |
from pyspark.sql import SparkSession | |
spark = (SparkSession | |
.builder | |
.config('spark.executor.memory', '36g') | |
.getOrCreate()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
System | Runtime | LOC changed | |
---|---|---|---|
Python (single-node) | 3 hours | - | |
Dask | 11 minutes | 10 | |
Spark | 47 minutes | 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fitted = crossval.fit(taxi) | |
print(np.min(results.avgMetrics)) # min because metric is RMSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.ml.regression import LinearRegression | |
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder | |
from pyspark.ml.evaluation import RegressionEvaluator | |
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler, StandardScaler | |
from pyspark.ml.pipeline import Pipeline | |
indexers = [ | |
StringIndexer( | |
inputCol=c, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark.sql.functions as F | |
import pyspark.sql.types as T | |
taxi = taxi.withColumn('pickup_weekday', F.dayofweek(taxi.tpep_pickup_datetime).cast(T.DoubleType())) | |
taxi = taxi.withColumn('pickup_weekofyear', F.weekofyear(taxi.tpep_pickup_datetime).cast(T.DoubleType())) | |
taxi = taxi.withColumn('pickup_hour', F.hour(taxi.tpep_pickup_datetime).cast(T.DoubleType())) | |
taxi = taxi.withColumn('pickup_minute', F.minute(taxi.tpep_pickup_datetime).cast(T.DoubleType())) | |
taxi = taxi.withColumn('pickup_year_seconds', | |
(F.unix_timestamp(taxi.tpep_pickup_datetime) - | |
F.unix_timestamp(F.lit(datetime.datetime(2019, 1, 1, 0, 0, 0)))).cast(T.DoubleType())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
spark = SparkSession.builder.getOrCreate() | |
taxi = spark.read.csv('s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv', | |
header=True, | |
inferSchema=True, | |
timestampFormat='yyyy-MM-dd HH:mm:ss', | |
).sample(fraction=0.1, withReplacement=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask_ml.compose import ColumnTransformer | |
from dask_ml.preprocessing import StandardScaler, DummyEncoder, Categorizer | |
from dask_ml.model_selection import GridSearchCV | |
# Dask has slightly different way of one-hot encoding | |
pipeline = Pipeline(steps=[ | |
('categorize', Categorizer(columns=categorical_feat)), | |
('onehot', DummyEncoder(columns=categorical_feat)), | |
('scale', ColumnTransformer( | |
transformers=[('num', StandardScaler(), numeric_feat)], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dask.dataframe as dd | |
taxi = dd.read_csv( | |
's3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv', | |
parse_dates=['tpep_pickup_datetime', 'tpep_dropoff_datetime'], | |
).sample(frac=0.1, replace=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask.distributed import Client | |
from dask_saturn import SaturnCluster | |
cluster = SaturnCluster(n_workers=20) | |
client = Client(cluster) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
grid_search.fit(taxi[features], taxi[y_col]) | |
print(grid_search.best_score_) |