Last active
November 8, 2020 12:10
-
-
Save Jimmy-Hu/74cf5e27f6edc42ecae58eef236aed1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// A Ones Function for Boost.MultiArray | |
#include <algorithm> | |
#include <array> | |
#include <cassert> | |
#include <chrono> | |
#include <complex> | |
#include <concepts> | |
#include <deque> | |
#include <exception> | |
#include <functional> | |
#include <iostream> | |
#include <iterator> | |
#include <list> | |
#include <map> | |
#include <numeric> | |
#include <optional> | |
#include <stdexcept> | |
#include <string> | |
#include <type_traits> | |
#include <utility> | |
#include <variant> | |
#include <vector> | |
#define USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
#include <boost/multi_array.hpp> | |
#include <boost/multi_array/algorithm.hpp> | |
#include <boost/multi_array/base.hpp> | |
#include <boost/multi_array/collection_concept.hpp> | |
#endif | |
template<typename T> | |
concept is_back_inserterable = requires(T x) | |
{ | |
std::back_inserter(x); | |
}; | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
template<typename T> | |
concept is_multi_array = requires(T x) | |
{ | |
x.num_dimensions(); | |
x.shape(); | |
boost::multi_array(x); | |
}; | |
#endif | |
template<typename T> | |
concept is_iterable = requires(T x) | |
{ | |
*std::begin(x); | |
std::end(x); | |
}; | |
template<typename T> | |
concept is_elements_iterable = requires(T x) | |
{ | |
std::begin(x)->begin(); | |
std::end(x)->end(); | |
}; | |
template<typename T> | |
concept is_element_visitable = requires(T x) | |
{ | |
std::visit([](auto) {}, *x.begin()); | |
}; | |
template<typename T> | |
concept is_summable = requires(T x) { x + x; }; | |
template<class T, class F> | |
auto recursive_transform(const T& input, const F& f) | |
{ | |
return f(input); | |
} | |
template<class T, std::size_t S, class F> | |
auto recursive_transform(const std::array<T, S>& input, const F& f) | |
{ | |
using TransformedValueType = decltype(recursive_transform(*input.cbegin(), f)); | |
std::array<TransformedValueType, S> output; | |
std::transform(input.cbegin(), input.cend(), output.begin(), | |
[f](auto& element) | |
{ | |
return recursive_transform(element, f); | |
} | |
); | |
return output; | |
} | |
template<template<class...> class Container, class Function, class... Ts> | |
requires (is_back_inserterable<Container<Ts...>>&& is_iterable<Container<Ts...>> && !is_elements_iterable<Container<Ts...>>) | |
// non-recursive version | |
auto recursive_transform(const Container<Ts...>& input, const Function& f) | |
{ | |
using TransformedValueType = decltype(f(*input.cbegin())); | |
Container<TransformedValueType> output; | |
std::transform(input.cbegin(), input.cend(), std::back_inserter(output), f); | |
return output; | |
} | |
template<template<class...> class Container, class Function, class... Ts> | |
requires (is_back_inserterable<Container<Ts...>> && is_elements_iterable<Container<Ts...>>) | |
auto recursive_transform(const Container<Ts...>& input, const Function& f) | |
{ | |
using TransformedValueType = decltype(recursive_transform(*input.cbegin(), f)); | |
Container<TransformedValueType> output; | |
std::transform(input.cbegin(), input.cend(), std::back_inserter(output), | |
[&](auto& element) | |
{ | |
return recursive_transform(element, f); | |
} | |
); | |
return output; | |
} | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
template<class T, class F> requires (is_multi_array<T>) | |
auto recursive_transform(const T& input, const F& f) | |
{ | |
boost::multi_array output(input); | |
for (decltype(+input.shape()[0]) i = 0; i < input.shape()[0]; i++) | |
{ | |
output[i] = recursive_transform(input[i], f); | |
} | |
return output; | |
} | |
#endif | |
#ifdef USE_BOOST_MULTIDIMENSIONAL_ARRAY | |
// Add operator | |
template<class T1, class T2> requires (is_multi_array<T1>&& is_multi_array<T2>) | |
auto operator+(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = input1[i] + input2[i]; | |
} | |
return output; | |
} | |
// Minus operator | |
template<class T1, class T2> requires (is_multi_array<T1>&& is_multi_array<T2>) | |
auto operator-(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = input1[i] - input2[i]; | |
} | |
return output; | |
} | |
template<typename T> | |
concept is_multiplicable = requires(T x) | |
{ | |
x * x; | |
}; | |
// Multiplication | |
template<class T1, class T2> requires (is_multiplicable<T1> && is_multiplicable<T2>) | |
auto element_wise_multiplication(const T1& input1, const T2& input2) | |
{ | |
return input1 * input2; | |
} | |
template<class T1, class T2> requires (is_multi_array<T1> && is_multi_array<T2>) | |
auto element_wise_multiplication(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = element_wise_multiplication(input1[i], input2[i]); | |
} | |
return output; | |
} | |
template<typename T> | |
concept is_divisible = requires(T x) | |
{ | |
x / x; | |
}; | |
// Division | |
template<class T1, class T2> requires (is_divisible<T1> && is_divisible<T2>) | |
auto element_wise_division(const T1& input1, const T2& input2) | |
{ | |
if (input2 == 0) | |
{ | |
throw std::logic_error("Divide by zero exception"); // Handle the case of dividing by zero exception | |
} | |
return input1 / input2; | |
} | |
template<class T1, class T2> requires (is_multi_array<T1> && is_multi_array<T2>) | |
auto element_wise_division(const T1& input1, const T2& input2) | |
{ | |
if (input1.num_dimensions() != input2.num_dimensions()) // Dimensions are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array dimensions are different"); | |
} | |
if (*input1.shape() != *input2.shape()) // Shapes are different, unable to perform element-wise add operation | |
{ | |
throw std::logic_error("Array shapes are different"); | |
} | |
boost::multi_array output(input1); | |
for (decltype(+input1.shape()[0]) i = 0; i < input1.shape()[0]; i++) | |
{ | |
output[i] = element_wise_division(input1[i], input2[i]); | |
} | |
return output; | |
} | |
// ones function | |
// Reference: https://stackoverflow.com/q/46595857/6667035 | |
template<class T, std::size_t NumDims> | |
auto ones(boost::detail::multi_array::extent_gen<NumDims> size) | |
{ | |
boost::multi_array<T, NumDims> output(size); | |
return recursive_transform(output, [](auto& x) { return 1; }); | |
} | |
// Sin function | |
template<typename T> | |
concept is_sinable = requires(T x) | |
{ | |
std::sin(x); | |
}; | |
template<class T> requires (is_sinable<T>) | |
auto sin(const T& input) | |
{ | |
return std::sin(input); | |
} | |
template<class T> requires (is_multi_array<T>) | |
auto sin(const T& input) | |
{ | |
boost::multi_array output(input); | |
for (decltype(+input.shape()[0]) i = 0; i < input.shape()[0]; i++) | |
{ | |
output[i] = sin(input[i]); | |
} | |
return output; | |
} | |
#endif | |
void recursive_transform_test1(); | |
int main() | |
{ | |
//recursive_transform_test1(); | |
auto A = ones<double, 3>(boost::extents[3][4][2]); | |
typedef decltype(A)::index index; | |
std::cout << "A:" << std::endl; | |
for (index i = 0; i != 3; ++i) | |
for (index j = 0; j != 4; ++j) | |
for (index k = 0; k != 2; ++k) | |
std::cout << A[i][j][k] << std::endl; | |
return 0; | |
} | |
void recursive_transform_test1() | |
{ | |
// std::vector<int> -> std::vector<std::string> | |
std::vector<int> test_vector = { | |
1, 2, 3 | |
}; | |
auto recursive_transform_result = recursive_transform( | |
test_vector, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result.at(0) << std::endl; // recursive_transform_result.at(0) is a std::string | |
// std::vector<std::vector<int>> -> std::vector<std::vector<std::string>> | |
std::vector<decltype(test_vector)> test_vector2 = { | |
test_vector, test_vector, test_vector | |
}; | |
auto recursive_transform_result2 = recursive_transform( | |
test_vector2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result2.at(0).at(0) << std::endl; // recursive_transform_result.at(0).at(0) is also a std::string | |
// std::deque<int> -> std::deque<std::string> | |
std::deque<int> test_deque; | |
test_deque.push_back(1); | |
test_deque.push_back(1); | |
test_deque.push_back(1); | |
auto recursive_transform_result3 = recursive_transform( | |
test_deque, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result3.at(0) << std::endl; | |
// std::deque<std::deque<int>> -> std::deque<std::deque<std::string>> | |
std::deque<decltype(test_deque)> test_deque2; | |
test_deque2.push_back(test_deque); | |
test_deque2.push_back(test_deque); | |
test_deque2.push_back(test_deque); | |
auto recursive_transform_result4 = recursive_transform( | |
test_deque2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result4.at(0).at(0) << std::endl; | |
// std::array<int, 10> -> std::array<std::string, 10> | |
std::array<int, 10> test_array; | |
for (int i = 0; i < 10; i++) | |
{ | |
test_array[i] = 1; | |
} | |
auto recursive_transform_result5 = recursive_transform( | |
test_array, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result5.at(0) << std::endl; | |
// std::array<std::array<int, 10>, 10> -> std::array<std::array<std::string, 10>, 10> | |
std::array<std::array<int, 10>, 10> test_array2; | |
for (int i = 0; i < 10; i++) | |
{ | |
test_array2[i] = test_array; | |
} | |
auto recursive_transform_result6 = recursive_transform( | |
test_array2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result6.at(0).at(0) << std::endl; | |
// std::list<int> -> std::list<std::string> | |
std::list<int> test_list = { 1, 2, 3, 4 }; | |
auto recursive_transform_result7 = recursive_transform( | |
test_list, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result7.front() << std::endl; | |
// std::list<std::list<int>> -> std::list<std::list<std::string>> | |
std::list<std::list<int>> test_list2 = { test_list, test_list, test_list, test_list }; | |
auto recursive_transform_result8 = recursive_transform( | |
test_list2, | |
[](int x)->std::string { return std::to_string(x); }); // For testing | |
std::cout << "string: " + recursive_transform_result8.front().front() << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment