Skip to content

Instantly share code, notes, and snippets.

@y-ich
Created December 11, 2017 06:39
Show Gist options
  • Save y-ich/db973fc2a1a1736adf5570a22e87902f to your computer and use it in GitHub Desktop.
Save y-ich/db973fc2a1a1736adf5570a22e87902f to your computer and use it in GitHub Desktop.
investigating convert error by padding
WEBDNN = /Users/yuji/OpenSources/webdnn/bin/convert_keras.py
PLUGIN = padding_channel_plugin.py
PLUGIN_OPTION = --plugin padding_channel_plugin.py
OUTDIR = output
$(OUTDIR)/kernels_webassembly.js: model.h5 $(PLUGIN)
DEBUG=1 CPLUS_INCLUDE_PATH=/usr/local/include/eigen3 python3 $(WEBDNN) $< $(PLUGIN_OPTION) --input_shape '(1,4,1)' --out $(OUTDIR)
model.h5: model.py padding_channel.py
python3 model.py
import sys
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Reshape, Lambda
from padding_channel import PaddingChannel
# インプットを直接繋げば変換に成功する。
def model_can_be_converted1():
x = Input(shape=(4, 1), name="x")
y = PaddingChannel(padding=1)(x)
return Model(inputs=[x], outputs=[y])
# 前段でReshapeすると変換に失敗する。
def model_cannot_be_converted1():
x = Input(shape=(4, 1), name="x")
tmp = Reshape((2, 2, 1))(x)
y = PaddingChannel(padding=1)(tmp)
return Model(inputs=[x], outputs=[y])
# Reshapeのみで変換できることを確認した。
def model_can_be_converted2():
x = Input(shape=(4, 1), name="x")
y = Reshape((2, 2, 1))(x)
return Model(inputs=[x], outputs=[y])
# 2次元で変換できることを確認した。
def model_can_be_converted2():
x = Input(shape=(2, 2, 1), name="x")
y = PaddingChannel(padding=1)(x)
return Model(inputs=[x], outputs=[y])
if __name__ == '__main__':
#model = model_can_be_converted2()
model = model_cannot_be_converted1()
model.save('model.h5')
import tensorflow as tf
from keras.engine import Layer
class PaddingChannel(Layer):
def __init__(self, padding=0, **kwargs):
super(PaddingChannel, self).__init__(**kwargs)
self.padding = padding
def call(self, x):
paddings = [[0, 0] for _ in range(len(x.shape))]
paddings[-1][1] = self.padding
return tf.pad(x, paddings)
def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) >= 2
assert input_shape[-1]
output_shape = list(input_shape)
output_shape[-1] = input_shape[-1] + self.padding
return tuple(output_shape)
import sys
sys.path.append('.')
from padding_channel import PaddingChannel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment