Created
October 9, 2017 21:43
-
-
Save anonymous/a711f7dcf8d51492c39ad623246f31c5 to your computer and use it in GitHub Desktop.
CoreML bug?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void DummyTest() | |
{ | |
using namespace CoreML::Specification; | |
Model model; | |
model.set_specificationversion( 1 ); | |
// first layer: swapping Seq with Channels | |
NeuralNetworkLayer* permute = model.mutable_neuralnetwork()->add_layers(); | |
permute->set_name( "permute" ); | |
permute->add_input( "input" ); | |
permute->add_output( "permutated" ); | |
permute->mutable_permute()->add_axis( 1 ); | |
permute->mutable_permute()->add_axis( 0 ); | |
permute->mutable_permute()->add_axis( 2 ); | |
permute->mutable_permute()->add_axis( 3 ); | |
// second layer: slicing channels | |
NeuralNetworkLayer* slice = model.mutable_neuralnetwork()->add_layers(); | |
slice->set_name( "slice" ); | |
slice->add_input( "permutated" ); | |
slice->add_output( "sliced" ); | |
slice->mutable_slice()->set_startindex( 1 ); // it works if startindex is 0, but I need another | |
slice->mutable_slice()->set_endindex( 2 ); // slicing 1 element | |
slice->mutable_slice()->set_stride( 1 ); // I need every element | |
slice->mutable_slice()->set_axis( SliceLayerParams_SliceAxis_CHANNEL_AXIS ); // channels (ex-sequence) | |
// third layer: permuting back ex-Channels x ex-Sequence x H x W -> ex-Seq x ex-Channels x H x W | |
NeuralNetworkLayer* permuteBack = model.mutable_neuralnetwork()->add_layers(); | |
permuteBack->set_name( "permute_back" ); | |
permuteBack->add_input( "sliced" ); | |
permuteBack->add_output( "output" ); | |
permuteBack->mutable_permute()->add_axis( 1 ); | |
permuteBack->mutable_permute()->add_axis( 0 ); | |
permuteBack->mutable_permute()->add_axis( 2 ); | |
permuteBack->mutable_permute()->add_axis( 3 ); | |
// adding input | |
FeatureDescription* inputDesc = model.mutable_description()->add_input(); | |
inputDesc->set_name( "input" ); | |
ArrayFeatureType* inType = inputDesc->mutable_type()->mutable_multiarraytype(); | |
inType->set_datatype( ArrayFeatureType_ArrayDataType_FLOAT32 ); | |
inType->add_shape( 11 ); // channels | |
inType->add_shape( 9 ); // height | |
inType->add_shape( 10 ); // width | |
// adding output | |
FeatureDescription* outputDesc = model.mutable_description()->add_output(); | |
outputDesc->set_name( "output" ); | |
ArrayFeatureType* outType = outputDesc->mutable_type()->mutable_multiarraytype(); | |
outType->set_datatype( ArrayFeatureType_ArrayDataType_FLOAT32 ); | |
outType->add_shape( 11 ); // channels | |
outType->add_shape( 9 ); // height | |
outType->add_shape( 10 ); // width | |
// serialization | |
std::ofstream outStream( "output.mlmodel", std::ios::out | std::ios::binary ); | |
google::protobuf::io::OstreamOutputStream rawOutput( &outStream ); | |
model.SerializePartialToZeroCopyStream( &rawOutput ); | |
model.Clear(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment