Skip to content

Instantly share code, notes, and snippets.

@fchollet
Created April 6, 2019 19:24
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save fchollet/c5802dd2660a337a6590ff0cca54589f to your computer and use it in GitHub Desktop.
Save fchollet/c5802dd2660a337a6590ff0cca54589f to your computer and use it in GitHub Desktop.
from keras import backend
from keras.backend import numpy_backend
import numpy as np
import tensorflow as tf
class NPTF(object):
def __getattr__(self, name):
if name in dir(numpy_backend) and name in dir(backend):
k_symbol = getattr(backend, name)
np_symbol = getattr(numpy_backend, name)
else:
raise ValueError('Unknown symbol:', name)
def wrapped(*args, **kwargs):
if args and isinstance(args[0], (np.ndarray, float, int)):
return np_symbol(*args, **kwargs)
return k_symbol(*args, **kwargs)
return wrapped
X = NPTF()
# Use like you'd use a Keras backend
def dual_rmse(x, y):
return X.sqrt(X.sum(X.square(x - y), axis=-1))
print(dual_rmse(tf.zeros((3, 4)), tf.ones((3, 4))))
print(dual_rmse(np.zeros((3, 4)), np.ones((3, 4))))
# Bonus points: also works with Theano, CNTK, MXNet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment