Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
from keras import backend
from keras.backend import numpy_backend
import numpy as np
import tensorflow as tf
class NPTF(object):
def __getattr__(self, name):
if name in dir(numpy_backend) and name in dir(backend):
k_symbol = getattr(backend, name)
np_symbol = getattr(numpy_backend, name)
raise ValueError('Unknown symbol:', name)
def wrapped(*args, **kwargs):
if args and isinstance(args[0], (np.ndarray, float, int)):
return np_symbol(*args, **kwargs)
return k_symbol(*args, **kwargs)
return wrapped
X = NPTF()
# Use like you'd use a Keras backend
def dual_rmse(x, y):
return X.sqrt(X.sum(X.square(x - y), axis=-1))
print(dual_rmse(tf.zeros((3, 4)), tf.ones((3, 4))))
print(dual_rmse(np.zeros((3, 4)), np.ones((3, 4))))
# Bonus points: also works with Theano, CNTK, MXNet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.