Skip to content

Instantly share code, notes, and snippets.

@rsdubtso
Created September 25, 2019 15:15
Show Gist options
  • Save rsdubtso/6b34d71a21fce584b853938c8b4718d0 to your computer and use it in GitHub Desktop.
Save rsdubtso/6b34d71a21fce584b853938c8b4718d0 to your computer and use it in GitHub Desktop.
/*******************************************************************************
* Copyright 2018-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/// @example rnn_training_f32.cpp
/// @copybrief rnn_training_f32_cpp
///
/// @page rnn_training_f32_cpp RNN f32 training example
/// This C++ API example demonstrates how to build GNMT model training.
///
/// @include rnn_training_f32.cpp
#include <cstring>
#include <math.h>
#include <numeric>
#include "example_utils.hpp"
using namespace dnnl;
const int seq_length = 35;
const int batch = 40;
// Number of channels
const int common_feature_size = 650;
// RNN primitive characteristics
const int common_n_layers = 1;
const int lstm_n_gates = 4;
void simple_net(engine::kind engine_kind) {
using tag = memory::format_tag;
using dt = memory::data_type;
auto eng = engine(engine_kind, 0);
stream s(eng);
bool is_training = true;
auto fwd_inf_train = is_training ? prop_kind::forward_training
: prop_kind::forward_inference;
std::vector<primitive> fwd_net;
std::vector<primitive> bwd_net;
// Input tensor holds two batches with different sequence lengths.
// Shorter sequences are padded
memory::dims src_layer_dims = {
seq_length, // time
batch, // n
common_feature_size // c
};
memory::dims common_weights_layer_dims = {
common_n_layers, // layers
1, // directions
common_feature_size, // input feature size
lstm_n_gates, // gates number
common_feature_size // output feature size
};
memory::dims common_weights_iter_dims = {
common_n_layers, // layers
1, // directions
common_feature_size, // input feature size
lstm_n_gates, // gates number
common_feature_size // output feature size
};
memory::dims common_bias_dims = {
common_n_layers, // layers
1, // directions
lstm_n_gates, // gates number
common_feature_size // output feature size
};
memory::dims dst_layer_dims = {
seq_length, // time
batch, // n
common_feature_size // c
};
// leftmost primitive passes its states to the next RNN iteration
// so it needs dst_iter parameter.
//
// rightmost primitive will consume these as src_iter and will access the
// memory via a sub-memory because it will have different batch dimension.
// We have arranged our primitives so that
// leftmost_batch >= rightmost_batch, and so the rightmost data will fit
// into the memory allocated for the leftmost.
memory::dims dst_iter_dims = {
common_n_layers, // layers
1, // directions
batch, // n
common_feature_size // c
};
memory::dims dst_iter_c_dims = {
common_n_layers, // layers
1, // directions
batch, // n
common_feature_size // c
};
// multiplication of tensor dimensions
auto tz_volume = [=](memory::dims tz_dims) {
return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1,
std::multiplies<memory::dim>());
};
// Create auxillary f32 memory descriptor
// based on user- supplied dimensions and layout.
auto formatted_md
= [=](memory::dims dimensions, memory::format_tag layout) {
return memory::desc {{dimensions}, dt::f32, layout};
};
// Create auxillary generic f32 memory descriptor
// based on supplied dimensions, with format_tag::any.
auto generic_md = [=](memory::dims dimensions) {
return formatted_md(dimensions, tag::any);
};
//
// I/O memory, coming from user
//
// Net input
std::vector<float> net_src(tz_volume(src_layer_dims), -0.1f);
// NOTE: in this example we study input sequences with variable batch
// dimension, which get processed by two separate RNN primitives, thus
// the destination memory for the two will have different shapes: batch
// is the second dimension currently: see format_tag::tnc.
// We are not copying the output to some common user provided memory as we
// suggest that the user should rather keep the two output memories separate
// throughout the whole topology and only reorder to something else as
// needed.
// So there's no common net_dst, but there are two destinations instead:
// leftmost_dst_layer_memory
// rightmost_dst_layer_memory
// Memory for the user allocated memory
// Suppose user data is in tnc format.
auto net_src_memory
= dnnl::memory({{src_layer_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(net_src.data(), net_src_memory);
// src_layer memory of the leftmost and rightmost RNN primitives
// are accessed through the respective sub-memories in larger memory.
// View primitives compute the strides to accommodate for padding.
auto user_src_layer_md = net_src_memory.get_desc().submemory_desc(
src_layer_dims, {0, 0, 0}); // t, n, c offsets
auto src_layer_memory = net_src_memory;
// Other user provided memory arrays, descriptors and primitives with the
// data layouts chosen by user. We'll have to reorder if RNN
// primitive prefers it in a different format.
std::vector<float> user_common_weights_layer(
tz_volume(common_weights_layer_dims), -0.1f);
auto user_common_weights_layer_memory = dnnl::memory(
{common_weights_layer_dims, dt::f32, tag::ldigo}, eng);
write_to_dnnl_memory(
user_common_weights_layer.data(), user_common_weights_layer_memory);
std::vector<float> user_common_weights_iter(
tz_volume(common_weights_iter_dims), -0.1f);
auto user_common_weights_iter_memory = dnnl::memory(
{{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng);
write_to_dnnl_memory(
user_common_weights_layer.data(), user_common_weights_iter_memory);
std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_bias_memory
= dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng);
write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory);
std::vector<float> user_dst_layer(
tz_volume(dst_layer_dims), 1.0f);
auto user_dst_layer_memory
= dnnl::memory({{dst_layer_dims}, dt::f32, tag::tnc}, eng);
write_to_dnnl_memory(
user_dst_layer.data(), user_dst_layer_memory);
// Describe layer, forward pass, leftmost primitive.
// There are no primitives to the left from here,
// so src_iter_desc needs to be zero memory desc
lstm_forward::desc layer_desc(fwd_inf_train, // aprop_kind
rnn_direction::unidirectional_left2right, // direction
user_src_layer_md, // src_layer_desc
memory::desc(), // src_iter_desc
memory::desc(), // src_iter_c_desc
generic_md(common_weights_layer_dims), // weights_layer_desc
generic_md(common_weights_iter_dims), // weights_iter_desc
generic_md(common_bias_dims), // bias_desc
formatted_md(dst_layer_dims, tag::tnc), // dst_layer_desc
generic_md(dst_iter_dims), // dst_iter_desc
generic_md(dst_iter_c_dims) // dst_iter_c_desc
);
// Describe primitive
auto prim_desc
= dnnl::lstm_forward::primitive_desc(layer_desc, eng);
//
// Need to connect leftmost and rightmost via "iter" parameters.
// We allocate memory here based on the shapes provided by RNN primitive.
//
auto dst_iter_memory
= dnnl::memory(prim_desc.dst_iter_desc(), eng);
auto dst_iter_c_memory
= dnnl::memory(prim_desc.dst_iter_c_desc(), eng);
//
// Weights and biases, layer memory
// Same layout should work across the layer, no reorders
// needed between leftmost and rigthmost, only reordering
// user memory to the RNN-friendly shapes.
//
auto common_weights_layer_memory = user_common_weights_layer_memory;
if (prim_desc.weights_layer_desc()
!= common_weights_layer_memory.get_desc()) {
common_weights_layer_memory
= dnnl::memory(prim_desc.weights_layer_desc(), eng);
reorder(user_common_weights_layer_memory, common_weights_layer_memory)
.execute(s, user_common_weights_layer_memory,
common_weights_layer_memory);
}
auto common_weights_iter_memory = user_common_weights_iter_memory;
if (prim_desc.weights_iter_desc()
!= common_weights_iter_memory.get_desc()) {
common_weights_iter_memory
= dnnl::memory(prim_desc.weights_iter_desc(), eng);
reorder(user_common_weights_iter_memory, common_weights_iter_memory)
.execute(s, user_common_weights_iter_memory,
common_weights_iter_memory);
}
auto common_bias_memory = user_common_bias_memory;
if (prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_memory = dnnl::memory(prim_desc.bias_desc(), eng);
reorder(user_common_bias_memory, common_bias_memory)
.execute(s, user_common_bias_memory, common_bias_memory);
}
//
// Destination layer memory
//
auto dst_layer_memory = user_dst_layer_memory;
if (prim_desc.dst_layer_desc()
!= dst_layer_memory.get_desc()) {
dst_layer_memory
= dnnl::memory(prim_desc.dst_layer_desc(), eng);
reorder(user_dst_layer_memory, dst_layer_memory)
.execute(s, user_dst_layer_memory,
dst_layer_memory);
}
// We also create workspace memory based on the information from
// the workspace_primitive_desc(). This is needed for internal
// communication between forward and backward primitives during
// training.
auto create_ws = [=](dnnl::lstm_forward::primitive_desc &pd) {
return dnnl::memory(pd.workspace_desc(), eng);
};
auto workspace_memory = create_ws(prim_desc);
// Construct the RNN primitive objects
lstm_forward layer(prim_desc);
for(int i = 0 ; i < 100; i++){
layer.execute(s,
{{DNNL_ARG_SRC_LAYER, src_layer_memory},
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory},
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory},
{DNNL_ARG_BIAS, common_bias_memory},
{DNNL_ARG_DST_LAYER, dst_layer_memory},
{DNNL_ARG_DST_ITER, dst_iter_memory},
{DNNL_ARG_DST_ITER_C, dst_iter_c_memory},
{DNNL_ARG_WORKSPACE, workspace_memory}});
}
// No backward pass for inference
if (!is_training) return;
//
// Backward primitives will reuse memory from forward
// and allocate/describe specifics here. Only relevant for training.
//
// User-provided memory for backward by data output
std::vector<float> net_diff_src(tz_volume(src_layer_dims), 1.0f);
auto net_diff_src_memory
= dnnl::memory(formatted_md(src_layer_dims, tag::tnc), eng);
write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory);
// diff_src follows the same layout we have for net_src
auto user_diff_src_layer_md
= net_diff_src_memory.get_desc().submemory_desc(
src_layer_dims, {0, 0, 0}); // t, n, c offsets
auto diff_src_layer_memory = net_diff_src_memory;
// User-provided memory for backpropagation by weights
std::vector<float> user_common_diff_weights_layer(
tz_volume(common_weights_layer_dims), 1.0f);
auto user_common_diff_weights_layer_memory = dnnl::memory(
formatted_md(common_weights_layer_dims, tag::ldigo), eng);
write_to_dnnl_memory(user_common_diff_weights_layer.data(),
user_common_diff_weights_layer_memory);
std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f);
auto user_common_diff_bias_memory
= dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng);
write_to_dnnl_memory(
user_common_diff_bias.data(), user_common_diff_bias_memory);
// User-provided input to the backward primitive.
// To be updated by the user after forward pass using some cost function.
memory::dims net_diff_dst_dims = {
seq_length,
batch,
common_feature_size // c
};
// Suppose user data is in tnc format.
std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f);
auto net_diff_dst_memory
= dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng);
write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory);
// diff_dst_layer memory of the leftmost and rightmost RNN primitives
// are accessed through the respective sub-memory in larger memory.
// View primitives compute the strides to accommodate for padding.
auto user_diff_dst_layer_md
= net_diff_dst_memory.get_desc().submemory_desc(
dst_layer_dims, {0, 0, 0}); // t, n, c offsets
auto diff_dst_layer_memory = net_diff_dst_memory;
// Backward leftmost primitive descriptor
lstm_backward::desc layer_bwd_desc(
prop_kind::backward, // aprop_kind
rnn_direction::unidirectional_left2right, // direction
user_src_layer_md, // src_layer_desc
memory::desc(), // src_iter_desc
memory::desc(), // src_iter_c_desc
generic_md(common_weights_layer_dims), // weights_layer_desc
generic_md(common_weights_iter_dims), // weights_iter_desc
generic_md(common_bias_dims), // bias_desc
formatted_md(dst_layer_dims, tag::tnc), // dst_layer_desc
generic_md(dst_iter_dims), // dst_iter_desc
generic_md(dst_iter_c_dims), // dst_iter_c_desc
user_diff_src_layer_md, // diff_src_layer_desc
memory::desc(), // diff_src_iter_desc
memory::desc(), // diff_src_iter_c_desc
generic_md(common_weights_layer_dims), // diff_weights_layer_desc
generic_md(common_weights_iter_dims), // diff_weights_iter_desc
generic_md(common_bias_dims), // diff_bias_desc
user_diff_dst_layer_md, // diff_dst_layer_desc
generic_md(dst_iter_dims), // diff_dst_iter_desc
generic_md(dst_iter_c_dims) // diff_dst_iter_c_desc
);
auto bwd_prim_desc = lstm_backward::primitive_desc(
layer_bwd_desc, eng, prim_desc);
// As the batch dimensions are different between leftmost and rightmost
// we need to use a sub-memory. rightmost needs less memory, so it will
// be a sub-memory of leftmost.
auto diff_dst_iter_memory
= dnnl::memory(bwd_prim_desc.diff_dst_iter_desc(), eng);
auto diff_dst_iter_c_memory
= dnnl::memory(bwd_prim_desc.diff_dst_iter_c_desc(), eng);
//
// Memory for backward pass
//
// src layer uses the same memory as forward
auto src_layer_bwd_memory = src_layer_memory;
// Memory for weights and biases for backward pass
// Try to use the same memory between forward and backward, but
// sometimes reorders are needed.
auto common_weights_layer_bwd_memory = common_weights_layer_memory;
if (bwd_prim_desc.weights_layer_desc()
!= prim_desc.weights_layer_desc()) {
common_weights_layer_bwd_memory
= memory(bwd_prim_desc.weights_layer_desc(), eng);
reorder(common_weights_layer_memory, common_weights_layer_bwd_memory)
.execute(s, common_weights_layer_memory,
common_weights_layer_bwd_memory);
}
auto common_weights_iter_bwd_memory = common_weights_iter_memory;
if (bwd_prim_desc.weights_iter_desc()
!= prim_desc.weights_iter_desc()) {
common_weights_iter_bwd_memory
= memory(bwd_prim_desc.weights_iter_desc(), eng);
reorder(common_weights_iter_memory, common_weights_iter_bwd_memory)
.execute(s, common_weights_iter_memory,
common_weights_iter_bwd_memory);
}
auto common_bias_bwd_memory = common_bias_memory;
if (bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) {
common_bias_bwd_memory
= dnnl::memory(bwd_prim_desc.bias_desc(), eng);
reorder(common_bias_memory, common_bias_bwd_memory)
.execute(s, common_bias_memory, common_bias_bwd_memory);
}
// diff_weights and biases
auto common_diff_weights_layer_memory
= user_common_diff_weights_layer_memory;
auto reorder_common_diff_weights_layer = false;
if (bwd_prim_desc.diff_weights_layer_desc()
!= common_diff_weights_layer_memory.get_desc()) {
common_diff_weights_layer_memory = dnnl::memory(
bwd_prim_desc.diff_weights_layer_desc(), eng);
reorder_common_diff_weights_layer = true;
}
auto common_diff_bias_memory = user_common_diff_bias_memory;
auto reorder_common_diff_bias = false;
if (bwd_prim_desc.diff_bias_desc()
!= common_diff_bias_memory.get_desc()) {
common_diff_bias_memory
= dnnl::memory(bwd_prim_desc.diff_bias_desc(), eng);
reorder_common_diff_bias = true;
}
// dst_layer memory for backward pass
auto dst_layer_bwd_memory = dst_layer_memory;
if (bwd_prim_desc.dst_layer_desc()
!= dst_layer_bwd_memory.get_desc()) {
dst_layer_bwd_memory
= dnnl::memory(bwd_prim_desc.dst_layer_desc(), eng);
reorder(dst_layer_memory, dst_layer_bwd_memory)
.execute(s, dst_layer_memory,
dst_layer_bwd_memory);
}
// Similar to forward, the backward primitives are connected
// via "iter" parameters.
auto common_diff_weights_iter_memory = dnnl::memory(
bwd_prim_desc.diff_weights_iter_desc(), eng);
auto dst_iter_bwd_memory = dst_iter_memory;
if (bwd_prim_desc.dst_iter_desc()
!= dst_iter_bwd_memory.get_desc()) {
dst_iter_bwd_memory
= dnnl::memory(bwd_prim_desc.dst_iter_desc(), eng);
reorder(dst_iter_memory, dst_iter_bwd_memory)
.execute(s, dst_iter_memory,
dst_iter_bwd_memory);
}
auto dst_iter_c_bwd_memory = dst_iter_c_memory;
if (bwd_prim_desc.dst_iter_c_desc()
!= dst_iter_c_bwd_memory.get_desc()) {
dst_iter_c_bwd_memory
= dnnl::memory(bwd_prim_desc.dst_iter_c_desc(), eng);
reorder(dst_iter_c_memory, dst_iter_c_bwd_memory)
.execute(s, dst_iter_c_memory,
dst_iter_c_bwd_memory);
}
// Construct the RNN primitive objects for backward
lstm_backward layer_bwd(bwd_prim_desc);
for(int i = 0 ; i < 100; i++){
layer_bwd.execute(s,
{{DNNL_ARG_SRC_LAYER, src_layer_bwd_memory},
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory},
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory},
{DNNL_ARG_BIAS, common_bias_bwd_memory},
{DNNL_ARG_DST_LAYER, dst_layer_bwd_memory},
{DNNL_ARG_DST_ITER, dst_iter_bwd_memory},
{DNNL_ARG_DST_ITER_C, dst_iter_c_bwd_memory},
{DNNL_ARG_DIFF_SRC_LAYER, diff_src_layer_memory},
{DNNL_ARG_DIFF_WEIGHTS_LAYER,
common_diff_weights_layer_memory},
{DNNL_ARG_DIFF_WEIGHTS_ITER,
common_diff_weights_iter_memory},
{DNNL_ARG_DIFF_BIAS, common_diff_bias_memory},
{DNNL_ARG_DIFF_DST_LAYER, diff_dst_layer_memory},
{DNNL_ARG_DIFF_DST_ITER, diff_dst_iter_memory},
{DNNL_ARG_DIFF_DST_ITER_C, diff_dst_iter_c_memory},
{DNNL_ARG_WORKSPACE, workspace_memory}});
}
if (reorder_common_diff_weights_layer) {
reorder(common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory)
.execute(s, common_diff_weights_layer_memory,
user_common_diff_weights_layer_memory);
}
if (reorder_common_diff_bias) {
reorder(common_diff_bias_memory, user_common_diff_bias_memory)
.execute(s, common_diff_bias_memory,
user_common_diff_bias_memory);
}
//
// User updates weights and bias using diffs
//
s.wait();
}
int main(int argc, char **argv) {
try {
simple_net(parse_engine_kind(argc, argv));
std::cout << "Simple net f32 training example passed!\n";
} catch (error &e) {
std::cerr << "status: " << e.status << std::endl;
std::cerr << "message: " << e.message << std::endl;
return 1;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment