Created
December 11, 2017 06:39
-
-
Save y-ich/db973fc2a1a1736adf5570a22e87902f to your computer and use it in GitHub Desktop.
investigating convert error by padding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
WEBDNN = /Users/yuji/OpenSources/webdnn/bin/convert_keras.py | |
PLUGIN = padding_channel_plugin.py | |
PLUGIN_OPTION = --plugin padding_channel_plugin.py | |
OUTDIR = output | |
$(OUTDIR)/kernels_webassembly.js: model.h5 $(PLUGIN) | |
DEBUG=1 CPLUS_INCLUDE_PATH=/usr/local/include/eigen3 python3 $(WEBDNN) $< $(PLUGIN_OPTION) --input_shape '(1,4,1)' --out $(OUTDIR) | |
model.h5: model.py padding_channel.py | |
python3 model.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import tensorflow as tf | |
from keras import backend as K | |
from keras.models import Model | |
from keras.layers import Input, Reshape, Lambda | |
from padding_channel import PaddingChannel | |
# インプットを直接繋げば変換に成功する。 | |
def model_can_be_converted1(): | |
x = Input(shape=(4, 1), name="x") | |
y = PaddingChannel(padding=1)(x) | |
return Model(inputs=[x], outputs=[y]) | |
# 前段でReshapeすると変換に失敗する。 | |
def model_cannot_be_converted1(): | |
x = Input(shape=(4, 1), name="x") | |
tmp = Reshape((2, 2, 1))(x) | |
y = PaddingChannel(padding=1)(tmp) | |
return Model(inputs=[x], outputs=[y]) | |
# Reshapeのみで変換できることを確認した。 | |
def model_can_be_converted2(): | |
x = Input(shape=(4, 1), name="x") | |
y = Reshape((2, 2, 1))(x) | |
return Model(inputs=[x], outputs=[y]) | |
# 2次元で変換できることを確認した。 | |
def model_can_be_converted2(): | |
x = Input(shape=(2, 2, 1), name="x") | |
y = PaddingChannel(padding=1)(x) | |
return Model(inputs=[x], outputs=[y]) | |
if __name__ == '__main__': | |
#model = model_can_be_converted2() | |
model = model_cannot_be_converted1() | |
model.save('model.h5') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from keras.engine import Layer | |
class PaddingChannel(Layer): | |
def __init__(self, padding=0, **kwargs): | |
super(PaddingChannel, self).__init__(**kwargs) | |
self.padding = padding | |
def call(self, x): | |
paddings = [[0, 0] for _ in range(len(x.shape))] | |
paddings[-1][1] = self.padding | |
return tf.pad(x, paddings) | |
def compute_output_shape(self, input_shape): | |
assert input_shape and len(input_shape) >= 2 | |
assert input_shape[-1] | |
output_shape = list(input_shape) | |
output_shape[-1] = input_shape[-1] + self.padding | |
return tuple(output_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.append('.') | |
from padding_channel import PaddingChannel |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment