Skip to content

Instantly share code, notes, and snippets.

@benman1
Created October 22, 2019 08:14
Show Gist options
  • Save benman1/7a07cc288110fa2588d84fd925fe4b66 to your computer and use it in GitHub Desktop.
Save benman1/7a07cc288110fa2588d84fd925fe4b66 to your computer and use it in GitHub Desktop.
From one-hot encoding to Argmax using a keras layer (should work for tf>=1.13)
import numpy as np
from tensorflow.keras import backend as K
from tensorflow.python.keras.layers import InputSpec, Layer
class Argmax(Layer):
"""
Based on https://github.com/YerevaNN/R-NET-in-Keras/blob/master/layers/Argmax.py
"""
def __init__(self, axis=-1, **kwargs):
super(Argmax, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
def call(self, inputs, mask=None):
return K.argmax(inputs, axis=self.axis)
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
del input_shape[self.axis]
return tuple(input_shape)
def compute_mask(self, x, mask):
return None
def get_config(self):
config = {'axis': self.axis}
base_config = super(Argmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment