Skip to content

Instantly share code, notes, and snippets.

@mikaem
Created February 26, 2019 10:53
Show Gist options
  • Save mikaem/e093894fba07234c01f23bb9176830b4 to your computer and use it in GitHub Desktop.
Save mikaem/e093894fba07234c01f23bb9176830b4 to your computer and use it in GitHub Desktop.
Test speed of subclassed Numpy array DistributedArray compared to pure Numpy array
import numpy as np
from mpi4py import MPI
from mpi4py_fft import PFFT, DistributedArray, newDarray, Function
rank = 1
N = (32, 32, 32)
FFT = PFFT(MPI.COMM_WORLD, N)
M = (3,)*rank + N
# Create 3 versions of array
a0 = newDarray(FFT, False, val=0, rank=rank)
a1 = a0.__array__()
a2 = np.zeros(M)
b0 = newDarray(FFT, False, val=1, rank=rank)
b1 = b0.__array__()
b2 = np.ones(M)
c0 = newDarray(FFT, False, val=1, rank=rank)
c1 = c0.__array__()
c2 = np.ones(M)
def add(a, b, c):
a += b*c
return a
def add2(a, b, c):
a += 2.0*b*c
return a
if __name__ == '__main__':
import timeit
# I see no penalty for the add function
print("a += 1*1")
print(timeit.timeit("a0=add(a0, 1, 1)", number=1000, setup="from __main__ import add, a0"))
print(timeit.timeit("a1=add(a1, 1, 1)", number=1000, setup="from __main__ import add, a1"))
print(timeit.timeit("a2=add(a2, 1, 1)", number=1000, setup="from __main__ import add, a2"))
print("a += b*1")
print(timeit.timeit("a0=add(a0, b0, 1)", number=1000, setup="from __main__ import add, a0, b0"))
print(timeit.timeit("a1=add(a1, b1, 1)", number=1000, setup="from __main__ import add, a1, b1"))
print(timeit.timeit("a2=add(a2, b2, 1)", number=1000, setup="from __main__ import add, a2, b2"))
print("a += b*c")
print(timeit.timeit("a0=add(a0, b0, c0)", number=1000, setup="from __main__ import add, a0, b0, c0"))
print(timeit.timeit("a1=add(a1, b1, c1)", number=1000, setup="from __main__ import add, a1, b1, c1"))
print(timeit.timeit("a2=add(a2, b2, c2)", number=1000, setup="from __main__ import add, a2, b2, c2"))
# However, adding an extra constant 2.0 in front and the DistributedArray approach sees a significant penalty
# in the last two calls. Why?
print("a += 2.0*1*1")
print(timeit.timeit("a0=add2(a0, 1, 1)", number=1000, setup="from __main__ import add2, a0"))
print(timeit.timeit("a1=add2(a1, 1, 1)", number=1000, setup="from __main__ import add2, a1"))
print(timeit.timeit("a2=add2(a2, 1, 1)", number=1000, setup="from __main__ import add2, a2"))
print("a += 2.0*b*1")
print(timeit.timeit("a0=add2(a0, b0, 1)", number=1000, setup="from __main__ import add2, a0, b0"))
print(timeit.timeit("a1=add2(a1, b1, 1)", number=1000, setup="from __main__ import add2, a1, b1"))
print(timeit.timeit("a2=add2(a2, b2, 1)", number=1000, setup="from __main__ import add2, a2, b2"))
print("a += 2.0*b*c")
print(timeit.timeit("a0=add2(a0, b0, c0)", number=1000, setup="from __main__ import add2, a0, b0, c0"))
print(timeit.timeit("a1=add2(a1, b1, c1)", number=1000, setup="from __main__ import add2, a1, b1, c1"))
print(timeit.timeit("a2=add2(a2, b2, c2)", number=1000, setup="from __main__ import add2, a2, b2, c2"))
@mikaem
Copy link
Author

mikaem commented Feb 26, 2019

Running this in Ubuntu gives me only minor differences up to the last two calls, where the difference is significant. If I make the array smaller

a += 1*1
0.04665716199997405
0.0434702840002501
0.04557729599991944
a += b*1
0.11887742299995807
0.11035974699962026
0.11472923100018306
a += b*c
0.1345189290000235
0.12881515299977764
0.12775983599976826
a += 2.0*1*1
0.04476622500033045
0.04491703100029554
0.04409783199980666
a += 2.0*b*1
0.5386193060003279
0.15911475499979133
0.16141825300019264
a += 2.0*b*c
0.5064503859998695
0.1909908840002572
0.1974248540000189

If I make the array smaller (16, 16, 16), then the penalty is seen for all tests:

a += 1*1
0.008763114999965183
0.007290302999990672
0.007404803000099491
a += b*1
0.019433665999713412
0.01568768600009207
0.017227625000032276
a += b*c
0.02224193300025945
0.017224629999873287
0.019373073000224394
a += 2.0*1*1
0.008524932000000263
0.007692241999848193
0.007385493999663595
a += 2.0*b*1
0.02937988600024255
0.022421208000196202
0.02256454299958932
a += 2.0*b*c
0.03133991199956654
0.024102769999899465
0.024605184999927587

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment