Skip to content

Instantly share code, notes, and snippets.

@Jimmy-Hu
Created November 4, 2020 22:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Jimmy-Hu/bf1c54b54c492ce0562b15228d67beef to your computer and use it in GitHub Desktop.
Save Jimmy-Hu/bf1c54b54c492ce0562b15228d67beef to your computer and use it in GitHub Desktop.
A recursive_transform Function with boost::multi_array
#include <algorithm>
#include <array>
#include <cassert>
#include <chrono>
#include <complex>
#include <concepts>
#include <deque>
#include <exception>
#include <functional>
#include <iostream>
#include <iterator>
#include <list>
#include <map>
#include <numeric>
#include <optional>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
#define USE_BOOST_MULTIDIMENSIONAL_ARRAY
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY
#include <boost/multi_array.hpp>
#include <boost/multi_array/algorithm.hpp>
#include <boost/multi_array/base.hpp>
#include <boost/multi_array/collection_concept.hpp>
#endif
template<typename T>
concept is_back_inserterable = requires(T x)
{
std::back_inserter(x);
};
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY
template<typename T>
concept is_multi_array = requires(T x)
{
x.num_dimensions();
x.shape();
boost::multi_array(x);
};
template<typename T>
concept is_sub_array = requires(T x)
{
x.num_dimensions();
x.shape();
boost::detail::multi_array::sub_array(x);
};
template<typename T>
concept is_const_sub_array = requires(T x)
{
x.num_dimensions();
x.shape();
boost::detail::multi_array::const_sub_array(x);
};
#endif
template<typename T>
concept is_iterable = requires(T x)
{
*std::begin(x);
std::end(x);
};
template<typename T>
concept is_elements_iterable = requires(T x)
{
std::begin(x)->begin();
std::end(x)->end();
};
template<typename T>
concept is_element_visitable = requires(T x)
{
std::visit([](auto) {}, *x.begin());
};
template<typename T>
concept is_summable = requires(T x) { x + x; };
template<class T, class F>
auto recursive_transform(const T& input, const F& f)
{
return f(input);
}
template<class T, std::size_t S, class F>
auto recursive_transform(const std::array<T, S>& input, const F& f)
{
using TransformedValueType = decltype(recursive_transform(*input.cbegin(), f));
std::array<TransformedValueType, S> output;
std::transform(input.cbegin(), input.cend(), output.begin(),
[f](auto& element)
{
return recursive_transform(element, f);
}
);
return output;
}
template<template<class...> class Container, class Function, class... Ts>
// non-recursive version
auto recursive_transform(const Container<Ts...>& input, const Function& f)
{
using TransformedValueType = decltype(f(*input.cbegin()));
Container<TransformedValueType> output;
std::transform(input.cbegin(), input.cend(), std::back_inserter(output), f);
return output;
}
template<template<class...> class Container, class Function, class... Ts>
requires is_elements_iterable<Container<Ts...>>
auto recursive_transform(const Container<Ts...>& input, const Function& f)
{
using TransformedValueType = decltype(recursive_transform(*input.cbegin(), f));
Container<TransformedValueType> output;
std::transform(input.cbegin(), input.cend(), std::back_inserter(output),
[&](auto& element)
{
return recursive_transform(element, f);
}
);
return output;
}
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY
template<class T, class F> requires (is_multi_array<T> || is_sub_array<T> || is_const_sub_array<T>)
auto recursive_transform(const T& input, const F& f)
{
boost::multi_array output(input);
for (decltype(+input.shape()[0]) i = 0; i < input.shape()[0]; i++)
{
output[i] = recursive_transform(input[i], f);
}
return output;
}
#endif
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY
// Add operator
template<class T1, class T2> requires (is_multi_array<T1>&& is_multi_array<T2>)
auto operator+(const T1& input1, const T2& input2)
{
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation
{
throw std::logic_error("Array dimensions are different");
}
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation
{
throw std::logic_error("Array shapes are different");
}
boost::multi_array output(input1);
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++)
{
output[i] = input1[i] + input2[i];
}
return output;
}
// Minus operator
template<class T1, class T2> requires (is_multi_array<T1>&& is_multi_array<T2>)
auto operator-(const T1& input1, const T2& input2)
{
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation
{
throw std::logic_error("Array dimensions are different");
}
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation
{
throw std::logic_error("Array shapes are different");
}
boost::multi_array output(input1);
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++)
{
output[i] = input1[i] - input2[i];
}
return output;
}
template<typename T>
concept is_multiplicable = requires(T x)
{
x * x;
};
// Multiplication
template<class T1, class T2> requires (is_multiplicable<T1> && is_multiplicable<T2>)
auto element_wise_multiplication(const T1& input1, const T2& input2)
{
return input1 * input2;
}
template<class T1, class T2> requires (is_multi_array<T1> && is_multi_array<T2>)
auto element_wise_multiplication(const T1& input1, const T2& input2)
{
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation
{
throw std::logic_error("Array dimensions are different");
}
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation
{
throw std::logic_error("Array shapes are different");
}
boost::multi_array output(input1);
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++)
{
output[i] = element_wise_multiplication(input1[i], input2[i]);
}
return output;
}
template<typename T>
concept is_divisible = requires(T x)
{
x / x;
};
// Division
template<class T1, class T2> requires (is_divisible<T1> && is_divisible<T2>)
auto element_wise_division(const T1& input1, const T2& input2)
{
if (input2 == 0)
{
throw std::logic_error("Divide by zero exception"); // Handle the case of dividing by zero exception
}
return input1 / input2;
}
template<class T1, class T2> requires (is_multi_array<T1> && is_multi_array<T2>)
auto element_wise_division(const T1& input1, const T2& input2)
{
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation
{
throw std::logic_error("Array dimensions are different");
}
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation
{
throw std::logic_error("Array shapes are different");
}
boost::multi_array output(input1);
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++)
{
output[i] = element_wise_division(input1[i], input2[i]);
}
return output;
}
// Sin function
template<typename T>
concept is_sinable = requires(T x)
{
std::sin(x);
};
template<class T> requires (is_sinable<T>)
auto sin(const T& input)
{
return std::sin(input);
}
template<class T> requires (is_multi_array<T>)
auto sin(const T& input)
{
boost::multi_array output(input);
for (decltype(+input.shape()[0]) i = 0; i < input.shape()[0]; i++)
{
output[i] = sin(input[i]);
}
return output;
}
#endif
void recursive_transform_test1();
int main()
{
recursive_transform_test1();
// Create a 3D array that is 3 x 4 x 2
typedef boost::multi_array<double, 3> array_type;
typedef array_type::index index;
array_type A(boost::extents[3][4][2]);
// Assign values to the elements
int values = 1;
for (index i = 0; i != 3; ++i)
for (index j = 0; j != 4; ++j)
for (index k = 0; k != 2; ++k)
A[i][j][k] = values++;
for (index i = 0; i != 3; ++i)
for (index j = 0; j != 4; ++j)
for (index k = 0; k != 2; ++k)
std::cout << A[i][j][k] << std::endl;
auto test_result = sin(A) - recursive_transform(A, [](auto& x) {return std::sin(x); });
for (index i = 0; i != 3; ++i)
for (index j = 0; j != 4; ++j)
for (index k = 0; k != 2; ++k)
std::cout << test_result[i][j][k] << std::endl;
return 0;
}
void recursive_transform_test1()
{
// std::vector<int> -> std::vector<std::string>
std::vector<int> test_vector = {
1, 2, 3
};
auto recursive_transform_result = recursive_transform(
test_vector,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result.at(0) << std::endl; // recursive_transform_result.at(0) is a std::string
// std::vector<std::vector<int>> -> std::vector<std::vector<std::string>>
std::vector<decltype(test_vector)> test_vector2 = {
test_vector, test_vector, test_vector
};
auto recursive_transform_result2 = recursive_transform(
test_vector2,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result2.at(0).at(0) << std::endl; // recursive_transform_result.at(0).at(0) is also a std::string
// std::deque<int> -> std::deque<std::string>
std::deque<int> test_deque;
test_deque.push_back(1);
test_deque.push_back(1);
test_deque.push_back(1);
auto recursive_transform_result3 = recursive_transform(
test_deque,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result3.at(0) << std::endl;
// std::deque<std::deque<int>> -> std::deque<std::deque<std::string>>
std::deque<decltype(test_deque)> test_deque2;
test_deque2.push_back(test_deque);
test_deque2.push_back(test_deque);
test_deque2.push_back(test_deque);
auto recursive_transform_result4 = recursive_transform(
test_deque2,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result4.at(0).at(0) << std::endl;
// std::array<int, 10> -> std::array<std::string, 10>
std::array<int, 10> test_array;
for (int i = 0; i < 10; i++)
{
test_array[i] = 1;
}
auto recursive_transform_result5 = recursive_transform(
test_array,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result5.at(0) << std::endl;
// std::array<std::array<int, 10>, 10> -> std::array<std::array<std::string, 10>, 10>
std::array<std::array<int, 10>, 10> test_array2;
for (int i = 0; i < 10; i++)
{
test_array2[i] = test_array;
}
auto recursive_transform_result6 = recursive_transform(
test_array2,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result6.at(0).at(0) << std::endl;
// std::list<int> -> std::list<std::string>
std::list<int> test_list = { 1, 2, 3, 4 };
auto recursive_transform_result7 = recursive_transform(
test_list,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result7.front() << std::endl;
// std::list<std::list<int>> -> std::list<std::list<std::string>>
std::list<std::list<int>> test_list2 = { test_list, test_list, test_list, test_list };
auto recursive_transform_result8 = recursive_transform(
test_list2,
[](int x)->std::string { return std::to_string(x); }); // For testing
std::cout << "string: " + recursive_transform_result8.front().front() << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment