Skip to content

Instantly share code, notes, and snippets.

Created October 9, 2017 21:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/a711f7dcf8d51492c39ad623246f31c5 to your computer and use it in GitHub Desktop.
Save anonymous/a711f7dcf8d51492c39ad623246f31c5 to your computer and use it in GitHub Desktop.
CoreML bug?
void DummyTest()
{
using namespace CoreML::Specification;
Model model;
model.set_specificationversion( 1 );
// first layer: swapping Seq with Channels
NeuralNetworkLayer* permute = model.mutable_neuralnetwork()->add_layers();
permute->set_name( "permute" );
permute->add_input( "input" );
permute->add_output( "permutated" );
permute->mutable_permute()->add_axis( 1 );
permute->mutable_permute()->add_axis( 0 );
permute->mutable_permute()->add_axis( 2 );
permute->mutable_permute()->add_axis( 3 );
// second layer: slicing channels
NeuralNetworkLayer* slice = model.mutable_neuralnetwork()->add_layers();
slice->set_name( "slice" );
slice->add_input( "permutated" );
slice->add_output( "sliced" );
slice->mutable_slice()->set_startindex( 1 ); // it works if startindex is 0, but I need another
slice->mutable_slice()->set_endindex( 2 ); // slicing 1 element
slice->mutable_slice()->set_stride( 1 ); // I need every element
slice->mutable_slice()->set_axis( SliceLayerParams_SliceAxis_CHANNEL_AXIS ); // channels (ex-sequence)
// third layer: permuting back ex-Channels x ex-Sequence x H x W -> ex-Seq x ex-Channels x H x W
NeuralNetworkLayer* permuteBack = model.mutable_neuralnetwork()->add_layers();
permuteBack->set_name( "permute_back" );
permuteBack->add_input( "sliced" );
permuteBack->add_output( "output" );
permuteBack->mutable_permute()->add_axis( 1 );
permuteBack->mutable_permute()->add_axis( 0 );
permuteBack->mutable_permute()->add_axis( 2 );
permuteBack->mutable_permute()->add_axis( 3 );
// adding input
FeatureDescription* inputDesc = model.mutable_description()->add_input();
inputDesc->set_name( "input" );
ArrayFeatureType* inType = inputDesc->mutable_type()->mutable_multiarraytype();
inType->set_datatype( ArrayFeatureType_ArrayDataType_FLOAT32 );
inType->add_shape( 11 ); // channels
inType->add_shape( 9 ); // height
inType->add_shape( 10 ); // width
// adding output
FeatureDescription* outputDesc = model.mutable_description()->add_output();
outputDesc->set_name( "output" );
ArrayFeatureType* outType = outputDesc->mutable_type()->mutable_multiarraytype();
outType->set_datatype( ArrayFeatureType_ArrayDataType_FLOAT32 );
outType->add_shape( 11 ); // channels
outType->add_shape( 9 ); // height
outType->add_shape( 10 ); // width
// serialization
std::ofstream outStream( "output.mlmodel", std::ios::out | std::ios::binary );
google::protobuf::io::OstreamOutputStream rawOutput( &outStream );
model.SerializePartialToZeroCopyStream( &rawOutput );
model.Clear();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment