Replacing a nested loop with a custom ABCD transform - https://stackoverflow.com/questions/75144985/is-it-possible-to-optimize-these-fortran-loops
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!> Example of replacing a complex loop nest with a solution | |
! based on smaller sub-programs. | |
! | |
! For more context visit the following page: | |
! | |
! https://stackoverflow.com/questions/75144985/is-it-possible-to-optimize-these-fortran-loops | |
! | |
! Usage instructions are given just above the main program. | |
! A Makefile is available to compile the executable. | |
! | |
! Changelog: | |
! - 2023-01-29: renaming | |
! - 2023-01-28: removed redundant lbound and ubound; added storage information | |
! - 2023-01-27: first re-engineered version | |
! | |
module abcd_transform | |
implicit none | |
private | |
public :: abcd, dp | |
integer, parameter :: dp = kind(1.0d0) | |
contains | |
! | |
! a_ij := a_ij + beta * b_ij | |
! | |
pure subroutine apb(A,B,beta) | |
real(dp), intent(inout) :: A(:,:) | |
real(dp), intent(in) :: B(:,:) | |
real(dp), intent(in) :: beta | |
A = A + beta*B | |
end subroutine | |
! | |
! a_ij := a_ij + b_ijkl c_kl | |
! | |
pure subroutine reduce_b(A,B,C) | |
real(dp), intent(inout) :: A(:,:) | |
real(dp), intent(in) :: B(:,:,:,:) | |
real(dp), intent(in) :: C(:,:) | |
integer :: k, l | |
do l = 1, size(B,4) | |
do k = 1, size(B,3) | |
call apb( A, B(:,:,k,l), C(k,l) ) | |
end do | |
end do | |
end subroutine | |
! | |
! A_n := A_n + B1_cd C_cdn + B2_kl D_kln | |
! | |
! The elements of A_n are a_ijn | |
! The elements of B1_cd are B1_ijcd | |
! The elements of B2_kl are B2_ijkl | |
! | |
subroutine abcd(A,B1,C,B2,D) | |
real(dp), intent(inout), contiguous :: A(:,:,:) | |
real(dp), intent(in) :: B1(:,:,:,:) | |
real(dp), intent(in) :: B2(:,:,:,:) | |
real(dp), intent(in), contiguous, target :: C(:,:), D(:,:) | |
real(dp), pointer :: p_C(:,:,:) => null() | |
real(dp), pointer :: p_D(:,:,:) => null() | |
integer :: k | |
integer :: nc, nd | |
nc = size(B1,3)*size(B1,4) | |
nd = size(B2,3)*size(B2,4) | |
if (nc /= size(C,1)) then | |
error stop "FATAL ERROR: Dimension mismatch between B1 and C" | |
end if | |
if (nd /= size(D,1)) then | |
error stop "FATAL ERROR: Dimension mismatch between B2 and D" | |
end if | |
! Pointer remapping of arrays C and D to rank-3 | |
p_C(1:size(B1,3),1:size(B1,4),1:size(C,2)) => C | |
p_D(1:size(B2,3),1:size(B2,4),1:size(D,2)) => D | |
!$omp parallel do default(private) shared(A,B1,p_C,B2,p_D) | |
do k = 1, size(A,3) | |
call reduce_b( A(:,:,k), B1, p_C(:,:,k)) | |
call reduce_b( A(:,:,k), B2, p_D(:,:,k)) | |
end do | |
!$omp end parallel do | |
end subroutine | |
end module | |
! Usage: | |
! | |
! N=100 ORIG=0|1 ./abcd | |
! | |
! N determines the problem size (a positive integer) | |
! ORIG whether to run the original loop nest (0: false, 1: true) | |
! | |
! Adjust OMP_NUM_THREADS to control how many threads are used. | |
! | |
program main | |
!$ use omp_lib, only: omp_get_wtime | |
use abcd_transform, only: abcd, dp | |
implicit none | |
! n0 [2,10] | |
! N [100,200] | |
! | |
integer, parameter :: n0 = 6 | |
integer, parameter :: n00 = n0*n0 | |
integer :: N, nVV | |
real(dp), allocatable :: A(:,:,:), B(:,:,:,:), C(:,:), D(:,:) | |
real(dp) :: t1, t2, t3 | |
integer :: nseed | |
integer, allocatable :: seed(:) | |
logical :: with_original | |
! --- | |
call random_seed(size=nseed) | |
allocate(seed(nseed)) | |
seed = 11327317 | |
call random_seed(put=seed) | |
! --- | |
call read_settings(N,with_original) | |
nVV = (N - n0)**2 | |
! --- | |
allocate(A(N,N,nVV)) | |
allocate(B(N,N,N,N)) | |
allocate(C(nVV,nVV)) | |
allocate(D(n00,nVV)) | |
print *, "Memory occupied (MB): ", & | |
real(sizeof(A) + sizeof(B) + sizeof(C) + sizeof(D),dp) / 1024._dp**2 | |
call random_number(B) | |
call random_number(C) | |
call random_number(D) | |
! --- | |
if (with_original) then | |
call cpu_time(t1) | |
!$ t1 = omp_get_wtime() | |
A = 0 | |
call original(N,nVV,n0,n00,A,B,C,D) | |
print *, A(1,1,1), A(1,1,2) | |
call cpu_time(t2) | |
!$ t2 = omp_get_wtime() | |
print *, "Time original: ", t2 - t1 | |
end if | |
! --- | |
call cpu_time(t1) | |
!$ t1 = omp_get_wtime() | |
A = 0 | |
call abcd(A=A, & | |
B1=B(:,:,n0+1:N,n0+1:N), & | |
B2=B(:,:,1:n0,1:n0), & | |
C=C, & | |
D=D) | |
print *, A(1,1,1), A(1,1,2) | |
call cpu_time(t3) | |
!$ t3 = omp_get_wtime() | |
print *, "Time new: ", t3 - t1 | |
deallocate(A,B,C,D) | |
contains | |
! We use environment variables to adjust the problem size | |
! (I couldn't be bothered to parse command line arguments.) | |
! | |
subroutine read_settings(N,with_orig) | |
integer, intent(out) :: N | |
logical, intent(out) :: with_orig | |
character(len=16) :: with_orig_str, N_str | |
call get_environment_variable("ORIG",with_orig_str) | |
with_orig = .false. | |
if (trim(with_orig_str) == '1') then | |
with_orig = .true. | |
end if | |
call get_environment_variable("N",N_str) | |
N = 100 | |
if (trim(N_str) /= '') then | |
read(N_str,*) N | |
end if | |
end subroutine | |
! This was the original loop nest given in the StackOverflow question. | |
! | |
! The changes I made were: | |
! - prepending the letter i to the loop variables c,d,l,k | |
! - switching the order of the loops id,ic and il,ik; this has the | |
! effect of "transposing" C along the first dimension, which | |
! is actually used a linearized subscript of a rank-2 array | |
! | |
subroutine original(N,nVV,n0,n00,A,B,C,D) | |
implicit none | |
integer, intent(in) :: N, nVV, n0, n00 | |
real(dp), intent(inout) :: A(N,N,nVV) | |
real(dp), intent(in) :: B(N,N,N,N) | |
real(dp), intent(in) :: C(nVV,nVV), D(n00,nVV) | |
integer :: p, q, ab | |
integer :: cd, kl | |
integer :: ic, id, ik, il | |
do p=1,N | |
do q=1,N | |
do ab=1,nVV | |
cd = 0 | |
do id=n0+1,N | |
do ic=n0+1,N | |
cd = cd + 1 | |
A(p,q,ab) = A(p,q,ab) + B(p,q,ic,id)*C(cd,ab) | |
end do | |
end do | |
kl = 0 | |
do il=1,n0 | |
do ik=1,n0 | |
kl = kl + 1 | |
A(p,q,ab) = A(p,q,ab) + B(p,q,ik,il)*D(kl,ab) | |
end do | |
end do | |
end do | |
end do | |
end do | |
end subroutine | |
end program |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FC=gfortran | |
FCFLAGS=-Wall -O2 -fopenmp | |
.phony: all clean | |
all: abcd | |
abcd: abcd.F90 | |
$(FC) $(FCFLAGS) -o $@ $< | |
clean: | |
rm -f abcd *.mod | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment