Created
May 19, 2011 22:23
-
-
Save anonymous/981935 to your computer and use it in GitHub Desktop.
a test JOCL program to run a bunch of HH neurons in parallel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.jogamp.opencl.CLBuffer; | |
import com.jogamp.opencl.CLCommandQueue; | |
import com.jogamp.opencl.CLContext; | |
import com.jogamp.opencl.CLDevice; | |
import com.jogamp.opencl.CLKernel; | |
import com.jogamp.opencl.CLProgram; | |
import java.io.File; | |
import java.io.IOException; | |
import java.nio.FloatBuffer; | |
import java.nio.FloatBuffer; | |
import java.util.ArrayList; | |
import java.util.Dictionary; | |
import java.util.Hashtable; | |
import java.util.Iterator; | |
import java.util.List; | |
import java.util.Random; | |
import org.jfree.chart.ChartFactory; | |
import org.jfree.chart.ChartUtilities; | |
import org.jfree.chart.JFreeChart; | |
import org.jfree.chart.plot.PlotOrientation; | |
import org.jfree.data.xy.XYSeries; | |
import org.jfree.data.xy.XYSeriesCollection; | |
import static java.lang.System.*; | |
import static com.jogamp.opencl.CLMemory.Mem.*; | |
import static java.lang.Math.*; | |
/** | |
* JOCL Java Alpha Kernel client example. | |
*/ | |
public class AlphaHHKernel_Test { | |
public static Random randomGenerator = new Random(); | |
public static float START_TIME = -30; | |
public static float END_TIME = 100; | |
public static int ELEM_COUNT = 302; | |
public static int SAMPLES = 3; | |
public static void main(String[] args) throws IOException { | |
// set up (uses default CLPlatform and creates context for all devices) | |
CLContext context = CLContext.create(); | |
out.println("created "+ context); | |
try{ | |
// an array with available devices | |
CLDevice[] devices = context.getDevices(); | |
for(int i=0; i<devices.length; i++) | |
{ | |
out.println("device-" + i + ": " + devices[i]); | |
} | |
// have a look at the output and select a device | |
CLDevice device = devices[0]; | |
// ... or use this code to select fastest device | |
//CLDevice device = context.getMaxFlopsDevice(); | |
out.println("using "+ device); | |
// create command queue on selected device. | |
CLCommandQueue queue = device.createCommandQueue(); | |
// NOTE: do the same for all the neurons, C. elegans has 302 neurons, even if they don't fire and this is HH (squid) it will hopefully give a first indication | |
int elementCount = ELEM_COUNT; // Length of arrays to process | |
int localWorkSize = min(device.getMaxWorkGroupSize(), 256); // Local work size dimensions | |
int globalWorkSize = roundUp(localWorkSize, elementCount); // rounded up to the nearest multiple of the localWorkSize | |
// load sources, create and build program | |
CLProgram program = context.createProgram(AlphaHHKernel_Test.class.getResourceAsStream("/resource/AlphaHHKernel.cl")).build(); | |
/* constants declarations */ | |
// max conductances - no need to be buffer (same for all) | |
float maxG_K = 36; | |
float maxG_Na = 120; | |
float maxG_Leak = (float) 0.3; | |
// reverse potentials - no need to be buffers (they're the same for all) | |
float E_K = -12; | |
float E_Na = 115; | |
float E_Leak = (float) 10.613; | |
// I_ext | |
float I_ext = 0; | |
// time step | |
float dt = (float) 0.01; | |
/* constants declarations */ | |
/* input buffers declarations */ | |
CLBuffer<FloatBuffer> V_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY); | |
CLBuffer<FloatBuffer> x_n_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY); | |
CLBuffer<FloatBuffer> x_m_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY); | |
CLBuffer<FloatBuffer> x_h_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_ONLY); | |
/* input buffers declarations */ | |
/* output buffers declarations */ | |
CLBuffer<FloatBuffer> V_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY); | |
CLBuffer<FloatBuffer> x_n_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY); | |
CLBuffer<FloatBuffer> x_m_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY); | |
CLBuffer<FloatBuffer> x_h_out_Buffer = context.createFloatBuffer(globalWorkSize, WRITE_ONLY); | |
/* output buffers declarations */ | |
out.println("Approx. used device memory (buffers only): " + (V_in_Buffer.getCLSize()*8)/1000000 +"MB"); | |
// fill input buffers with initial conditions | |
// NOTE: they'll all be the same for now, but initial conditions for different neurons could be different | |
initInputBuffers(V_in_Buffer.getBuffer(), x_n_in_Buffer.getBuffer(), x_m_in_Buffer.getBuffer(), x_h_in_Buffer.getBuffer()); | |
// get a reference to the kernel function with the name 'IntegrateHHStep' | |
CLKernel kernel = program.createCLKernel("IntegrateHHStep"); | |
/* SETUP SOME PLOTTING SETUP STUFF */ | |
// some dictionary for plotting | |
Hashtable<Integer, Hashtable<Float, Float>> V_by_t = new Hashtable<Integer, Hashtable<Float, Float>>(); | |
List<Integer> sampleIndexes = new ArrayList<Integer>(); | |
// Generate some random indexes in the 0 .. ELEM_COUNT range | |
for(int i = 0; i < SAMPLES; i++ ) | |
{ | |
sampleIndexes.add(randomGenerator.nextInt(ELEM_COUNT)); | |
} | |
/* SETUP SOME PLOTTING SETUP STUFF */ | |
long compuTime = nanoTime(); | |
// here we go, HH integration loop (Euler's method) | |
for (float t = START_TIME; t < END_TIME; t = t+dt) { | |
// turn current to 10 mV at t=10 | |
if (t >= 10 && t < 70) { | |
I_ext = 10; | |
} | |
// turn current off at t=70 | |
else if (t >= 70) { | |
I_ext = 0; | |
} | |
/* HH LOGIC STARTs HERE */ | |
// map the input/output buffers to its input parameters. | |
kernel.putArg(maxG_K) | |
.putArg(maxG_Na) | |
.putArg(maxG_Leak) | |
.putArg(E_K) | |
.putArg(E_Na) | |
.putArg(E_Leak) | |
.putArg(I_ext) | |
.putArg(dt) | |
.putArgs(V_in_Buffer, x_n_in_Buffer, x_m_in_Buffer, x_h_in_Buffer) | |
.putArgs(V_out_Buffer, x_n_out_Buffer, x_m_out_Buffer, x_h_out_Buffer) | |
.putArg(elementCount) | |
.rewind(); | |
// asynchronous write of data to GPU device, followed by blocking read to get the computed results back. | |
queue.putWriteBuffer(V_in_Buffer, false) | |
.putWriteBuffer(x_n_in_Buffer, false) | |
.putWriteBuffer(x_m_in_Buffer, false) | |
.putWriteBuffer(x_h_in_Buffer, false) | |
.put1DRangeKernel(kernel, 0, globalWorkSize, localWorkSize) | |
.putReadBuffer(V_out_Buffer, true) | |
.putReadBuffer(x_n_out_Buffer, true) | |
.putReadBuffer(x_m_out_Buffer, true) | |
.putReadBuffer(x_h_out_Buffer, true); | |
// set output as new input | |
V_in_Buffer = V_out_Buffer; | |
x_n_in_Buffer = x_n_out_Buffer; | |
x_m_in_Buffer = x_m_out_Buffer; | |
x_h_in_Buffer = x_h_out_Buffer; | |
/* HH LOGIC ENDs HERE */ | |
/* RECORD STUFF FOR PLOTTING */ | |
// record some values for plotting | |
Iterator<Integer> itr = sampleIndexes.iterator(); | |
while(itr.hasNext()) | |
{ | |
Integer index = itr.next(); | |
if(!V_by_t.containsKey(index)) | |
{ | |
V_by_t.put(index, new Hashtable<Float, Float>()); | |
} | |
V_by_t.get(index).put(t, V_out_Buffer.getBuffer().get(index)); | |
} | |
/* RECORD STUFF FOR PLOTTING */ | |
} | |
compuTime = nanoTime() - compuTime; | |
out.println("computation took: "+(compuTime/1000000)+"ms"); | |
/* PLOT RESULTS */ | |
// print some sampled charts to make sure we got fine-looking results. | |
// Plot results | |
Iterator<Integer> itr = sampleIndexes.iterator(); | |
while(itr.hasNext()) | |
{ | |
XYSeries series = new XYSeries("HH_Graph"); | |
Integer index = itr.next(); | |
for (float t = START_TIME; t < END_TIME; t = t + dt) { | |
series.add(t, V_by_t.get(index).get(t)); | |
} | |
// Add the series to your data set | |
XYSeriesCollection dataset = new XYSeriesCollection(); | |
dataset.addSeries(series); | |
plot(dataset, index); | |
} | |
/* PLOT RESULTS */ | |
System.out.println("end of HH simulation"); | |
}finally{ | |
// cleanup all resources associated with this context. | |
context.release(); | |
} | |
} | |
private static void initInputBuffers(FloatBuffer V_in, FloatBuffer x_n_in, FloatBuffer x_m_in, FloatBuffer x_h_in) { | |
// initial condition for V is -10 | |
while(V_in.remaining() != 0) | |
{ | |
V_in.put(-10); | |
} | |
V_in.rewind(); | |
// initial conditions for x n/m/h | |
while(x_n_in.remaining() != 0) | |
{ | |
x_n_in.put(0); | |
} | |
x_n_in.rewind(); | |
while(x_m_in.remaining() != 0) | |
{ | |
x_m_in.put(0); | |
} | |
x_m_in.rewind(); | |
while(x_h_in.remaining() != 0) | |
{ | |
x_h_in.put(1); | |
} | |
x_h_in.rewind(); | |
} | |
private static int roundUp(int groupSize, int globalSize) { | |
int r = globalSize % groupSize; | |
if (r == 0) { | |
return globalSize; | |
} else { | |
return globalSize + groupSize - r; | |
} | |
} | |
private static void plot(XYSeriesCollection dataset, int index) | |
{ | |
// Generate the graph | |
JFreeChart chart = ChartFactory.createXYLineChart("HH Chart", "time", "Voltage", dataset, PlotOrientation.VERTICAL, true, true, false); | |
try { | |
ChartUtilities.saveChartAsJPEG(new File("HH_Chart_" + index + ".jpg"), chart, 500, 300); | |
} catch (IOException e) { | |
System.err.println("Problem occurred creating chart."); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment