Last active
March 6, 2019 17:07
-
-
Save kuroko1t/c84c4149f0850e00029a96a574287eb4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "mpi.h" | |
#include "iostream" | |
#include "vector" | |
#include <stdio.h> | |
#include <cstdlib> | |
using namespace std; | |
#define MPI_CHECK(op) \ | |
{ \ | |
auto mpi_result = (op); \ | |
if (mpi_result != MPI_SUCCESS) { \ | |
cout << __LINE__ << endl; \ | |
abort(); \ | |
} \ | |
} | |
vector<int> allreduce(vector<int> &in, vector<int> &out, int size) { | |
MPI_Allreduce(&in.front(), &out.front(), size, MPI_INT, MPI_SUM, MPI_COMM_WORLD); | |
return out; | |
} | |
vector<int> allreduce_hierachical(vector<int> &in, vector<int> &out, | |
int size, int rank, int group) { | |
if ((size < group) || (size % group != 0) ) { | |
cout << "size:" << size << "group:" << group << endl; | |
abort(); | |
} | |
MPI_Comm local_comm; | |
int color = rank / group; | |
MPI_Comm_split(MPI_COMM_WORLD, color, rank, &local_comm); | |
int local_size, local_rank; | |
MPI_Comm_size(local_comm, &local_size); | |
MPI_Comm_rank(local_comm, &local_rank); | |
MPI_Reduce(&in.front(), &out.front(), size, MPI_INT, MPI_SUM, 0, local_comm); | |
MPI_Comm group_comm; | |
MPI_Comm_split(MPI_COMM_WORLD, local_rank, 0, &group_comm); | |
auto out1 = out; | |
if (local_rank == 0) { | |
MPI_Allreduce(&out.front(), &out1.front(), size, MPI_INT, MPI_SUM, group_comm); | |
} | |
MPI_Bcast(&out1.front(), size, MPI_INT, 0, local_comm); | |
return out1; | |
} | |
vector<int> allreduce_ring(vector<int> &in, vector<int> &out, int size, int rank) { | |
int send_rank, recv_rank; | |
auto out1 = in; | |
auto out_tmp = in; | |
if (size == rank + 1) { | |
send_rank = 0; | |
} else { | |
send_rank = rank + 1; | |
} | |
if (rank -1 == -1) { | |
recv_rank = size -1; | |
} else { | |
recv_rank = rank - 1; | |
} | |
int64_t data_send, data_recv; | |
data_send = rank; | |
data_recv = rank -1; | |
for (int i = 0;i < size-1; i++) { | |
if (data_send > size-1) { | |
data_send = 0; | |
} | |
if (data_recv < 0) { | |
data_recv = size-1; | |
} | |
if (data_recv > size -1) { | |
data_recv = 0; | |
} | |
// ring reduce for P-1 times | |
if (rank %2 == 0) { | |
MPI_CHECK(MPI_Send(&out1.front() + int64_t(data_send), | |
1, MPI_INT, send_rank, rank, MPI_COMM_WORLD)); | |
MPI_CHECK(MPI_Recv(&out_tmp.front() + int64_t(data_recv), | |
1, MPI_INT, recv_rank, recv_rank, MPI_COMM_WORLD, NULL)); | |
out1[data_recv] += out_tmp[data_recv]; | |
} else { | |
MPI_CHECK(MPI_Recv(&out_tmp.front()+ int64_t(data_recv), | |
1, MPI_INT, recv_rank, recv_rank, MPI_COMM_WORLD, NULL)); | |
MPI_CHECK(MPI_Send(&out1.front() + int64_t(data_send), | |
1, MPI_INT, send_rank, rank, MPI_COMM_WORLD)); | |
out1[data_recv] += out_tmp[data_recv]; | |
} | |
data_send = data_recv; | |
data_recv -= 1; | |
} | |
// broadcast sum data | |
int bpos = 0; | |
for (int i=0; i < size; i++) { | |
if (bpos > size -1) { | |
bpos = 0; | |
} | |
MPI_CHECK(MPI_Bcast(&out1.front() + int64_t(bpos+1), | |
1, MPI_INT, bpos, MPI_COMM_WORLD)); | |
bpos +=1; | |
} | |
return out1; | |
} | |
int main(int argc, char* argv[]) { | |
MPI_Init(&argc, &argv); | |
int rank; | |
int size; | |
MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
MPI_Comm_size(MPI_COMM_WORLD, &size); | |
vector<int> in(size); | |
vector<int> out_normal(size); | |
vector<int> out_hie(size); | |
vector<int> out_ring(size); | |
int root = 0; | |
for (int i = 0; i < size; i++) { | |
in[i] = i; | |
} | |
out_normal = allreduce(in, out_normal, size); | |
out_hie = allreduce_hierachical(in, out_hie, size, rank, 3); | |
out_ring = allreduce_ring(in, out_ring, size, rank); | |
for (int i = 0; i < size; i++) { | |
cout << "out_normal:" << i<< ":" << out_normal[i] << endl; | |
cout << "out_hie:" << i<< ":" << out_hie[i] << endl; | |
cout << "out_ring:" << i<< ":" << out_ring[i] << endl; | |
} | |
cout << endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment