Skip to content

Instantly share code, notes, and snippets.

@cwfitzgerald
Last active September 13, 2015 21:54
Show Gist options
  • Save cwfitzgerald/b642be7d13b212b6baa2 to your computer and use it in GitHub Desktop.
Save cwfitzgerald/b642be7d13b212b6baa2 to your computer and use it in GitHub Desktop.
OpenCL Implimentation of Rule90 making a Serpinki Triangle
# Call this with arguments of size, and the name of the png file of the result
# eg. python Rule90OpenCLSerpinki.py 1000 rule90serpinki.png
# Needs PyOpenCL, Numpy, and MatPlotLib
from time import time
timer = time()
import pyopencl as cl
from pyopencl import array
import numpy as np
from sys import argv
import matplotlib.pyplot as plt
##### Time Variables #####
import_time = time() - timer
kernel_time = 0
cpukernel_time = 0
cache_time = 0
cpucache_time = 0
plot_prep_time = 0
plot_time = 0
##### Setup Array #####
height = int(argv[1])
starting_state = "1"
for i in xrange(height):
starting_state = "0" + starting_state + "0"
width = len(starting_state)
last_state_host = np.zeros((width,), dtype=np.bool)
curr_state_host = np.zeros((width,), dtype=np.bool)
final_array = np.zeros((height+1,width), dtype=np.bool)
for i in xrange(width):
if starting_state[i] == '1':
last_state_host[i] = True
final_array[0][i] = True
##### Kernel Code #####
kernelsource = """
__kernel
void Rule90(__global bool *lastState, __global bool *currState, const uint size){
int global_id = get_global_id(0);
if (global_id < size){
if (global_id == 0){
currState[global_id] = lastState[global_id+1];
}
else if (global_id+1 == size){
currState[global_id] = lastState[global_id-1];
}
else{
currState[global_id] = lastState[global_id-1] ^ lastState[global_id+1];
}
}
}
"""
def cpukernel (lastState, size):
currState = np.empty(lastState.shape, dtype=np.bool)
for i in xrange(size):
if i == 0:
currState[i] = lastState[i+1]
elif (i+1 == size):
currState[i] = lastState[i-1]
else:
currState[i] = lastState[i-1] ^ lastState [i+1]
return currState
##### OpenCL Setup #####
timer = time()
## Get list of OpenCL platforms
platform = cl.get_platforms()[0]
## Obtain a device id for accelerator
device = platform.get_devices()[0]
## Get 'em some context
context = cl.Context([device])
## Build 'er a kernel
kernel = cl.Program(context, kernelsource).build()
## I needs me a command queue
queue = cl.CommandQueue(context)
last_state_client = cl.Buffer(context, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=last_state_host)
curr_state_client = cl.Buffer(context, cl.mem_flags.WRITE_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=curr_state_host)
kernel.Rule90.set_scalar_arg_dtypes([None, None, np.uint32])
OpenCL_prep_time = time() - timer
for i in xrange(height):
timer = time()
kernel.Rule90(queue, curr_state_host.shape, None, last_state_client, curr_state_client, np.uint32(width))
queue.finish()
kernel_time += time()-timer
timer = time()
cl.enqueue_copy(queue, final_array[i+1], curr_state_client)
cl.enqueue_copy(queue, last_state_client, curr_state_client)
queue.finish()
cache_time += time()-timer
timer = time()
final_array[i+1] = cpukernel(last_state_host, width)
cpukernel_time += time() - timer
timer = time()
last_state_host = final_array[i+1]
cpucache_time += time() - timer
timer = time()
fig = plt.figure(frameon=False)
dpi = min(height,width/round(width/750))
fig.set_size_inches(float(width/dpi),float(height/dpi))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.matshow(final_array, cmap=plt.cm.gray,aspect='auto')
plot_prep_time = time() - timer
timer = time()
plt.savefig(argv[2],dpi=(dpi))
plot_time = time() - timer
print "Total GPU=%s seconds: Kernel=%ss, Cache=%ss, OpenCL_Prep=%ss" % (str(round(OpenCL_prep_time+kernel_time+cache_time,2))[:4], str(round(kernel_time,2))[:4], str(round(cache_time,2))[:4], str(round(OpenCL_prep_time,2))[:4])
print "Total CPU=%s seconds: Kernel=%ss, Cache=%ss" % (str(round(cpukernel_time+cpucache_time,2))[:4], str(round(cpukernel_time,2))[:4], str(round(cpucache_time,2))[:4])
print "TOverhead=%s seconds: Import=%ss, PlotPrep=%ss, Plot=%ss" % (str(round(import_time+plot_prep_time+plot_time,2))[:4], str(round(import_time,2))[:4], str(round(plot_prep_time,2))[:4], str(round(plot_time,2))[:4])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment