Created
September 25, 2019 15:15
-
-
Save rsdubtso/6b34d71a21fce584b853938c8b4718d0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/******************************************************************************* | |
* Copyright 2018-2019 Intel Corporation | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*******************************************************************************/ | |
/// @example rnn_training_f32.cpp | |
/// @copybrief rnn_training_f32_cpp | |
/// | |
/// @page rnn_training_f32_cpp RNN f32 training example | |
/// This C++ API example demonstrates how to build GNMT model training. | |
/// | |
/// @include rnn_training_f32.cpp | |
#include <cstring> | |
#include <math.h> | |
#include <numeric> | |
#include "example_utils.hpp" | |
using namespace dnnl; | |
const int seq_length = 35; | |
const int batch = 40; | |
// Number of channels | |
const int common_feature_size = 650; | |
// RNN primitive characteristics | |
const int common_n_layers = 1; | |
const int lstm_n_gates = 4; | |
void simple_net(engine::kind engine_kind) { | |
using tag = memory::format_tag; | |
using dt = memory::data_type; | |
auto eng = engine(engine_kind, 0); | |
stream s(eng); | |
bool is_training = true; | |
auto fwd_inf_train = is_training ? prop_kind::forward_training | |
: prop_kind::forward_inference; | |
std::vector<primitive> fwd_net; | |
std::vector<primitive> bwd_net; | |
// Input tensor holds two batches with different sequence lengths. | |
// Shorter sequences are padded | |
memory::dims src_layer_dims = { | |
seq_length, // time | |
batch, // n | |
common_feature_size // c | |
}; | |
memory::dims common_weights_layer_dims = { | |
common_n_layers, // layers | |
1, // directions | |
common_feature_size, // input feature size | |
lstm_n_gates, // gates number | |
common_feature_size // output feature size | |
}; | |
memory::dims common_weights_iter_dims = { | |
common_n_layers, // layers | |
1, // directions | |
common_feature_size, // input feature size | |
lstm_n_gates, // gates number | |
common_feature_size // output feature size | |
}; | |
memory::dims common_bias_dims = { | |
common_n_layers, // layers | |
1, // directions | |
lstm_n_gates, // gates number | |
common_feature_size // output feature size | |
}; | |
memory::dims dst_layer_dims = { | |
seq_length, // time | |
batch, // n | |
common_feature_size // c | |
}; | |
// leftmost primitive passes its states to the next RNN iteration | |
// so it needs dst_iter parameter. | |
// | |
// rightmost primitive will consume these as src_iter and will access the | |
// memory via a sub-memory because it will have different batch dimension. | |
// We have arranged our primitives so that | |
// leftmost_batch >= rightmost_batch, and so the rightmost data will fit | |
// into the memory allocated for the leftmost. | |
memory::dims dst_iter_dims = { | |
common_n_layers, // layers | |
1, // directions | |
batch, // n | |
common_feature_size // c | |
}; | |
memory::dims dst_iter_c_dims = { | |
common_n_layers, // layers | |
1, // directions | |
batch, // n | |
common_feature_size // c | |
}; | |
// multiplication of tensor dimensions | |
auto tz_volume = [=](memory::dims tz_dims) { | |
return std::accumulate(tz_dims.begin(), tz_dims.end(), (memory::dim)1, | |
std::multiplies<memory::dim>()); | |
}; | |
// Create auxillary f32 memory descriptor | |
// based on user- supplied dimensions and layout. | |
auto formatted_md | |
= [=](memory::dims dimensions, memory::format_tag layout) { | |
return memory::desc {{dimensions}, dt::f32, layout}; | |
}; | |
// Create auxillary generic f32 memory descriptor | |
// based on supplied dimensions, with format_tag::any. | |
auto generic_md = [=](memory::dims dimensions) { | |
return formatted_md(dimensions, tag::any); | |
}; | |
// | |
// I/O memory, coming from user | |
// | |
// Net input | |
std::vector<float> net_src(tz_volume(src_layer_dims), -0.1f); | |
// NOTE: in this example we study input sequences with variable batch | |
// dimension, which get processed by two separate RNN primitives, thus | |
// the destination memory for the two will have different shapes: batch | |
// is the second dimension currently: see format_tag::tnc. | |
// We are not copying the output to some common user provided memory as we | |
// suggest that the user should rather keep the two output memories separate | |
// throughout the whole topology and only reorder to something else as | |
// needed. | |
// So there's no common net_dst, but there are two destinations instead: | |
// leftmost_dst_layer_memory | |
// rightmost_dst_layer_memory | |
// Memory for the user allocated memory | |
// Suppose user data is in tnc format. | |
auto net_src_memory | |
= dnnl::memory({{src_layer_dims}, dt::f32, tag::tnc}, eng); | |
write_to_dnnl_memory(net_src.data(), net_src_memory); | |
// src_layer memory of the leftmost and rightmost RNN primitives | |
// are accessed through the respective sub-memories in larger memory. | |
// View primitives compute the strides to accommodate for padding. | |
auto user_src_layer_md = net_src_memory.get_desc().submemory_desc( | |
src_layer_dims, {0, 0, 0}); // t, n, c offsets | |
auto src_layer_memory = net_src_memory; | |
// Other user provided memory arrays, descriptors and primitives with the | |
// data layouts chosen by user. We'll have to reorder if RNN | |
// primitive prefers it in a different format. | |
std::vector<float> user_common_weights_layer( | |
tz_volume(common_weights_layer_dims), -0.1f); | |
auto user_common_weights_layer_memory = dnnl::memory( | |
{common_weights_layer_dims, dt::f32, tag::ldigo}, eng); | |
write_to_dnnl_memory( | |
user_common_weights_layer.data(), user_common_weights_layer_memory); | |
std::vector<float> user_common_weights_iter( | |
tz_volume(common_weights_iter_dims), -0.1f); | |
auto user_common_weights_iter_memory = dnnl::memory( | |
{{common_weights_iter_dims}, dt::f32, tag::ldigo}, eng); | |
write_to_dnnl_memory( | |
user_common_weights_layer.data(), user_common_weights_iter_memory); | |
std::vector<float> user_common_bias(tz_volume(common_bias_dims), 1.0f); | |
auto user_common_bias_memory | |
= dnnl::memory({{common_bias_dims}, dt::f32, tag::ldgo}, eng); | |
write_to_dnnl_memory(user_common_bias.data(), user_common_bias_memory); | |
std::vector<float> user_dst_layer( | |
tz_volume(dst_layer_dims), 1.0f); | |
auto user_dst_layer_memory | |
= dnnl::memory({{dst_layer_dims}, dt::f32, tag::tnc}, eng); | |
write_to_dnnl_memory( | |
user_dst_layer.data(), user_dst_layer_memory); | |
// Describe layer, forward pass, leftmost primitive. | |
// There are no primitives to the left from here, | |
// so src_iter_desc needs to be zero memory desc | |
lstm_forward::desc layer_desc(fwd_inf_train, // aprop_kind | |
rnn_direction::unidirectional_left2right, // direction | |
user_src_layer_md, // src_layer_desc | |
memory::desc(), // src_iter_desc | |
memory::desc(), // src_iter_c_desc | |
generic_md(common_weights_layer_dims), // weights_layer_desc | |
generic_md(common_weights_iter_dims), // weights_iter_desc | |
generic_md(common_bias_dims), // bias_desc | |
formatted_md(dst_layer_dims, tag::tnc), // dst_layer_desc | |
generic_md(dst_iter_dims), // dst_iter_desc | |
generic_md(dst_iter_c_dims) // dst_iter_c_desc | |
); | |
// Describe primitive | |
auto prim_desc | |
= dnnl::lstm_forward::primitive_desc(layer_desc, eng); | |
// | |
// Need to connect leftmost and rightmost via "iter" parameters. | |
// We allocate memory here based on the shapes provided by RNN primitive. | |
// | |
auto dst_iter_memory | |
= dnnl::memory(prim_desc.dst_iter_desc(), eng); | |
auto dst_iter_c_memory | |
= dnnl::memory(prim_desc.dst_iter_c_desc(), eng); | |
// | |
// Weights and biases, layer memory | |
// Same layout should work across the layer, no reorders | |
// needed between leftmost and rigthmost, only reordering | |
// user memory to the RNN-friendly shapes. | |
// | |
auto common_weights_layer_memory = user_common_weights_layer_memory; | |
if (prim_desc.weights_layer_desc() | |
!= common_weights_layer_memory.get_desc()) { | |
common_weights_layer_memory | |
= dnnl::memory(prim_desc.weights_layer_desc(), eng); | |
reorder(user_common_weights_layer_memory, common_weights_layer_memory) | |
.execute(s, user_common_weights_layer_memory, | |
common_weights_layer_memory); | |
} | |
auto common_weights_iter_memory = user_common_weights_iter_memory; | |
if (prim_desc.weights_iter_desc() | |
!= common_weights_iter_memory.get_desc()) { | |
common_weights_iter_memory | |
= dnnl::memory(prim_desc.weights_iter_desc(), eng); | |
reorder(user_common_weights_iter_memory, common_weights_iter_memory) | |
.execute(s, user_common_weights_iter_memory, | |
common_weights_iter_memory); | |
} | |
auto common_bias_memory = user_common_bias_memory; | |
if (prim_desc.bias_desc() != common_bias_memory.get_desc()) { | |
common_bias_memory = dnnl::memory(prim_desc.bias_desc(), eng); | |
reorder(user_common_bias_memory, common_bias_memory) | |
.execute(s, user_common_bias_memory, common_bias_memory); | |
} | |
// | |
// Destination layer memory | |
// | |
auto dst_layer_memory = user_dst_layer_memory; | |
if (prim_desc.dst_layer_desc() | |
!= dst_layer_memory.get_desc()) { | |
dst_layer_memory | |
= dnnl::memory(prim_desc.dst_layer_desc(), eng); | |
reorder(user_dst_layer_memory, dst_layer_memory) | |
.execute(s, user_dst_layer_memory, | |
dst_layer_memory); | |
} | |
// We also create workspace memory based on the information from | |
// the workspace_primitive_desc(). This is needed for internal | |
// communication between forward and backward primitives during | |
// training. | |
auto create_ws = [=](dnnl::lstm_forward::primitive_desc &pd) { | |
return dnnl::memory(pd.workspace_desc(), eng); | |
}; | |
auto workspace_memory = create_ws(prim_desc); | |
// Construct the RNN primitive objects | |
lstm_forward layer(prim_desc); | |
for(int i = 0 ; i < 100; i++){ | |
layer.execute(s, | |
{{DNNL_ARG_SRC_LAYER, src_layer_memory}, | |
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_memory}, | |
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_memory}, | |
{DNNL_ARG_BIAS, common_bias_memory}, | |
{DNNL_ARG_DST_LAYER, dst_layer_memory}, | |
{DNNL_ARG_DST_ITER, dst_iter_memory}, | |
{DNNL_ARG_DST_ITER_C, dst_iter_c_memory}, | |
{DNNL_ARG_WORKSPACE, workspace_memory}}); | |
} | |
// No backward pass for inference | |
if (!is_training) return; | |
// | |
// Backward primitives will reuse memory from forward | |
// and allocate/describe specifics here. Only relevant for training. | |
// | |
// User-provided memory for backward by data output | |
std::vector<float> net_diff_src(tz_volume(src_layer_dims), 1.0f); | |
auto net_diff_src_memory | |
= dnnl::memory(formatted_md(src_layer_dims, tag::tnc), eng); | |
write_to_dnnl_memory(net_diff_src.data(), net_diff_src_memory); | |
// diff_src follows the same layout we have for net_src | |
auto user_diff_src_layer_md | |
= net_diff_src_memory.get_desc().submemory_desc( | |
src_layer_dims, {0, 0, 0}); // t, n, c offsets | |
auto diff_src_layer_memory = net_diff_src_memory; | |
// User-provided memory for backpropagation by weights | |
std::vector<float> user_common_diff_weights_layer( | |
tz_volume(common_weights_layer_dims), 1.0f); | |
auto user_common_diff_weights_layer_memory = dnnl::memory( | |
formatted_md(common_weights_layer_dims, tag::ldigo), eng); | |
write_to_dnnl_memory(user_common_diff_weights_layer.data(), | |
user_common_diff_weights_layer_memory); | |
std::vector<float> user_common_diff_bias(tz_volume(common_bias_dims), 1.0f); | |
auto user_common_diff_bias_memory | |
= dnnl::memory(formatted_md(common_bias_dims, tag::ldgo), eng); | |
write_to_dnnl_memory( | |
user_common_diff_bias.data(), user_common_diff_bias_memory); | |
// User-provided input to the backward primitive. | |
// To be updated by the user after forward pass using some cost function. | |
memory::dims net_diff_dst_dims = { | |
seq_length, | |
batch, | |
common_feature_size // c | |
}; | |
// Suppose user data is in tnc format. | |
std::vector<float> net_diff_dst(tz_volume(net_diff_dst_dims), 1.0f); | |
auto net_diff_dst_memory | |
= dnnl::memory(formatted_md(net_diff_dst_dims, tag::tnc), eng); | |
write_to_dnnl_memory(net_diff_dst.data(), net_diff_dst_memory); | |
// diff_dst_layer memory of the leftmost and rightmost RNN primitives | |
// are accessed through the respective sub-memory in larger memory. | |
// View primitives compute the strides to accommodate for padding. | |
auto user_diff_dst_layer_md | |
= net_diff_dst_memory.get_desc().submemory_desc( | |
dst_layer_dims, {0, 0, 0}); // t, n, c offsets | |
auto diff_dst_layer_memory = net_diff_dst_memory; | |
// Backward leftmost primitive descriptor | |
lstm_backward::desc layer_bwd_desc( | |
prop_kind::backward, // aprop_kind | |
rnn_direction::unidirectional_left2right, // direction | |
user_src_layer_md, // src_layer_desc | |
memory::desc(), // src_iter_desc | |
memory::desc(), // src_iter_c_desc | |
generic_md(common_weights_layer_dims), // weights_layer_desc | |
generic_md(common_weights_iter_dims), // weights_iter_desc | |
generic_md(common_bias_dims), // bias_desc | |
formatted_md(dst_layer_dims, tag::tnc), // dst_layer_desc | |
generic_md(dst_iter_dims), // dst_iter_desc | |
generic_md(dst_iter_c_dims), // dst_iter_c_desc | |
user_diff_src_layer_md, // diff_src_layer_desc | |
memory::desc(), // diff_src_iter_desc | |
memory::desc(), // diff_src_iter_c_desc | |
generic_md(common_weights_layer_dims), // diff_weights_layer_desc | |
generic_md(common_weights_iter_dims), // diff_weights_iter_desc | |
generic_md(common_bias_dims), // diff_bias_desc | |
user_diff_dst_layer_md, // diff_dst_layer_desc | |
generic_md(dst_iter_dims), // diff_dst_iter_desc | |
generic_md(dst_iter_c_dims) // diff_dst_iter_c_desc | |
); | |
auto bwd_prim_desc = lstm_backward::primitive_desc( | |
layer_bwd_desc, eng, prim_desc); | |
// As the batch dimensions are different between leftmost and rightmost | |
// we need to use a sub-memory. rightmost needs less memory, so it will | |
// be a sub-memory of leftmost. | |
auto diff_dst_iter_memory | |
= dnnl::memory(bwd_prim_desc.diff_dst_iter_desc(), eng); | |
auto diff_dst_iter_c_memory | |
= dnnl::memory(bwd_prim_desc.diff_dst_iter_c_desc(), eng); | |
// | |
// Memory for backward pass | |
// | |
// src layer uses the same memory as forward | |
auto src_layer_bwd_memory = src_layer_memory; | |
// Memory for weights and biases for backward pass | |
// Try to use the same memory between forward and backward, but | |
// sometimes reorders are needed. | |
auto common_weights_layer_bwd_memory = common_weights_layer_memory; | |
if (bwd_prim_desc.weights_layer_desc() | |
!= prim_desc.weights_layer_desc()) { | |
common_weights_layer_bwd_memory | |
= memory(bwd_prim_desc.weights_layer_desc(), eng); | |
reorder(common_weights_layer_memory, common_weights_layer_bwd_memory) | |
.execute(s, common_weights_layer_memory, | |
common_weights_layer_bwd_memory); | |
} | |
auto common_weights_iter_bwd_memory = common_weights_iter_memory; | |
if (bwd_prim_desc.weights_iter_desc() | |
!= prim_desc.weights_iter_desc()) { | |
common_weights_iter_bwd_memory | |
= memory(bwd_prim_desc.weights_iter_desc(), eng); | |
reorder(common_weights_iter_memory, common_weights_iter_bwd_memory) | |
.execute(s, common_weights_iter_memory, | |
common_weights_iter_bwd_memory); | |
} | |
auto common_bias_bwd_memory = common_bias_memory; | |
if (bwd_prim_desc.bias_desc() != common_bias_memory.get_desc()) { | |
common_bias_bwd_memory | |
= dnnl::memory(bwd_prim_desc.bias_desc(), eng); | |
reorder(common_bias_memory, common_bias_bwd_memory) | |
.execute(s, common_bias_memory, common_bias_bwd_memory); | |
} | |
// diff_weights and biases | |
auto common_diff_weights_layer_memory | |
= user_common_diff_weights_layer_memory; | |
auto reorder_common_diff_weights_layer = false; | |
if (bwd_prim_desc.diff_weights_layer_desc() | |
!= common_diff_weights_layer_memory.get_desc()) { | |
common_diff_weights_layer_memory = dnnl::memory( | |
bwd_prim_desc.diff_weights_layer_desc(), eng); | |
reorder_common_diff_weights_layer = true; | |
} | |
auto common_diff_bias_memory = user_common_diff_bias_memory; | |
auto reorder_common_diff_bias = false; | |
if (bwd_prim_desc.diff_bias_desc() | |
!= common_diff_bias_memory.get_desc()) { | |
common_diff_bias_memory | |
= dnnl::memory(bwd_prim_desc.diff_bias_desc(), eng); | |
reorder_common_diff_bias = true; | |
} | |
// dst_layer memory for backward pass | |
auto dst_layer_bwd_memory = dst_layer_memory; | |
if (bwd_prim_desc.dst_layer_desc() | |
!= dst_layer_bwd_memory.get_desc()) { | |
dst_layer_bwd_memory | |
= dnnl::memory(bwd_prim_desc.dst_layer_desc(), eng); | |
reorder(dst_layer_memory, dst_layer_bwd_memory) | |
.execute(s, dst_layer_memory, | |
dst_layer_bwd_memory); | |
} | |
// Similar to forward, the backward primitives are connected | |
// via "iter" parameters. | |
auto common_diff_weights_iter_memory = dnnl::memory( | |
bwd_prim_desc.diff_weights_iter_desc(), eng); | |
auto dst_iter_bwd_memory = dst_iter_memory; | |
if (bwd_prim_desc.dst_iter_desc() | |
!= dst_iter_bwd_memory.get_desc()) { | |
dst_iter_bwd_memory | |
= dnnl::memory(bwd_prim_desc.dst_iter_desc(), eng); | |
reorder(dst_iter_memory, dst_iter_bwd_memory) | |
.execute(s, dst_iter_memory, | |
dst_iter_bwd_memory); | |
} | |
auto dst_iter_c_bwd_memory = dst_iter_c_memory; | |
if (bwd_prim_desc.dst_iter_c_desc() | |
!= dst_iter_c_bwd_memory.get_desc()) { | |
dst_iter_c_bwd_memory | |
= dnnl::memory(bwd_prim_desc.dst_iter_c_desc(), eng); | |
reorder(dst_iter_c_memory, dst_iter_c_bwd_memory) | |
.execute(s, dst_iter_c_memory, | |
dst_iter_c_bwd_memory); | |
} | |
// Construct the RNN primitive objects for backward | |
lstm_backward layer_bwd(bwd_prim_desc); | |
for(int i = 0 ; i < 100; i++){ | |
layer_bwd.execute(s, | |
{{DNNL_ARG_SRC_LAYER, src_layer_bwd_memory}, | |
{DNNL_ARG_WEIGHTS_LAYER, common_weights_layer_bwd_memory}, | |
{DNNL_ARG_WEIGHTS_ITER, common_weights_iter_bwd_memory}, | |
{DNNL_ARG_BIAS, common_bias_bwd_memory}, | |
{DNNL_ARG_DST_LAYER, dst_layer_bwd_memory}, | |
{DNNL_ARG_DST_ITER, dst_iter_bwd_memory}, | |
{DNNL_ARG_DST_ITER_C, dst_iter_c_bwd_memory}, | |
{DNNL_ARG_DIFF_SRC_LAYER, diff_src_layer_memory}, | |
{DNNL_ARG_DIFF_WEIGHTS_LAYER, | |
common_diff_weights_layer_memory}, | |
{DNNL_ARG_DIFF_WEIGHTS_ITER, | |
common_diff_weights_iter_memory}, | |
{DNNL_ARG_DIFF_BIAS, common_diff_bias_memory}, | |
{DNNL_ARG_DIFF_DST_LAYER, diff_dst_layer_memory}, | |
{DNNL_ARG_DIFF_DST_ITER, diff_dst_iter_memory}, | |
{DNNL_ARG_DIFF_DST_ITER_C, diff_dst_iter_c_memory}, | |
{DNNL_ARG_WORKSPACE, workspace_memory}}); | |
} | |
if (reorder_common_diff_weights_layer) { | |
reorder(common_diff_weights_layer_memory, | |
user_common_diff_weights_layer_memory) | |
.execute(s, common_diff_weights_layer_memory, | |
user_common_diff_weights_layer_memory); | |
} | |
if (reorder_common_diff_bias) { | |
reorder(common_diff_bias_memory, user_common_diff_bias_memory) | |
.execute(s, common_diff_bias_memory, | |
user_common_diff_bias_memory); | |
} | |
// | |
// User updates weights and bias using diffs | |
// | |
s.wait(); | |
} | |
int main(int argc, char **argv) { | |
try { | |
simple_net(parse_engine_kind(argc, argv)); | |
std::cout << "Simple net f32 training example passed!\n"; | |
} catch (error &e) { | |
std::cerr << "status: " << e.status << std::endl; | |
std::cerr << "message: " << e.message << std::endl; | |
return 1; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment