Skip to content

Instantly share code, notes, and snippets.

@FavioVazquez
Last active January 22, 2020 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FavioVazquez/bab4fbf9c39aade9b92dbbea95127cec to your computer and use it in GitHub Desktop.
Save FavioVazquez/bab4fbf9c39aade9b92dbbea95127cec to your computer and use it in GitHub Desktop.
from sparkdl import KerasTransformer
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
# Generate random input data
num_features = 10
num_examples = 100
input_data = [{"features" : np.random.randn(num_features).astype(float).tolist()} for i in range(num_examples)]
schema = StructType([ StructField("features", ArrayType(FloatType()), True)])
input_df = spark.createDataFrame(input_data, schema)
# Create and save a single-hidden-layer Keras model for binary classification
# NOTE: In a typical workflow, we'd train the model before exporting it to disk,
# but we skip that step here for brevity
model = Sequential()
model.add(Dense(units=20, input_shape=[num_features], activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))
model_path = "simple-binary-classification"
model.save(model_path)
# Create transformer and apply it to our input data
transformer = KerasTransformer(inputCol="features", outputCol="predictions", modelFile=model_path)
final_df = transformer.transform(input_df)
@vhutse
Copy link

vhutse commented Dec 8, 2018

Hi @FavioVazquez, I'm having a problem importing the KerasTransformer into databricks. When passing the command:"
from sparkdl import KerasTransformer", I get the following error message:

"AttributeError: module 'sparkdl' has no attribute 'graph' "

Do you have any experience with this?

screenshot 2018-12-08 at 19 43 37

@ArnabRaxit
Copy link

Hi @FavioVazquez, I'm having a problem importing the KerasTransformer into databricks. When passing the command:"
from sparkdl import KerasTransformer", I get the following error message:

"AttributeError: module 'sparkdl' has no attribute 'graph' "

Do you have any experience with this?

screenshot 2018-12-08 at 19 43 37

@vhutse: I too got the same issue while trying this in python 3.6. Did you get a resolution?

@ronitshaw
Copy link

Initially I was getting the same error.
I did try to install different versions of sparkdl to see if that would fix the issue and also install different dependencies but the issue is not resolved at all.

Currently I getting the following error:

<ipython-input-36-17e37754773e> in <module>()
     23 # Create transformer and apply it to our input data
     24 transformer = KerasTransformer(inputCol="features", outputCol="predictions", modelFile=model_path)
---> 25 final_df = transformer.transform(input_df)

/Users/ronit_ibm/anaconda3/lib/python3.6/site-packages/pyspark/ml/base.py in transform(self, dataset, params)
    171                 return self.copy(params)._transform(dataset)
    172             else:
--> 173                 return self._transform(dataset)
    174         else:
    175             raise ValueError("Params must be a param map but got %s." % type(params))

/Users/ronit_ibm/Documents/temp_simple_moedel/spark-deep-learning-1.5.0/target/scala-2.11/spark-deep-learning-assembly-1.5.0-spark2.4.jar/sparkdl/transformers/keras_tensor.py in _transform(self, dataset)
     63                                     inputMapping={self.getInputCol(): inputTensorName},
     64                                     outputMapping={outputTensorName: self.getOutputCol()})
---> 65         return transformer.transform(dataset)

/Users/ronit_ibm/anaconda3/lib/python3.6/site-packages/pyspark/ml/base.py in transform(self, dataset, params)
    171                 return self.copy(params)._transform(dataset)
    172             else:
--> 173                 return self._transform(dataset)
    174         else:
    175             raise ValueError("Params must be a param map but got %s." % type(params))

/Users/ronit_ibm/Documents/temp_simple_moedel/spark-deep-learning-1.5.0/target/scala-2.11/spark-deep-learning-assembly-1.5.0-spark2.4.jar/sparkdl/transformers/tf_tensor.py in _transform(self, dataset)
    104         graph = tf.Graph()
    105         with tf.Session(graph=graph):
--> 106             analyzed_df = tfs.analyze(dataset)
    107             out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
    108             # Load graph

/Users/ronit_ibm/Documents/temp_simple_moedel/spark-deep-learning-1.5.0/target/scala-2.11/spark-deep-learning-assembly-1.5.0-spark2.4.jar/tensorframes/core.py in analyze(dframe)
    377     :return: a Spark DataFrame with metadata information embedded.
    378     """
--> 379     return DataFrame(_java_api().analyze(dframe._jdf), _sql)
    380 
    381 def append_shape(dframe, col, shape):

/Users/ronit_ibm/Documents/temp_simple_moedel/spark-deep-learning-1.5.0/target/scala-2.11/spark-deep-learning-assembly-1.5.0-spark2.4.jar/tensorframes/core.py in _java_api()
     33     # You cannot simply call the creation of the the class on the _jvm due to classloader issues
     34     # with Py4J.
---> 35     return _jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \
     36         .newInstance()
     37 

/Users/ronit_ibm/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/Users/ronit_ibm/anaconda3/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/Users/ronit_ibm/anaconda3/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o96.loadClass.
: java.lang.ClassNotFoundException: org.tensorframes.impl.DebugRowOps
	at java.net.URLClassLoader.findClass(URLClassLoader.java:382)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)

This is similar to the already existing discussion/issue:
databricks/tensorframes#49
@FavioVazquez, any leads on the above issues?

@taiwotman
Copy link

taiwotman commented Sep 26, 2019

The keras_image.py, where the error comes from, has been modified without the graph attribute.

@vivek-bombatkar
Copy link

I am using this lib, sparkdl-0.2.2
And still getting same error
`---> 22 import sparkdl.graph.utils as tfx
23 from sparkdl.transformers.keras_utils import KSessionWrap
24 from sparkdl.param import (

AttributeError: module 'sparkdl' has no attribute 'graph'
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment