Skip to content

Instantly share code, notes, and snippets.

@2b-t
Last active July 14, 2024 13:24
Show Gist options
  • Save 2b-t/50d85115db8b12ed263f8231abf07fa2 to your computer and use it in GitHub Desktop.
Save 2b-t/50d85115db8b12ed263f8231abf07fa2 to your computer and use it in GitHub Desktop.
MPI data type template in C++17
/**
* \file mpi_type.hpp
* \brief Function for automatically determining MPI data type to a constexpr
* \mainpage Contains a template function that helps determining the corresponding MPI message data type
* that can be found on \ref https://www.mpich.org/static/docs/latest/www3/Constants.html
* to a constexpr. This way the code may already be simplified at compile time.
*/
#ifndef MPI_TYPE_HPP_INCLUDED
#define MPI_TYPE_HPP_INCLUDED
#pragma once
#include <cassert>
#include <complex>
#include <cstdint>
#include <type_traits>
#include <mpi.h>
/**\fn mpi_get_type
* \brief Small template function to return the correct MPI_DATATYPE
* data type need for an MPI message as a constexpr at compile time
* https://www.mpich.org/static/docs/latest/www3/Constants.html
* Call in a template function with mpi_get_type<T>()
*
* \tparam T The C++ data type used in the MPI function
* \return The MPI_Datatype belonging to the template C++ data type T
*/
template <typename T>
[[nodiscard]] constexpr MPI_Datatype mpi_get_type() noexcept {
MPI_Datatype mpi_type = MPI_DATATYPE_NULL;
if constexpr (std::is_same_v<T, char>) {
mpi_type = MPI_CHAR;
} else if constexpr (std::is_same_v<T, signed char>) {
mpi_type = MPI_SIGNED_CHAR;
} else if constexpr (std::is_same_v<T, unsigned char>) {
mpi_type = MPI_UNSIGNED_CHAR;
} else if constexpr (std::is_same_v<T, wchar_t>) {
mpi_type = MPI_WCHAR;
} else if constexpr (std::is_same_v<T, signed short>) {
mpi_type = MPI_SHORT;
} else if constexpr (std::is_same_v<T, unsigned short>) {
mpi_type = MPI_UNSIGNED_SHORT;
} else if constexpr (std::is_same_v<T, signed int>) {
mpi_type = MPI_INT;
} else if constexpr (std::is_same_v<T, unsigned int>) {
mpi_type = MPI_UNSIGNED;
} else if constexpr (std::is_same_v<T, signed long int>) {
mpi_type = MPI_LONG;
} else if constexpr (std::is_same_v<T, unsigned long int>) {
mpi_type = MPI_UNSIGNED_LONG;
} else if constexpr (std::is_same_v<T, signed long long int>) {
mpi_type = MPI_LONG_LONG;
} else if constexpr (std::is_same_v<T, unsigned long long int>) {
mpi_type = MPI_UNSIGNED_LONG_LONG;
} else if constexpr (std::is_same_v<T, float>) {
mpi_type = MPI_FLOAT;
} else if constexpr (std::is_same_v<T, double>) {
mpi_type = MPI_DOUBLE;
} else if constexpr (std::is_same_v<T, long double>) {
mpi_type = MPI_LONG_DOUBLE;
} else if constexpr (std::is_same_v<T, std::int8_t>) {
mpi_type = MPI_INT8_T;
} else if constexpr (std::is_same_v<T, std::int16_t>) {
mpi_type = MPI_INT16_T;
} else if constexpr (std::is_same_v<T, std::int32_t>) {
mpi_type = MPI_INT32_T;
} else if constexpr (std::is_same_v<T, std::int64_t>) {
mpi_type = MPI_INT64_T;
} else if constexpr (std::is_same_v<T, std::uint8_t>) {
mpi_type = MPI_UINT8_T;
} else if constexpr (std::is_same_v<T, std::uint16_t>) {
mpi_type = MPI_UINT16_T;
} else if constexpr (std::is_same_v<T, std::uint32_t>) {
mpi_type = MPI_UINT32_T;
} else if constexpr (std::is_same_v<T, std::uint64_t>) {
mpi_type = MPI_UINT64_T;
} else if constexpr (std::is_same_v<T, bool>) {
mpi_type = MPI_C_BOOL;
} else if constexpr (std::is_same_v<T, std::complex<float>>) {
mpi_type = MPI_C_COMPLEX;
} else if constexpr (std::is_same_v<T, std::complex<double>>) {
mpi_type = MPI_C_DOUBLE_COMPLEX;
} else if constexpr (std::is_same_v<T, std::complex<long double>>) {
mpi_type = MPI_C_LONG_DOUBLE_COMPLEX;
}
assert(mpi_type != MPI_DATATYPE_NULL);
return mpi_type;
}
#endif // MPI_TYPE_HPP_INCLUDED
@mayeths
Copy link

mayeths commented Aug 22, 2022

For anyone wondering about the C++11 version of this snippet, the following code works for me. With mpic++ -std=c++11 -g main.c and objdump -d ./a.out, I have confirmed that GCC 9 will generate two short functions (about 7 instructions each) with -O2, and no function (that means, directly produces constant values) with -O3 optimization.

#include <type_traits>
#include <complex>
#include <mpi.h>
template <typename T>
static inline MPI_Datatype mpi_get_type() noexcept {
  MPI_Datatype mpi_type = MPI_DATATYPE_NULL;
  if (std::is_same<T, char>::value) {
    mpi_type = MPI_CHAR;
  } else if (std::is_same<T, signed char>::value) {
    mpi_type = MPI_SIGNED_CHAR;
  } else if (std::is_same<T, unsigned char>::value) {
    mpi_type = MPI_UNSIGNED_CHAR;
  } else if (std::is_same<T, wchar_t>::value) {
    mpi_type = MPI_WCHAR;
  } else if (std::is_same<T, signed short>::value) {
    mpi_type = MPI_SHORT;
  } else if (std::is_same<T, unsigned short>::value) {
    mpi_type = MPI_UNSIGNED_SHORT;
  } else if (std::is_same<T, signed int>::value) {
    mpi_type = MPI_INT;
  } else if (std::is_same<T, unsigned int>::value) {
    mpi_type = MPI_UNSIGNED;
  } else if (std::is_same<T, signed long int>::value) {
     mpi_type = MPI_LONG;
  } else if (std::is_same<T, unsigned long int>::value) {
    mpi_type = MPI_UNSIGNED_LONG;
  } else if (std::is_same<T, signed long long int>::value) {
    mpi_type = MPI_LONG_LONG;
  } else if (std::is_same<T, unsigned long long int>::value) {
    mpi_type = MPI_UNSIGNED_LONG_LONG;
  } else if (std::is_same<T, float>::value) {
    mpi_type = MPI_FLOAT;
  } else if (std::is_same<T, double>::value) {
    mpi_type = MPI_DOUBLE;
  } else if (std::is_same<T, long double>::value) {
    mpi_type = MPI_LONG_DOUBLE;
  } else if (std::is_same<T, std::int8_t>::value) {
    mpi_type = MPI_INT8_T;
  } else if (std::is_same<T, std::int16_t>::value) {
    mpi_type = MPI_INT16_T;
  } else if (std::is_same<T, std::int32_t>::value) {
    mpi_type = MPI_INT32_T;
  } else if (std::is_same<T, std::int64_t>::value) {
    mpi_type = MPI_INT64_T;
  } else if (std::is_same<T, std::uint8_t>::value) {
    mpi_type = MPI_UINT8_T;
  } else if (std::is_same<T, std::uint16_t>::value) {
    mpi_type = MPI_UINT16_T;
  } else if (std::is_same<T, std::uint32_t>::value) {
    mpi_type = MPI_UINT32_T;
  } else if (std::is_same<T, std::uint64_t>::value) {
    mpi_type = MPI_UINT64_T;
  } else if (std::is_same<T, bool>::value) {
    mpi_type = MPI_C_BOOL;
  } else if (std::is_same<T, std::complex<float>>::value) {
    mpi_type = MPI_C_COMPLEX;
  } else if (std::is_same<T, std::complex<double>>::value) {
    mpi_type = MPI_C_DOUBLE_COMPLEX;
  } else if (std::is_same<T, std::complex<long double>>::value) {
    mpi_type = MPI_C_LONG_DOUBLE_COMPLEX;
  }
  return mpi_type;
}

#include <stdio.h>
int main()
{
  printf("%d\n", mpi_get_type<int>());
  printf("%d\n", mpi_get_type<double>());
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment