Skip to content

Instantly share code, notes, and snippets.

@navidcy
Last active May 2, 2020 06:52
Show Gist options
  • Save navidcy/5ff351ad5b63b64cd428ed0cafe99629 to your computer and use it in GitHub Desktop.
Save navidcy/5ff351ad5b63b64cd428ed0cafe99629 to your computer and use it in GitHub Desktop.
a script to investigate slowdown for diagnostics when running on GPU
using FourierFlows, PyPlot, Printf, Random
using Random: seed!
using FFTW: rfft, irfft
import GeophysicalFlows.TwoDNavierStokes
import GeophysicalFlows.TwoDNavierStokes: energy, enstrophy
import GeophysicalFlows: peakedisotropicspectrum
# ## Choosing a device: CPU or GPU
dev = GPU() # Device (CPU/GPU)
n, L = 256, 2π # grid resolution and domain length
dt = 2e-3 # timestep
nsteps = 2000 # total number of steps
nsubs = 400 # number of steps between each plot
prob = TwoDNavierStokes.Problem(; nx=n, Lx=L, ny=n, Ly=L, dt=dt, stepper="FilteredRK4", dev=dev)
sol, cl, vs, gr, filter = prob.sol, prob.clock, prob.vars, prob.grid, prob.timestepper.filter
# Initial condidtion
seed!(1234)
k0, E0 = 6, 0.5
zetai = peakedisotropicspectrum(gr, k0, E0, mask=filter)
# Diagnostics
E = Diagnostic(energy, prob; nsteps=nsteps)
Z = Diagnostic(enstrophy, prob; nsteps=nsteps)
diags = [E, Z] # A list of Diagnostics types passed to "stepforward!" will be updated every timestep.
println("we time-step without diagnostics")
TwoDNavierStokes.set_zeta!(prob, zetai)
startwalltime = time()
while cl.step < nsteps
stepforward!(prob, nsubs)
log = @sprintf("step: %04d, t: %d, walltime: %.2f sec", cl.step, cl.t, (time()-startwalltime))
println(log)
end
println(" ")
println("re-initialize and run again with diagnostics")
prob = TwoDNavierStokes.Problem(; nx=n, Lx=L, ny=n, Ly=L, dt=dt, stepper="FilteredRK4", dev=dev)
sol, cl, vs, gr, filter = prob.sol, prob.clock, prob.vars, prob.grid, prob.timestepper.filter
TwoDNavierStokes.set_zeta!(prob, zetai)
startwalltime = time()
while cl.step < nsteps
stepforward!(prob, diags, nsubs)
log = @sprintf("step: %04d, t: %d, walltime: %.2f sec", cl.step, cl.t, (time()-startwalltime))
println(log)
end
println(" ")
println("finished")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment