Skip to content

Instantly share code, notes, and snippets.

@cartershanklin
Created April 21, 2020 16:43
Show Gist options
  • Save cartershanklin/68caa03132bac90fb33199b282b0bef4 to your computer and use it in GitHub Desktop.
Save cartershanklin/68caa03132bac90fb33199b282b0bef4 to your computer and use it in GitHub Desktop.
import argparse
import oci
import os
import pathlib
import pandas as pd
from pyspark import SparkConf
from pyspark.sql import SparkSession
from sklearn import svm, preprocessing
from sklearn.model_selection import GridSearchCV as GridSearchCVNative
from spark_sklearn import GridSearchCV as GridSearchCVSpark
def main():
use_spark = True
oci_path = "oci://sample-data@paasdevssstest/agaricus-lepiota.csv"
local_path = os.path.join(
pathlib.Path(__file__).parent.absolute(), "agaricus-lepiota.csv"
)
# Set up Spark.
conf = SparkConf()
# Check to see if we're in Data Flow or not.
if os.environ.get("HOME") == "/home/dataflow":
mode = "cluster"
path = oci_path
print("Running in cluster mode")
else:
mode = "local"
path = oci_path
oci_config = oci.config.from_file()
conf.set("fs.oci.client.auth.tenantId", oci_config["tenancy"])
conf.set("fs.oci.client.auth.userId", oci_config["user"])
conf.set("fs.oci.client.auth.fingerprint", oci_config["fingerprint"])
conf.set("fs.oci.client.auth.pemfilepath", oci_config["key_file"])
conf.set(
"fs.oci.client.hostname",
"https://objectstorage.{0}.oraclecloud.com".format(oci_config["region"]),
)
spark_session = (
SparkSession.builder.appName("svc_mushroom").config(conf=conf).getOrCreate()
)
print("Running in {} mode".format(mode))
spark_context = spark_session.sparkContext
# Handle arguments.
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", help="File Path", default=path)
args = parser.parse_args()
assert args.path is not None, "Need -p / --path"
X, y = load_mushroom_dataframes(args.path, spark_session)
svr = svm.SVC(gamma="auto")
parameters = {"kernel": ("linear", "rbf"), "C": range(1, 3), "shrinking": [False]}
parameters = {"kernel": ("linear", "rbf"), "C": range(1, 10), "shrinking": [False, True]}
if use_spark:
clf = GridSearchCVSpark(spark_context, svr, parameters)
else:
clf = GridSearchCVNative(svr, parameters)
clf.fit(X, y)
new_dataframe = pd.DataFrame(clf.cv_results_)
print(pd.DataFrame(new_dataframe).to_csv())
def load_mushroom_csv(path, spark_context):
print("Reading data from " + path)
if path.startswith("/"):
with open(path, "rt") as fd:
return pd.read_csv(fd)
else:
spark_df = spark_context.read.csv(path, header=True)
return spark_df.toPandas()
def load_mushroom_dataframes(path, spark_session):
unencoded_data = load_mushroom_csv(path, spark_session)
everything = set(
[
"class",
"cap-shape",
"cap-surface",
"cap-color",
"bruises",
"odor",
"gill-attachment",
"gill-spacing",
"gill-size",
"gill-color",
"stalk-shape",
"stalk-root",
"stalk-surface-above-ring",
"stalk-surface-below-ring",
"stalk-color-above-ring",
"stalk-color-below-ring",
"veil-type",
"veil-color",
"ring-number",
"ring-type",
"spore-print-color",
"population",
"habitat",
]
)
# Convert categories to numbers within the dataframe.
data = unencoded_data.copy()
le = preprocessing.LabelEncoder()
for i in range(data.shape[1]):
data.iloc[:, i] = data.iloc[:, i].fillna("")
data.iloc[:, i] = le.fit_transform(data.iloc[:, i])
# Set up the data to predict.
predict_attribute = "class"
independent_vars = everything
independent_vars.remove(predict_attribute)
X = data[list(independent_vars)]
y = data[predict_attribute]
return X, y
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment