Skip to content

Instantly share code, notes, and snippets.

@FavioVazquez
Last active January 22, 2020 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FavioVazquez/bab4fbf9c39aade9b92dbbea95127cec to your computer and use it in GitHub Desktop.
Save FavioVazquez/bab4fbf9c39aade9b92dbbea95127cec to your computer and use it in GitHub Desktop.
from sparkdl import KerasTransformer
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
# Generate random input data
num_features = 10
num_examples = 100
input_data = [{"features" : np.random.randn(num_features).astype(float).tolist()} for i in range(num_examples)]
schema = StructType([ StructField("features", ArrayType(FloatType()), True)])
input_df = spark.createDataFrame(input_data, schema)
# Create and save a single-hidden-layer Keras model for binary classification
# NOTE: In a typical workflow, we'd train the model before exporting it to disk,
# but we skip that step here for brevity
model = Sequential()
model.add(Dense(units=20, input_shape=[num_features], activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))
model_path = "simple-binary-classification"
model.save(model_path)
# Create transformer and apply it to our input data
transformer = KerasTransformer(inputCol="features", outputCol="predictions", modelFile=model_path)
final_df = transformer.transform(input_df)
@vivek-bombatkar
Copy link

I am using this lib, sparkdl-0.2.2
And still getting same error
`---> 22 import sparkdl.graph.utils as tfx
23 from sparkdl.transformers.keras_utils import KSessionWrap
24 from sparkdl.param import (

AttributeError: module 'sparkdl' has no attribute 'graph'
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment