Last active
July 14, 2024 13:24
-
-
Save 2b-t/50d85115db8b12ed263f8231abf07fa2 to your computer and use it in GitHub Desktop.
MPI data type template in C++17
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* \file mpi_type.hpp | |
* \brief Function for automatically determining MPI data type to a constexpr | |
* \mainpage Contains a template function that helps determining the corresponding MPI message data type | |
* that can be found on \ref https://www.mpich.org/static/docs/latest/www3/Constants.html | |
* to a constexpr. This way the code may already be simplified at compile time. | |
*/ | |
#ifndef MPI_TYPE_HPP_INCLUDED | |
#define MPI_TYPE_HPP_INCLUDED | |
#pragma once | |
#include <cassert> | |
#include <complex> | |
#include <cstdint> | |
#include <type_traits> | |
#include <mpi.h> | |
/**\fn mpi_get_type | |
* \brief Small template function to return the correct MPI_DATATYPE | |
* data type need for an MPI message as a constexpr at compile time | |
* https://www.mpich.org/static/docs/latest/www3/Constants.html | |
* Call in a template function with mpi_get_type<T>() | |
* | |
* \tparam T The C++ data type used in the MPI function | |
* \return The MPI_Datatype belonging to the template C++ data type T | |
*/ | |
template <typename T> | |
[[nodiscard]] constexpr MPI_Datatype mpi_get_type() noexcept { | |
MPI_Datatype mpi_type = MPI_DATATYPE_NULL; | |
if constexpr (std::is_same_v<T, char>) { | |
mpi_type = MPI_CHAR; | |
} else if constexpr (std::is_same_v<T, signed char>) { | |
mpi_type = MPI_SIGNED_CHAR; | |
} else if constexpr (std::is_same_v<T, unsigned char>) { | |
mpi_type = MPI_UNSIGNED_CHAR; | |
} else if constexpr (std::is_same_v<T, wchar_t>) { | |
mpi_type = MPI_WCHAR; | |
} else if constexpr (std::is_same_v<T, signed short>) { | |
mpi_type = MPI_SHORT; | |
} else if constexpr (std::is_same_v<T, unsigned short>) { | |
mpi_type = MPI_UNSIGNED_SHORT; | |
} else if constexpr (std::is_same_v<T, signed int>) { | |
mpi_type = MPI_INT; | |
} else if constexpr (std::is_same_v<T, unsigned int>) { | |
mpi_type = MPI_UNSIGNED; | |
} else if constexpr (std::is_same_v<T, signed long int>) { | |
mpi_type = MPI_LONG; | |
} else if constexpr (std::is_same_v<T, unsigned long int>) { | |
mpi_type = MPI_UNSIGNED_LONG; | |
} else if constexpr (std::is_same_v<T, signed long long int>) { | |
mpi_type = MPI_LONG_LONG; | |
} else if constexpr (std::is_same_v<T, unsigned long long int>) { | |
mpi_type = MPI_UNSIGNED_LONG_LONG; | |
} else if constexpr (std::is_same_v<T, float>) { | |
mpi_type = MPI_FLOAT; | |
} else if constexpr (std::is_same_v<T, double>) { | |
mpi_type = MPI_DOUBLE; | |
} else if constexpr (std::is_same_v<T, long double>) { | |
mpi_type = MPI_LONG_DOUBLE; | |
} else if constexpr (std::is_same_v<T, std::int8_t>) { | |
mpi_type = MPI_INT8_T; | |
} else if constexpr (std::is_same_v<T, std::int16_t>) { | |
mpi_type = MPI_INT16_T; | |
} else if constexpr (std::is_same_v<T, std::int32_t>) { | |
mpi_type = MPI_INT32_T; | |
} else if constexpr (std::is_same_v<T, std::int64_t>) { | |
mpi_type = MPI_INT64_T; | |
} else if constexpr (std::is_same_v<T, std::uint8_t>) { | |
mpi_type = MPI_UINT8_T; | |
} else if constexpr (std::is_same_v<T, std::uint16_t>) { | |
mpi_type = MPI_UINT16_T; | |
} else if constexpr (std::is_same_v<T, std::uint32_t>) { | |
mpi_type = MPI_UINT32_T; | |
} else if constexpr (std::is_same_v<T, std::uint64_t>) { | |
mpi_type = MPI_UINT64_T; | |
} else if constexpr (std::is_same_v<T, bool>) { | |
mpi_type = MPI_C_BOOL; | |
} else if constexpr (std::is_same_v<T, std::complex<float>>) { | |
mpi_type = MPI_C_COMPLEX; | |
} else if constexpr (std::is_same_v<T, std::complex<double>>) { | |
mpi_type = MPI_C_DOUBLE_COMPLEX; | |
} else if constexpr (std::is_same_v<T, std::complex<long double>>) { | |
mpi_type = MPI_C_LONG_DOUBLE_COMPLEX; | |
} | |
assert(mpi_type != MPI_DATATYPE_NULL); | |
return mpi_type; | |
} | |
#endif // MPI_TYPE_HPP_INCLUDED |
For anyone wondering about the C++11 version of this snippet, the following code works for me. With mpic++ -std=c++11 -g main.c
and objdump -d ./a.out
, I have confirmed that GCC 9 will generate two short functions (about 7 instructions each) with -O2
, and no function (that means, directly produces constant values) with -O3
optimization.
#include <type_traits>
#include <complex>
#include <mpi.h>
template <typename T>
static inline MPI_Datatype mpi_get_type() noexcept {
MPI_Datatype mpi_type = MPI_DATATYPE_NULL;
if (std::is_same<T, char>::value) {
mpi_type = MPI_CHAR;
} else if (std::is_same<T, signed char>::value) {
mpi_type = MPI_SIGNED_CHAR;
} else if (std::is_same<T, unsigned char>::value) {
mpi_type = MPI_UNSIGNED_CHAR;
} else if (std::is_same<T, wchar_t>::value) {
mpi_type = MPI_WCHAR;
} else if (std::is_same<T, signed short>::value) {
mpi_type = MPI_SHORT;
} else if (std::is_same<T, unsigned short>::value) {
mpi_type = MPI_UNSIGNED_SHORT;
} else if (std::is_same<T, signed int>::value) {
mpi_type = MPI_INT;
} else if (std::is_same<T, unsigned int>::value) {
mpi_type = MPI_UNSIGNED;
} else if (std::is_same<T, signed long int>::value) {
mpi_type = MPI_LONG;
} else if (std::is_same<T, unsigned long int>::value) {
mpi_type = MPI_UNSIGNED_LONG;
} else if (std::is_same<T, signed long long int>::value) {
mpi_type = MPI_LONG_LONG;
} else if (std::is_same<T, unsigned long long int>::value) {
mpi_type = MPI_UNSIGNED_LONG_LONG;
} else if (std::is_same<T, float>::value) {
mpi_type = MPI_FLOAT;
} else if (std::is_same<T, double>::value) {
mpi_type = MPI_DOUBLE;
} else if (std::is_same<T, long double>::value) {
mpi_type = MPI_LONG_DOUBLE;
} else if (std::is_same<T, std::int8_t>::value) {
mpi_type = MPI_INT8_T;
} else if (std::is_same<T, std::int16_t>::value) {
mpi_type = MPI_INT16_T;
} else if (std::is_same<T, std::int32_t>::value) {
mpi_type = MPI_INT32_T;
} else if (std::is_same<T, std::int64_t>::value) {
mpi_type = MPI_INT64_T;
} else if (std::is_same<T, std::uint8_t>::value) {
mpi_type = MPI_UINT8_T;
} else if (std::is_same<T, std::uint16_t>::value) {
mpi_type = MPI_UINT16_T;
} else if (std::is_same<T, std::uint32_t>::value) {
mpi_type = MPI_UINT32_T;
} else if (std::is_same<T, std::uint64_t>::value) {
mpi_type = MPI_UINT64_T;
} else if (std::is_same<T, bool>::value) {
mpi_type = MPI_C_BOOL;
} else if (std::is_same<T, std::complex<float>>::value) {
mpi_type = MPI_C_COMPLEX;
} else if (std::is_same<T, std::complex<double>>::value) {
mpi_type = MPI_C_DOUBLE_COMPLEX;
} else if (std::is_same<T, std::complex<long double>>::value) {
mpi_type = MPI_C_LONG_DOUBLE_COMPLEX;
}
return mpi_type;
}
#include <stdio.h>
int main()
{
printf("%d\n", mpi_get_type<int>());
printf("%d\n", mpi_get_type<double>());
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example usage in templated
MPI_Alltoall
: