Skip to content

Instantly share code, notes, and snippets.

@rabernat
Created February 13, 2017 20:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rabernat/3c488b5b12f34a05027770e1cc6ebae6 to your computer and use it in GitHub Desktop.
Save rabernat/3c488b5b12f34a05027770e1cc6ebae6 to your computer and use it in GitHub Desktop.
Test script for xmitgcm llc read speed (deployed on Pleaides)
import random
import os
import time
import numpy as np
from memory_profiler import memory_usage
import psutil
import socket
import json
import datetime
import argparse
from collections import OrderedDict
from xmitgcm.utils import read_3d_llc_data
def random_file(do3d):
nx = 4320
ddir = '/u/dmenemen/llc_4320/MITgcm/run/'
dtype = '>f4'
if do3d:
prefixes = ['Salt', 'Theta', 'U', 'V', 'W']
nz = 90
else:
prefixes = ['oceFWflx', 'oceQnet', 'oceQsw', 'Eta']
nz = 1
nfiles = 3312
iter0 = 10368
delta = 144
all_iters = iter0 + delta*np.arange(nfiles)
# choose
this_prefix = random.choice(prefixes)
this_iter = random.choice(all_iters)
return (os.path.join(ddir, '%s.%010d.data' % (this_prefix, this_iter)),
nz, nx, dtype)
def load_data(method, do3d):
fname, nz, nx, dtype = random_file(do3d)
data = read_3d_llc_data(fname, nz, nx, dtype=dtype, method=method)
return data
def timing_val(func):
def wrapper(*arg, **kw):
process = psutil.Process(os.getpid())
mi_start = process.memory_info()
t1 = time.time()
#res = func(*arg, **kw)
mem_usage = memory_usage((func, arg, kw), interval=1)
t2 = time.time()
mi_end = process.memory_info()
# basic system output info
# function, method, 3d, timestamp, hostname
info = OrderedDict(function=func.__name__, method=arg[0], do3d=arg[1],
timestamp=datetime.datetime.now().isoformat(),
hostname=socket.gethostname(),
num_dask_workers=arg[2])
info['walltime'] = t2 - t1
info['mem_rss'] = mi_end.rss - mi_start.rss
info['mem_vms'] = mi_end.vms - mi_start.vms
info['max_mem_usage'] = max(mem_usage)
print(json.dumps(info))
return
return wrapper
@timing_val
def load_with_numpy(method, do3d, num_workers):
fname, nz, nx, dtype = random_file(do3d)
return np.fromfile(fname, dtype=dtype)
@timing_val
def load_simple(method, do3d, num_workers):
return load_data(method, do3d)
@timing_val
def load_data_in_memory(method, do3d, num_workers):
data = load_data(method, do3d)
return data.compute(num_workers=num_workers)
@timing_val
def compute_sum(method, do3d, num_workers):
data = load_data(method, do3d)
return data.sum().compute(num_workers=num_workers)
@timing_val
def extract_chunk(method, do3d, num_workers):
data = load_data(method, do3d)
# either 2d or 3d, depending on data shape
region = data[...,2,1080:2160,1080:2160]
return region.compute(num_workers=num_workers)
# argument parsing
parser = argparse.ArgumentParser(description='Benchmark LLC File Loading')
parser.add_argument('--method', choices=['bigchunks', 'smallchunks'],
default="smallchunks")
parser.add_argument('--do3d', action="store_true", help="load 3d files")
parser.add_argument('--num_dask_workers', type=int, default=1)
def main():
pargs = parser.parse_args()
args = (pargs.method, pargs.do3d, pargs.num_dask_workers)
load_simple(*args)
load_with_numpy(*args)
extract_chunk(*args)
compute_sum(*args)
if not pargs.do3d:
# literally takes more than 12 hours!
# (or runs out of memory)
load_data_in_memory(*args)
if __name__ == "__main__":
# execute only if run as a script
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment