Skip to content

Instantly share code, notes, and snippets.

@ashao
Last active July 10, 2024 20:42
Show Gist options
  • Save ashao/f52fb5e57e42d062aad845bdd930cbec to your computer and use it in GitHub Desktop.
Save ashao/f52fb5e57e42d062aad845bdd930cbec to your computer and use it in GitHub Desktop.
program mpi_array_exchange
use mpi
use iso_fortran_env, only : real64
implicit none
integer, parameter :: num_count = 9
integer, parameter :: num_iter = 50
integer, dimension(num_count), parameter :: counts =[1, 2, 4, 8, 16, 32, 64, 128, 256] ! How many arrays to send
integer, parameter :: num_elements = 64000 ! For double precision, corresponds to 512KB array
integer :: ierr, rank, comm_size
integer :: leg1_send, leg1_receive
integer :: leg2_send, leg2_receive
real(kind=real64), dimension(:,:), allocatable :: send_array, recv_array
integer, allocatable, dimension(:) :: rank_ids
integer, allocatable, dimension(:,:) :: rank_pairs
integer :: i, j
integer :: timing_unit = 10
character(len=50) :: filename
real(kind=real64) :: start_time, end_time
real(kind=real64) :: elapsed_time
call mpi_init(ierr)
call mpi_comm_rank(MPI_COMM_WORLD, rank, ierr)
call mpi_comm_size(MPI_COMM_WORLD, comm_size, ierr)
if (mod(comm_size,2)/= 0) then
print *, "Must run this with an even number of processors"
call mpi_abort(MPI_COMM_WORLD, 1, ierr)
end if
! Assume that ranks fill up a node, only works for 2 nodes
rank_ids = [(i-1, i=1, comm_size)]
rank_pairs = reshape(rank_ids, [comm_size/2, 2])
! Loop through and find the paired rank for each leg
do i = 1, comm_size/2
if (rank == rank_pairs(i,1)) then
leg1_send = rank_pairs(i,1)
leg1_receive = rank_pairs(i,2)
leg2_send = rank_pairs(i,2)
leg2_receive = rank_pairs(i,1)
end if
if (rank == rank_pairs(i,2)) then
leg1_send = rank_pairs(i,1)
leg1_receive = rank_pairs(i,2)
leg2_send = rank_pairs(i,2)
leg2_receive = rank_pairs(i,1)
end if
enddo
! Write the roundtrip time only on the original senders
if (rank == leg1_send) then
write(filename, "(A,I0.3,A)") "timing_", rank, ".log"
open(unit=timing_unit, file=filename, status="replace")
write(timing_unit, "(A)") "count iteration elapsed_time"
endif
! Loop over the different array counts
do i = 1, size(counts)
allocate(send_array(counts(i), num_elements))
allocate(recv_array(counts(i), num_elements))
do j = 1, num_iter
call random_number(send_array)
recv_array(:,:) = 0.
call MPI_BARRIER(MPI_COMM_WORLD, ierr)
start_time = mpi_wtime()
! Do the outbound leg
if (rank == leg1_send) then
call mpi_send(send_array, size(send_array), mpi_double, leg1_receive, 1, MPI_COMM_WORLD, ierr)
else
call mpi_recv(recv_array, size(recv_array), mpi_double, leg1_send, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE, ierr)
endif
if (ierr /=0 ) call MPI_ABORT(MPI_COMM_WORLD, 1, ierr)
! Do the return leg
if (rank == leg2_send) then
call mpi_send(send_array, size(send_array), mpi_double, leg2_receive, 1, MPI_COMM_WORLD, ierr)
else
call mpi_recv(recv_array, size(recv_array), mpi_double, leg2_send, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE, ierr)
endif
if (ierr /=0 ) call MPI_ABORT(MPI_COMM_WORLD, 1, ierr)
end_time = mpi_wtime()
if (rank==leg1_send) then
write(timing_unit, "(I,X,I,X,E15.7)") counts(i), j, end_time - start_time
endif
enddo
deallocate(send_array)
deallocate(recv_array)
enddo
call mpi_finalize(ierr)
end program mpi_array_exchange
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment