Skip to content

Instantly share code, notes, and snippets.

@DemianD
Created December 12, 2018 16:27
Show Gist options
  • Save DemianD/cc4cb5b093d4353d41561f410ceb4ed7 to your computer and use it in GitHub Desktop.
Save DemianD/cc4cb5b093d4353d41561f410ceb4ed7 to your computer and use it in GitHub Desktop.
Partial solution guest lecture ML
import pyspark
from pyspark.ml.classification import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import IndexToString
from pyspark.ml.classification import LogisticRegression
spark = SparkSession.builder.appName('pandasToSparkDF').getOrCreate()
df = spark.read.csv('data/mushrooms.csv', header=True)
target_column = "label"
encoded_feature_columns = [column + "_index" for column in df.columns if column != "class"]
label_indexer = StringIndexer(inputCol="class", outputCol="label").fit(df)
df = label_indexer.transform(df)
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index").fit(df) for column in df.columns if column != "class"]
pipeline = Pipeline(stages=indexers)
df_r = pipeline.fit(df).transform(df)
assembler = VectorAssembler(
inputCols=encoded_feature_columns,
outputCol="features")
splits = df_r.randomSplit([0.8, 0.2])
training_df = splits[0];
test_df = splits[1];
print "The training set contains %d examples." % training_df.count()
print "The test set contains %d examples." % test_df.count()
print "The full data set contains %d examples." % df.count()
print "The training set contains %.2f%% of all examples." % (100.0 * training_df.count() / df.count())
lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
converter = IndexToString(inputCol="prediction", outputCol="predictedLabel", labels=label_indexer.labels)
pipeline = Pipeline(stages=[assembler, lr, converter])
model = pipeline.fit(training_df)
predictions = model.transform(test_df)
predictions.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment