Skip to content

Instantly share code, notes, and snippets.

Created May 19, 2011 22:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/981935 to your computer and use it in GitHub Desktop.
Save anonymous/981935 to your computer and use it in GitHub Desktop.
a test JOCL program to run a bunch of HH neurons in parallel
import com.jogamp.opencl.CLBuffer;
import com.jogamp.opencl.CLCommandQueue;
import com.jogamp.opencl.CLContext;
import com.jogamp.opencl.CLDevice;
import com.jogamp.opencl.CLKernel;
import com.jogamp.opencl.CLProgram;
import java.io.File;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Dictionary;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import static java.lang.System.*;
import static com.jogamp.opencl.CLMemory.Mem.*;
import static java.lang.Math.*;
/**
* JOCL Java Alpha Kernel client example.
*/
public class AlphaHHKernel_Test {
public static Random randomGenerator = new Random();
public static float START_TIME = -30;
public static float END_TIME = 100;
public static int ELEM_COUNT = 302;
public static int SAMPLES = 3;
public static void main(String[] args) throws IOException {
// set up (uses default CLPlatform and creates context for all devices)
CLContext context = CLContext.create();
out.println("created "+ context);
try{
// an array with available devices
CLDevice[] devices = context.getDevices();
for(int i=0; i<devices.length; i++)
{
out.println("device-" + i + ": " + devices[i]);
}
// have a look at the output and select a device
CLDevice device = devices[0];
// ... or use this code to select fastest device
//CLDevice device = context.getMaxFlopsDevice();
out.println("using "+ device);
// create command queue on selected device.
CLCommandQueue queue = device.createCommandQueue();
// NOTE: do the same for all the neurons, C. elegans has 302 neurons, even if they don't fire and this is HH (squid) it will hopefully give a first indication
int elementCount = ELEM_COUNT; // Length of arrays to process
int localWorkSize = min(device.getMaxWorkGroupSize(), 256); // Local work size dimensions
int globalWorkSize = roundUp(localWorkSize, elementCount); // rounded up to the nearest multiple of the localWorkSize
// load sources, create and build program
CLProgram program = context.createProgram(AlphaHHKernel_Test.class.getResourceAsStream("/resource/AlphaHHKernel.cl")).build();
/* constants declarations */
// max conductances - no need to be buffer (same for all)
float maxG_K = 36;
float maxG_Na = 120;
float maxG_Leak = (float) 0.3;
// reverse potentials - no need to be buffers (they're the same for all)
float E_K = -12;
float E_Na = 115;
float E_Leak = (float) 10.613;
// I_ext
float I_ext = 0;
// time step
float dt = (float) 0.01;
/* constants declarations */
/* input buffers declarations */
CLBuffer<FloatBuffer> V_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY);
CLBuffer<FloatBuffer> x_n_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY);
CLBuffer<FloatBuffer> x_m_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY);
CLBuffer<FloatBuffer> x_h_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY);
/* input buffers declarations */
/* output buffers declarations */
CLBuffer<FloatBuffer> V_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY);
CLBuffer<FloatBuffer> x_n_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY);
CLBuffer<FloatBuffer> x_m_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY);
CLBuffer<FloatBuffer> x_h_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY);
/* output buffers declarations */
out.println("Approx. used device memory (buffers only): " + (V_in_Buffer.getCLSize()*8)/1000000 +"MB");
// fill input buffers with initial conditions
// NOTE: they'll all be the same for now, but initial conditions for different neurons could be different
initInputBuffers(V_in_Buffer.getBuffer(), x_n_in_Buffer.getBuffer(), x_m_in_Buffer.getBuffer(), x_h_in_Buffer.getBuffer());
// get a reference to the kernel function with the name 'IntegrateHHStep'
CLKernel kernel = program.createCLKernel("IntegrateHHStep");
/* SETUP SOME PLOTTING SETUP STUFF */
// some dictionary for plotting
Hashtable<Integer, Hashtable<Float, Float>> V_by_t = new Hashtable<Integer, Hashtable<Float, Float>>();
List<Integer> sampleIndexes = new ArrayList<Integer>();
// Generate some random indexes in the 0 .. ELEM_COUNT range
for(int i = 0; i < SAMPLES; i++ )
{
sampleIndexes.add(randomGenerator.nextInt(ELEM_COUNT));
}
/* SETUP SOME PLOTTING SETUP STUFF */
long compuTime = nanoTime();
// here we go, HH integration loop (Euler's method)
for (float t = START_TIME; t < END_TIME; t = t+dt) {
// turn current to 10 mV at t=10
if (t >= 10 && t < 70) {
I_ext = 10;
}
// turn current off at t=70
else if (t >= 70) {
I_ext = 0;
}
/* HH LOGIC STARTs HERE */
// map the input/output buffers to its input parameters.
kernel.putArg(maxG_K)
.putArg(maxG_Na)
.putArg(maxG_Leak)
.putArg(E_K)
.putArg(E_Na)
.putArg(E_Leak)
.putArg(I_ext)
.putArg(dt)
.putArgs(V_in_Buffer, x_n_in_Buffer, x_m_in_Buffer, x_h_in_Buffer)
.putArgs(V_out_Buffer, x_n_out_Buffer, x_m_out_Buffer, x_h_out_Buffer)
.putArg(elementCount)
.rewind();
// asynchronous write of data to GPU device, followed by blocking read to get the computed results back.
queue.putWriteBuffer(V_in_Buffer, false)
.putWriteBuffer(x_n_in_Buffer, false)
.putWriteBuffer(x_m_in_Buffer, false)
.putWriteBuffer(x_h_in_Buffer, false)
.put1DRangeKernel(kernel, 0, globalWorkSize, localWorkSize)
.putReadBuffer(V_out_Buffer, true)
.putReadBuffer(x_n_out_Buffer, true)
.putReadBuffer(x_m_out_Buffer, true)
.putReadBuffer(x_h_out_Buffer, true);
// set output as new input
V_in_Buffer = V_out_Buffer;
x_n_in_Buffer = x_n_out_Buffer;
x_m_in_Buffer = x_m_out_Buffer;
x_h_in_Buffer = x_h_out_Buffer;
/* HH LOGIC ENDs HERE */
/* RECORD STUFF FOR PLOTTING */
// record some values for plotting
Iterator<Integer> itr = sampleIndexes.iterator();
while(itr.hasNext())
{
Integer index = itr.next();
if(!V_by_t.containsKey(index))
{
V_by_t.put(index, new Hashtable<Float, Float>());
}
V_by_t.get(index).put(t, V_out_Buffer.getBuffer().get(index));
}
/* RECORD STUFF FOR PLOTTING */
}
compuTime = nanoTime() - compuTime;
out.println("computation took: "+(compuTime/1000000)+"ms");
/* PLOT RESULTS */
// print some sampled charts to make sure we got fine-looking results.
// Plot results
Iterator<Integer> itr = sampleIndexes.iterator();
while(itr.hasNext())
{
XYSeries series = new XYSeries("HH_Graph");
Integer index = itr.next();
for (float t = START_TIME; t < END_TIME; t = t + dt) {
series.add(t, V_by_t.get(index).get(t));
}
// Add the series to your data set
XYSeriesCollection dataset = new XYSeriesCollection();
dataset.addSeries(series);
plot(dataset, index);
}
/* PLOT RESULTS */
System.out.println("end of HH simulation");
}finally{
// cleanup all resources associated with this context.
context.release();
}
}
private static void initInputBuffers(FloatBuffer V_in, FloatBuffer x_n_in, FloatBuffer x_m_in, FloatBuffer x_h_in) {
// initial condition for V is -10
while(V_in.remaining() != 0)
{
V_in.put(-10);
}
V_in.rewind();
// initial conditions for x n/m/h
while(x_n_in.remaining() != 0)
{
x_n_in.put(0);
}
x_n_in.rewind();
while(x_m_in.remaining() != 0)
{
x_m_in.put(0);
}
x_m_in.rewind();
while(x_h_in.remaining() != 0)
{
x_h_in.put(1);
}
x_h_in.rewind();
}
private static int roundUp(int groupSize, int globalSize) {
int r = globalSize % groupSize;
if (r == 0) {
return globalSize;
} else {
return globalSize + groupSize - r;
}
}
private static void plot(XYSeriesCollection dataset, int index)
{
// Generate the graph
JFreeChart chart = ChartFactory.createXYLineChart("HH Chart", "time", "Voltage", dataset, PlotOrientation.VERTICAL, true, true, false);
try {
ChartUtilities.saveChartAsJPEG(new File("HH_Chart_" + index + ".jpg"), chart, 500, 300);
} catch (IOException e) {
System.err.println("Problem occurred creating chart.");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment