Last active
June 13, 2016 20:22
-
-
Save JesseLivezey/12e0e320960a58c278138d402f724fca to your computer and use it in GitHub Desktop.
CorrMM OMP and BLAS timings script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse, time, os, subprocess | |
import numpy as np | |
import theano | |
from theano.tensor.nnet.corr import (CorrMM, CorrMM_gradInputs, | |
CorrMM_gradWeights) | |
from theano.compat.python2x import OrderedDict | |
n_avg = 10 | |
n_images = 128 | |
image_shapes = [(n_images, 3, 128, 128), (n_images, 85, 2, 258)] | |
filter_shapes = [(96, 3, 5, 5), (64, 85, 2, 20)] | |
map_shapes = [(n_images, 96, 124, 124), (n_images, 64, 1, 239)] | |
shape_names = ['Imagenet-like', 'Spectrogram-like'] | |
op_names = ['Legacy', 'CorrMM'] | |
def time_op(data_type, direction): | |
if data_type == 'image': | |
imgshp = image_shapes[0] | |
filshp = filter_shapes[0] | |
mapshp = map_shapes[0] | |
else: | |
imgshp = image_shapes[1] | |
filshp = filter_shapes[1] | |
mapshp = map_shapes[1] | |
images = np.random.randn(*imgshp).astype('float32') | |
filters = np.random.randn(*filshp).astype('float32') | |
maps = np.random.randn(*mapshp).astype('float32') | |
images_sym = theano.shared(images) | |
filters_sym = theano.shared(filters) | |
maps_sym = theano.shared(maps) | |
updates = OrderedDict() | |
if direction == 'gradInputs': | |
op = CorrMM_gradInputs() | |
out = op(filters_sym, maps_sym) | |
updates[images_sym] = out | |
if direction == 'gradInputs': | |
op = CorrMM_gradWeights() | |
out = op(images_sym, maps_sym) | |
updates[filters_sym] = out | |
else: | |
op = CorrMM() | |
out = op(images_sym, filters_sym) | |
updates[maps_sym] = out | |
omp_threads = os.environ.get('OMP_NUM_THREADS', None) | |
mkl_threads = os.environ.get('MKL_NUM_THREADS', None) | |
f = theano.function([], [], updates=updates) | |
f() | |
start = time.time() | |
for ii in range(n_avg): | |
f() | |
avg_time = (time.time()-start)/float(n_avg) | |
return avg_time, mkl_threads, omp_threads | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Time conv w.r.t. threads.') | |
parser.add_argument('max_threads', type=int, | |
help='max number of threads') | |
args = parser.parse_args() | |
max_threads = args.max_threads | |
log_threads = int(np.log2(max_threads)) | |
if pow(2, log_threads) == max_threads: | |
log_threads -= 1 | |
threads = np.power(2, np.arange(1, log_threads+1)).tolist() + [max_threads] | |
omp_flags = """export OMP_NUM_THREADS={0}; """ | |
blas_flags = """export MKL_NUM_THREADS={0}; """ | |
theano_flags = """THEANO_FLAGS='floatX=float32,device=cpu,openmp={0}' """ | |
base_cmd = """python -c 'from conv_timing import time_op; print(time_op({0}, {1}))'""" | |
def parse(rval): | |
rval = rval[1:-2] | |
rval = rval.split(',') | |
rval = [r.split("'") for r in rval] | |
results = [float(rval[0][0])] | |
for r in rval[1:]: | |
if r[0] == ' None': | |
results.append(None) | |
else: | |
results.append(int(r[1])) | |
return results | |
def print_results(blas, omp, base_t, t, notes): | |
print('{0: <10}{1: <10}{2: <10}{3: <10.2f}{4: <10}'.format(blas, omp, | |
int(t*1000), float(base_t)/t, notes)) | |
data_types = ['"image"', '"spect"'] | |
directions = ['"forward"', '"gradInputs"', '"gradWeights"'] | |
for d in directions: | |
for ii, dt in enumerate(data_types): | |
print('Op: {}'.format(d.split('"')[1])) | |
print('Image shape, filter shape') | |
print('{} {}'.format(image_shapes[ii], filter_shapes[ii])) | |
print(shape_names[ii]) | |
print('{0: <10}{1: <10}{2: <10}{3: <10}{4: <10}'.format('MKL', 'OpenMP', 'Time(ms)', 'Speedup', 'Notes')) | |
cmd = (blas_flags.format(max_threads) + | |
theano_flags.format(False)+base_cmd.format(dt, d)) | |
base_notes = 'BLAS only, baseline' | |
rval = subprocess.check_output(cmd, shell=True) | |
base_t, base_blas, base_omp = parse(rval) | |
cmd = (blas_flags.format(1) + | |
theano_flags.format(False)+base_cmd.format(dt, d)) | |
notes = 'single thread' | |
rval = subprocess.check_output(cmd, shell=True) | |
t, blas, omp = parse(rval) | |
print_results(blas, omp, base_t, t, notes) | |
print_results(base_blas, base_omp, base_t, base_t, base_notes) | |
cmd = (omp_flags.format(max_threads) + | |
theano_flags.format(True)+base_cmd.format(dt, d)) | |
notes = 'only OMP set' | |
rval = subprocess.check_output(cmd, shell=True) | |
t, blas, omp = parse(rval) | |
print_results(blas, omp, base_t, t, notes) | |
first = True | |
for n in threads[:-1]: | |
cmd = (omp_flags.format(1) + blas_flags.format(n) + | |
theano_flags.format(False)+base_cmd.format(dt, d)) | |
if first: | |
notes = 'BLAS only' | |
else: | |
notes = '"' | |
rval = subprocess.check_output(cmd, shell=True) | |
t, blas, omp = parse(rval) | |
print_results(blas, omp, base_t, t, notes) | |
for n in threads: | |
cmd = (omp_flags.format(n) + blas_flags.format(1) + | |
theano_flags.format(True)+base_cmd.format(dt, d)) | |
if first: | |
notes = 'OMP only' | |
else: | |
notes = '"' | |
rval = subprocess.check_output(cmd, shell=True) | |
t, blas, omp = parse(rval) | |
print_results(blas, omp, base_t, t, notes) | |
print('\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment