from keras import backend | |
from keras.backend import numpy_backend | |
import numpy as np | |
import tensorflow as tf | |
class NPTF(object): | |
def __getattr__(self, name): | |
if name in dir(numpy_backend) and name in dir(backend): | |
k_symbol = getattr(backend, name) | |
np_symbol = getattr(numpy_backend, name) | |
else: | |
raise ValueError('Unknown symbol:', name) | |
def wrapped(*args, **kwargs): | |
if args and isinstance(args[0], (np.ndarray, float, int)): | |
return np_symbol(*args, **kwargs) | |
return k_symbol(*args, **kwargs) | |
return wrapped | |
X = NPTF() | |
# Use like you'd use a Keras backend | |
def dual_rmse(x, y): | |
return X.sqrt(X.sum(X.square(x - y), axis=-1)) | |
print(dual_rmse(tf.zeros((3, 4)), tf.ones((3, 4)))) | |
print(dual_rmse(np.zeros((3, 4)), np.ones((3, 4)))) | |
# Bonus points: also works with Theano, CNTK, MXNet |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment