Skip to content

Instantly share code, notes, and snippets.

@davidhughhenrymack
Last active December 24, 2017 15:31
Show Gist options
  • Save davidhughhenrymack/6667b6bd14f950a14334cbd54954a857 to your computer and use it in GitHub Desktop.
Save davidhughhenrymack/6667b6bd14f950a14334cbd54954a857 to your computer and use it in GitHub Desktop.
A proposal for adding type-based symbolic shapes to keras
# This is a very rough early draft of something I think would help speed up my coding
# in Keras. I spend a reasonable amount of time reading the source code to work out
# how X method treats different dimensions, which largely seems to be by convention and
# only semi-documented. As someone new to these libraries, it'd help me a lot to make this
# explicit
# The rough idea is two fold:
# - Have "nanotypes" representing commonly used dimensions e.g. batch size
# - Allow Dimensions, and Shapes (a list of dimensions) to be type enforced
# - Let the programmer specify how operations are to be performed in terms of dimension mapping
# I'm new to python and keras, and am still thinking out the practical implementation of this,
# so apologies for all my gross mistakes!
from keras.dimensions import BatchSize, Dim, NewDim
from keras.shapes import Shape, NewShape
from keras.layer import *
import keras.backend as K
# For some imaginary RNN network
SequenceLength = NewDim('SequenceLength')
WordLength = NewDim('WordLength')
RNNInputShape = Shape[BatchSize, SequenceLength, WordLength]
data = Input<RNNInputShape>(name='input')
def mySpecialRNN<Output_Width:Dim>(input:RNNInputShape):
rnn = SimpleRNN<Output_Width>()
n = rnn(input)
# ...
n = Reshape(Shape[SequenceLength, HalfWord, 2])(n)
def combine_some_tensors(x:MinShape[BatchSize, HalfWord], y: MinShape[BatchSize, HalfWord, _], z):
# Reshape, with type checking so we know it's possible
y = K.reshape[BatchSize, _, HalfWord](y)
# ... later, we have some more tensors to play with and are definin
# I'd like to multiply along just the HalfWord axis and scan the BatchSize axis
r = K.batch_dot[[BatchSize], HalfWord](x,y)
# Find the mean along the batch axis please
r = K.mean[BatchSize](r)
return r
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment