Created
November 4, 2020 22:10
-
-
Save Jimmy-Hu/bf1c54b54c492ce0562b15228d67beef to your computer and use it in GitHub Desktop.
A recursive_transform Function with boost::multi_array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <array> | |
#include <cassert> | |
#include <chrono> | |
#include <complex> | |
#include <concepts> | |
#include <deque> | |
#include <exception> | |
#include <functional> | |
#include <iostream> | |
#include <iterator> | |
#include <list> | |
#include <map> | |
#include <numeric> | |
#include <optional> | |
#include <stdexcept> | |
#include <string> | |
#include <type_traits> | |
#include <utility> | |
#include <variant> | |
#include <vector> | |
#define USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
#include <boost/multi_array.hpp> | |
#include <boost/multi_array/algorithm.hpp> | |
#include <boost/multi_array/base.hpp> | |
#include <boost/multi_array/collection_concept.hpp> | |
#endif | |
template<typename T> | |
concept is_back_inserterable = requires(T x) | |
{ | |
std::back_inserter(x); | |
}; | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
template<typename T> | |
concept is_multi_array = requires(T x) | |
{ | |
x.num_dimensions(); | |
x.shape(); | |
boost::multi_array(x); | |
}; | |
template<typename T> | |
concept is_sub_array = requires(T x) | |
{ | |
x.num_dimensions(); | |
x.shape(); | |
boost::detail::multi_array::sub_array(x); | |
}; | |
template<typename T> | |
concept is_const_sub_array = requires(T x) | |
{ | |
x.num_dimensions(); | |
x.shape(); | |
boost::detail::multi_array::const_sub_array(x); | |
}; | |
#endif | |
template<typename T> | |
concept is_iterable = requires(T x) | |
{ | |
*std::begin(x); | |
std::end(x); | |
}; | |
template<typename T> | |
concept is_elements_iterable = requires(T x) | |
{ | |
std::begin(x)->begin(); | |
std::end(x)->end(); | |
}; | |
template<typename T> | |
concept is_element_visitable = requires(T x) | |
{ | |
std::visit([](auto) {}, *x.begin()); | |
}; | |
template<typename T> | |
concept is_summable = requires(T x) { x + x; }; | |
template<class T, class F> | |
auto recursive_transform(const T& input, const F& f) | |
{ | |
return f(input); | |
} | |
template<class T, std::size_t S, class F> | |
auto recursive_transform(const std::array<T, S>& input, const F& f) | |
{ | |
using TransformedValueType = decltype(recursive_transform(*input.cbegin(), f)); | |
std::array<TransformedValueType, S> output; | |
std::transform(input.cbegin(), input.cend(), output.begin(), | |
[f](auto& element) | |
{ | |
return recursive_transform(element, f); | |
} | |
); | |
return output; | |
} | |
template<template<class...> class Container, class Function, class... Ts> | |
// non-recursive version | |
auto recursive_transform(const Container<Ts...>& input, const Function& f) | |
{ | |
using TransformedValueType = decltype(f(*input.cbegin())); | |
Container<TransformedValueType> output; | |
std::transform(input.cbegin(), input.cend(), std::back_inserter(output), f); | |
return output; | |
} | |
template<template<class...> class Container, class Function, class... Ts> | |
requires is_elements_iterable<Container<Ts...>> | |
auto recursive_transform(const Container<Ts...>& input, const Function& f) | |
{ | |
using TransformedValueType = decltype(recursive_transform(*input.cbegin(), f)); | |
Container<TransformedValueType> output; | |
std::transform(input.cbegin(), input.cend(), std::back_inserter(output), | |
[&](auto& element) | |
{ | |
return recursive_transform(element, f); | |
} | |
); | |
return output; | |
} | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
template<class T, class F> requires (is_multi_array<T> || is_sub_array<T> || is_const_sub_array<T>) | |
auto recursive_transform(const T& input, const F& f) | |
{ | |
boost::multi_array output(input); | |
for (decltype(+input.shape()[0]) i = 0; i < input.shape()[0]; i++) | |
{ | |
output[i] = recursive_transform(input[i], f); | |
} | |
return output; | |
} | |
#endif | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
// Add operator | |
template<class T1, class T2> requires (is_multi_array<T1>&& is_multi_array<T2>) | |
auto operator+(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = input1[i] + input2[i]; | |
} | |
return output; | |
} | |
// Minus operator | |
template<class T1, class T2> requires (is_multi_array<T1>&& is_multi_array<T2>) | |
auto operator-(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = input1[i] - input2[i]; | |
} | |
return output; | |
} | |
template<typename T> | |
concept is_multiplicable = requires(T x) | |
{ | |
x * x; | |
}; | |
// Multiplication | |
template<class T1, class T2> requires (is_multiplicable<T1> && is_multiplicable<T2>) | |
auto element_wise_multiplication(const T1& input1, const T2& input2) | |
{ | |
return input1 * input2; | |
} | |
template<class T1, class T2> requires (is_multi_array<T1> && is_multi_array<T2>) | |
auto element_wise_multiplication(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = element_wise_multiplication(input1[i], input2[i]); | |
} | |
return output; | |
} | |
template<typename T> | |
concept is_divisible = requires(T x) | |
{ | |
x / x; | |
}; | |
// Division | |
template<class T1, class T2> requires (is_divisible<T1> && is_divisible<T2>) | |
auto element_wise_division(const T1& input1, const T2& input2) | |
{ | |
if (input2 == 0) | |
{ | |
throw std::logic_error("Divide by zero exception"); // Handle the case of dividing by zero exception | |
} | |
return input1 / input2; | |
} | |
template<class T1, class T2> requires (is_multi_array<T1> && is_multi_array<T2>) | |
auto element_wise_division(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = element_wise_division(input1[i], input2[i]); | |
} | |
return output; | |
} | |
// Sin function | |
template<typename T> | |
concept is_sinable = requires(T x) | |
{ | |
std::sin(x); | |
}; | |
template<class T> requires (is_sinable<T>) | |
auto sin(const T& input) | |
{ | |
return std::sin(input); | |
} | |
template<class T> requires (is_multi_array<T>) | |
auto sin(const T& input) | |
{ | |
boost::multi_array output(input); | |
for (decltype(+input.shape()[0]) i = 0; i < input.shape()[0]; i++) | |
{ | |
output[i] = sin(input[i]); | |
} | |
return output; | |
} | |
#endif | |
void recursive_transform_test1(); | |
int main() | |
{ | |
recursive_transform_test1(); | |
// Create a 3D array that is 3 x 4 x 2 | |
typedef boost::multi_array<double, 3> array_type; | |
typedef array_type::index index; | |
array_type A(boost::extents[3][4][2]); | |
// Assign values to the elements | |
int values = 1; | |
for (index i = 0; i != 3; ++i) | |
for (index j = 0; j != 4; ++j) | |
for (index k = 0; k != 2; ++k) | |
A[i][j][k] = values++; | |
for (index i = 0; i != 3; ++i) | |
for (index j = 0; j != 4; ++j) | |
for (index k = 0; k != 2; ++k) | |
std::cout << A[i][j][k] << std::endl; | |
auto test_result = sin(A) - recursive_transform(A, [](auto& x) {return std::sin(x); }); | |
for (index i = 0; i != 3; ++i) | |
for (index j = 0; j != 4; ++j) | |
for (index k = 0; k != 2; ++k) | |
std::cout << test_result[i][j][k] << std::endl; | |
return 0; | |
} | |
void recursive_transform_test1() | |
{ | |
// std::vector<int> -> std::vector<std::string> | |
std::vector<int> test_vector = { | |
1, 2, 3 | |
}; | |
auto recursive_transform_result = recursive_transform( | |
test_vector, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result.at(0) << std::endl; // recursive_transform_result.at(0) is a std::string | |
// std::vector<std::vector<int>> -> std::vector<std::vector<std::string>> | |
std::vector<decltype(test_vector)> test_vector2 = { | |
test_vector, test_vector, test_vector | |
}; | |
auto recursive_transform_result2 = recursive_transform( | |
test_vector2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result2.at(0).at(0) << std::endl; // recursive_transform_result.at(0).at(0) is also a std::string | |
// std::deque<int> -> std::deque<std::string> | |
std::deque<int> test_deque; | |
test_deque.push_back(1); | |
test_deque.push_back(1); | |
test_deque.push_back(1); | |
auto recursive_transform_result3 = recursive_transform( | |
test_deque, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result3.at(0) << std::endl; | |
// std::deque<std::deque<int>> -> std::deque<std::deque<std::string>> | |
std::deque<decltype(test_deque)> test_deque2; | |
test_deque2.push_back(test_deque); | |
test_deque2.push_back(test_deque); | |
test_deque2.push_back(test_deque); | |
auto recursive_transform_result4 = recursive_transform( | |
test_deque2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result4.at(0).at(0) << std::endl; | |
// std::array<int, 10> -> std::array<std::string, 10> | |
std::array<int, 10> test_array; | |
for (int i = 0; i < 10; i++) | |
{ | |
test_array[i] = 1; | |
} | |
auto recursive_transform_result5 = recursive_transform( | |
test_array, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result5.at(0) << std::endl; | |
// std::array<std::array<int, 10>, 10> -> std::array<std::array<std::string, 10>, 10> | |
std::array<std::array<int, 10>, 10> test_array2; | |
for (int i = 0; i < 10; i++) | |
{ | |
test_array2[i] = test_array; | |
} | |
auto recursive_transform_result6 = recursive_transform( | |
test_array2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result6.at(0).at(0) << std::endl; | |
// std::list<int> -> std::list<std::string> | |
std::list<int> test_list = { 1, 2, 3, 4 }; | |
auto recursive_transform_result7 = recursive_transform( | |
test_list, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result7.front() << std::endl; | |
// std::list<std::list<int>> -> std::list<std::list<std::string>> | |
std::list<std::list<int>> test_list2 = { test_list, test_list, test_list, test_list }; | |
auto recursive_transform_result8 = recursive_transform( | |
test_list2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result8.front().front() << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment