|
/* |
|
|
|
# Understanding overlapping memory transfers and kernel execution for very |
|
simple CUDA workflows |
|
|
|
|
|
## Executive summary |
|
|
|
This small exploration started when Dr Jon Rogers mentioned that one could get |
|
overlapping memory transfer and kernel execution by using device-mapped |
|
page-locked host memory (see section 3.2.4, Page-Locked Host Memory, of the CUDA |
|
C Programming Guide version 6.0) even for simple CUDA workflows, i.e., copying |
|
some data from the host to device, operating on that data on the device, and |
|
copying the data back. |
|
|
|
My understanding (incorrect, it turns out) of overlapping transfer and execution |
|
was that this was possible only if I broke up the calculation and concomintant |
|
memory transfer into small chunks and overlapped those. However, this code |
|
demonstrates that using device-mapped host memory ("device-mapped" implying |
|
page-locked host memory) can result in substantial runtime speedups relative to |
|
non-device-mapped, simply page-locked host memory (2x) as well as to regular |
|
pageable host memory (another 2x). |
|
|
|
These trends hold for very simple and fast kernels as well as slow kernels that |
|
do a lot of calculations (i.e., transcendental function evaluations). |
|
|
|
One unexpected wrinkle was that on the NVS 5200M compute capability 2.1 device |
|
in my laptop, the non-device-mapped and pageable cases took much longer to |
|
execute in a non-default stream than in the default stream. This unusual |
|
behavior did not occur when tested on a K20c (compute capability 3.5). |
|
|
|
|
|
## Introduction |
|
|
|
This code exercises a single GPU along the following dimensions: |
|
|
|
1) three separate kernels: fast, medium, and long (in terms of amount of |
|
computation); all three kernels read from global memory, run a calculation, and |
|
store the result in-place: three embarrassingly parallel problems; |
|
|
|
2) in the default stream or in a new stream (recall that the default stream |
|
supposedly does not allow any asynchronous operations); |
|
|
|
3) page-locked or paged host memory; |
|
|
|
3b) if page-locked, device-mapped or not. |
|
|
|
It runs a specific subset of tests and prints out a comma-separated table |
|
summarizing those tests and the timing results. Specifically, it will run and |
|
report: |
|
|
|
A- page-locked and device-mapped host memory; |
|
|
|
B- page-locked host memory, not device-mapped, with asynchronous memory |
|
transfer; |
|
|
|
C- paged host memory, with asynchronous memory transfer; |
|
|
|
D- the above three situations for each of the three kernels (fast, medium, and |
|
long); |
|
|
|
E- all of the above, in the default stream and in a new stream. |
|
|
|
|
|
## Details of timing ## |
|
|
|
This subset of tests is run several times and the lowest overall time is |
|
reported. The steps that the test takes, and the locations of the timer events, |
|
are as follows: |
|
|
|
1) Begin the overall timer |
|
|
|
2) CPU initializes host memory (this is included in the overall timer because I |
|
am not sure when this data begins to be copied to the device when in |
|
memory-mapped mode) |
|
|
|
3) If not device-mapped memory, asynchronous memory copy from the host to device |
|
|
|
4a) Kernel timer event starts |
|
|
|
4b) Kernel invokation |
|
|
|
4c) Kernel timer event stops |
|
|
|
5) If not device-mapped memory, asynchronous copy back from the device to host |
|
|
|
6) Stream synchronizes (waits for all the above to finish) and overall timer |
|
stops. |
|
|
|
|
|
## Results for NVS 5200M ## |
|
|
|
There are three extra columns computed in a spreadsheet program. Note that |
|
"DiffTime" is simply `overallTime - kernelTime`. "Ratios" is the ratio of the |
|
overall times relative to that of the first row's overall time for that kernel |
|
(fast, medium, long). All times are in milliseconds. |
|
|
|
The result running on my Nvidia NVS 5200M follow. |
|
|
|
| k | pl | dm | ns | O | K | d | r | |
|
|---|----|----|----|---------|--------|---------|---------| |
|
| 0 | 1 | 1 | 0 | 18.528 | 15.968 | 2.56 | 1 | |
|
| 0 | 1 | 0 | 0 | 36.576 | 7.168 | 29.408 | 1.97409 | |
|
| 0 | 0 | 0 | 0 | 68.256 | 7.36 | 60.896 | 3.68394 | |
|
| 0 | 1 | 1 | 1 | 18.848 | 16.288 | 2.56 | 1.01727 | |
|
| 0 | 1 | 0 | 1 | 132.16 | 26.944 | 105.216 | 7.13299 | |
|
| 0 | 0 | 0 | 1 | 151.904 | 6.944 | 144.96 | 8.19862 | |
|
| | | | | | | | | |
|
| 1 | 1 | 1 | 0 | 19.488 | 16.864 | 2.624 | 1 | |
|
| 1 | 1 | 0 | 0 | 41.632 | 12.32 | 29.312 | 2.13629 | |
|
| 1 | 0 | 0 | 0 | 116.256 | 13.056 | 103.2 | 5.96552 | |
|
| 1 | 1 | 1 | 1 | 22.272 | 19.552 | 2.72 | 1.14286 | |
|
| 1 | 1 | 0 | 1 | 137.184 | 32.8 | 104.384 | 7.03941 | |
|
| 1 | 0 | 0 | 1 | 198.592 | 13.152 | 185.44 | 10.1905 | |
|
| | | | | | | | | |
|
| 2 | 1 | 1 | 0 | 25.216 | 22.56 | 2.656 | 1 | |
|
| 2 | 1 | 0 | 0 | 50.464 | 21.184 | 29.28 | 2.00127 | |
|
| 2 | 0 | 0 | 0 | 140.608 | 21.632 | 118.976 | 5.57614 | |
|
| 2 | 1 | 1 | 1 | 28.384 | 25.824 | 2.56 | 1.12563 | |
|
| 2 | 1 | 0 | 1 | 185.056 | 45.024 | 140.032 | 7.33883 | |
|
| 2 | 0 | 0 | 1 | 169.152 | 20.992 | 148.16 | 6.70812 | |
|
|
|
- **k** kernel number |
|
- **pl** page-locked host memory? |
|
- **dm** device-mapped? |
|
- **ns** new stream? |
|
- **O** overall time (milliseconds) |
|
- **K** kernel time (milliseconds) |
|
- **d** overall minus kernel time (milliseconds) |
|
- **r** ratio between first row of this kernel and this row |
|
|
|
Some discussion of these results follow. |
|
|
|
For each of the three kernels, the fastest speed was observed with |
|
device-mapping turned on (which implies page-locked host memory). It is slightly |
|
advantageous to do this in the *default* stream than a new stream. This |
|
surprised me because I expected the default stream to enforce some serialization |
|
of the steps. |
|
|
|
Dr Rogers' insight, that turning on device-mapping would improve runtime, is |
|
correct: page-locked memory by itself, without device-mapping (and thus |
|
requiring explicit memory copies to and from the device) took much longer to |
|
finish than device-mapped memory. |
|
|
|
What was surprising about this is that the runtime ratio between |
|
non-device-mapped and device-mapped case was about 2x for the default stream, |
|
but 6x to 7x for the new stream. This holds for all three kernels. This is very |
|
mystifying. |
|
|
|
Also observe that the absolute standard case (paged host memory) has 4x to 10x |
|
the runtime of the fastest case for all three kernels, and again, the default |
|
stream is faster. |
|
|
|
One point to make is that the table for a `win32` build differs from the above |
|
table (`x64`) in unusual and surprising ways. The *general* trends are the same |
|
but some combinations are much faster or slower depending on the CPU |
|
architecture targeted. This is also very mystifying, since the overall and |
|
kernel times are obtained via CUDA events, and should only contain memory |
|
transfers and kernel invokations. (A possible source of this divergence could be |
|
branch misprediction.) |
|
|
|
|
|
## Results for K20c ## |
|
|
|
Next, we have timing results for a Nvidia K20c. This was run on a powerful |
|
workstation. Again, times are in milliseconds. |
|
|
|
| k | pl | dm | ns | O | K | d | r | |
|
|---|----|----|----|---------|--------|---------|----------| |
|
| 0 | 1 | 1 | 0 | 229.632 | 24.96 | 204.672 | 1 | |
|
| 0 | 1 | 0 | 0 | 241.376 | 7.392 | 233.984 | 1.05114 | |
|
| 0 | 0 | 0 | 0 | 259.52 | 7.392 | 252.128 | 1.13016 | |
|
| 0 | 1 | 1 | 1 | 226.4 | 24.512 | 201.888 | 0.985925 | |
|
| 0 | 1 | 0 | 1 | 238.336 | 8.672 | 229.664 | 1.0379 | |
|
| 0 | 0 | 0 | 1 | 261.472 | 8.032 | 253.44 | 1.13866 | |
|
| | | | | | | | | |
|
| 1 | 1 | 1 | 0 | 229.856 | 25.408 | 204.448 | 1 | |
|
| 1 | 1 | 0 | 0 | 242.24 | 8.32 | 233.92 | 1.05388 | |
|
| 1 | 0 | 0 | 0 | 260.64 | 8.32 | 252.32 | 1.13393 | |
|
| 1 | 1 | 1 | 1 | 222.464 | 24.736 | 197.728 | 0.967841 | |
|
| 1 | 1 | 0 | 1 | 242.4 | 9.376 | 233.024 | 1.05457 | |
|
| 1 | 0 | 0 | 1 | 260.8 | 7.68 | 253.12 | 1.13462 | |
|
| | | | | | | | | |
|
| 2 | 1 | 1 | 0 | 230.56 | 25.888 | 204.672 | 1 | |
|
| 2 | 1 | 0 | 0 | 244.064 | 9.504 | 234.56 | 1.05857 | |
|
| 2 | 0 | 0 | 0 | 262.368 | 9.504 | 252.864 | 1.13796 | |
|
| 2 | 1 | 1 | 1 | 227.168 | 25.376 | 201.792 | 0.985288 | |
|
| 2 | 1 | 0 | 1 | 244.064 | 10.24 | 233.824 | 1.05857 | |
|
| 2 | 0 | 0 | 1 | 260.64 | 8.448 | 252.192 | 1.13046 | |
|
|
|
Again, the key: |
|
|
|
- **k** kernel number |
|
- **pl** page-locked host memory? |
|
- **dm** device-mapped? |
|
- **ns** new stream? |
|
- **O** overall time (milliseconds) |
|
- **K** kernel time (milliseconds) |
|
- **d** overall minus kernel time (milliseconds) |
|
- **r** ratio between first row of this kernel and this row |
|
|
|
While the same trends were observed, viz., device-mapped memory is the fastest, |
|
the slowdown for non-mapped or paged host memory are much less pronounced, i.e., |
|
the slowest (pageable host memory) has only 1.13x the runtime of device-mapped |
|
page-locked host memory. There is also a slight but significant advantage to |
|
running the device-mapped case in a non-default stream but no clear difference |
|
for the other combinations. |
|
|
|
A surprise is that the absolute overall times are so much larger for this more |
|
capable device than the tiny laptop GPU benchmarked above. I am at a loss to |
|
explain this. |
|
|
|
|
|
## Hardware and software ## |
|
|
|
This code is compiled in Visual Studio 2010 with CUDA 6.0 (NVS 5200M, code |
|
generation flags set to `compute_20,sm_21`) and CUDA 5.5 (K20c, code generation |
|
flags set to `compute_35,sm_35`). It was compiled in Release mode, targeting |
|
x64. It should be trivial to compile it in Linux and Mac OS. |
|
|
|
The CUDA Occupancy Calculator was used to confirm that 256 threads per block |
|
resulted in the highest occupancy for all three kernels (each of which used |
|
different numbers of registers, according to the `--ptxas-options=-v` nvcc |
|
option). |
|
|
|
Here is the deviceQuery results: |
|
|
|
``` |
|
Device 0: "NVS 5200M" |
|
CUDA Driver Version / Runtime Version 6.0 / 6.0 |
|
CUDA Capability Major/Minor version number: 2.1 |
|
Total amount of global memory: 1024 MBytes (1073741824 bytes) |
|
( 2) Multiprocessors, ( 48) CUDA Cores/MP: 96 CUDA Cores |
|
GPU Clock rate: 1344 MHz (1.34 GHz) |
|
Memory Clock rate: 1569 Mhz |
|
Memory Bus Width: 64-bit |
|
L2 Cache Size: 131072 bytes |
|
Maximum Texture Dimension Size (x,y,z) 1D=(65536), 2D=(65536, 65535), |
|
3D=(2048, 2048, 2048) |
|
Maximum Layered 1D Texture Size, (num) layers 1D=(16384), 2048 layers |
|
Maximum Layered 2D Texture Size, (num) layers 2D=(16384, 16384), 2048 layers |
|
Total amount of constant memory: 65536 bytes |
|
Total amount of shared memory per block: 49152 bytes |
|
Total number of registers available per block: 32768 |
|
Warp size: 32 |
|
Maximum number of threads per multiprocessor: 1536 |
|
Maximum number of threads per block: 1024 |
|
Max dimension size of a thread block (x,y,z): (1024, 1024, 64) |
|
Max dimension size of a grid size (x,y,z): (65535, 65535, 65535) |
|
Maximum memory pitch: 2147483647 bytes |
|
Texture alignment: 512 bytes |
|
Concurrent copy and kernel execution: Yes with 1 copy engine(s) |
|
Run time limit on kernels: Yes |
|
Integrated GPU sharing Host Memory: No |
|
Support host page-locked memory mapping: Yes |
|
Alignment requirement for Surfaces: Yes |
|
Device has ECC support: Disabled |
|
CUDA Device Driver Mode (TCC or WDDM): WDDM (Windows Display Driver |
|
Model) |
|
Device supports Unified Addressing (UVA): Yes |
|
Device PCI Bus ID / PCI location ID: 1 / 0 |
|
Compute Mode: |
|
< Default (multiple host threads can use ::cudaSetDevice() with device |
|
simultaneously) > |
|
|
|
Device 0: "Tesla K20c" |
|
CUDA Driver Version / Runtime Version 5.5 / 5.5 |
|
CUDA Capability Major/Minor version number: 3.5 |
|
Total amount of global memory: 4800 MBytes (5032968192 bytes) |
|
(13) Multiprocessors, (192) CUDA Cores/MP: 2496 CUDA Cores |
|
GPU Clock rate: 706 MHz (0.71 GHz) |
|
Memory Clock rate: 2600 Mhz |
|
Memory Bus Width: 320-bit |
|
L2 Cache Size: 1310720 bytes |
|
Maximum Texture Dimension Size (x,y,z) 1D=(65536), 2D=(65536, 65536), |
|
3D=(4096, 4096, 4096) |
|
Maximum Layered 1D Texture Size, (num) layers 1D=(16384), 2048 layers |
|
Maximum Layered 2D Texture Size, (num) layers 2D=(16384, 16384), 2048 layers |
|
Total amount of constant memory: 65536 bytes |
|
Total amount of shared memory per block: 49152 bytes |
|
Total number of registers available per block: 65536 |
|
Warp size: 32 |
|
Maximum number of threads per multiprocessor: 2048 |
|
Maximum number of threads per block: 1024 |
|
Max dimension size of a thread block (x,y,z): (1024, 1024, 64) |
|
Max dimension size of a grid size (x,y,z): (2147483647, 65535, 65535) |
|
Maximum memory pitch: 2147483647 bytes |
|
Texture alignment: 512 bytes |
|
Concurrent copy and kernel execution: Yes with 2 copy engine(s) |
|
Run time limit on kernels: No |
|
Integrated GPU sharing Host Memory: No |
|
Support host page-locked memory mapping: Yes |
|
Alignment requirement for Surfaces: Yes |
|
Device has ECC support: Enabled |
|
Device supports Unified Addressing (UVA): Yes |
|
Device PCI Bus ID / PCI location ID: 4 / 0 |
|
Compute Mode: |
|
< Default (multiple host threads can use ::cudaSetDevice() with device |
|
simultaneously) > |
|
``` |
|
*/ |
|
|
|
#include "cuda.h" |
|
#include "cuda_runtime.h" |
|
#include "device_launch_parameters.h" |
|
|
|
#include <stdio.h> |
|
#include <cmath> |
|
#include <limits> |
|
#include <assert.h> |
|
|
|
/// The "fast" CUDA kernel. Performs a relatively fast in-place update. |
|
template <typename T> |
|
__global__ void kernelEasy(T *data) { |
|
int globalIdx = threadIdx.x + blockDim.x * blockIdx.x; |
|
data[globalIdx] = -2.3f * data[globalIdx]; |
|
} |
|
|
|
/// The "medium" CUDA kernel. Does a bit more work than the "fast" one. |
|
template <typename T> |
|
__global__ void kernelMedium(T *data) { |
|
int globalIdx = threadIdx.x + blockDim.x * blockIdx.x; |
|
data[globalIdx] = asinf(data[globalIdx]); |
|
} |
|
|
|
/// The "slow" CUDA kernel. Performs considerably more work. |
|
template <typename T> |
|
__global__ void kernelHard(T *data) { |
|
int globalIdx = threadIdx.x + blockDim.x * blockIdx.x; |
|
data[globalIdx] = log10f(expf(asinf(data[globalIdx]))); |
|
} |
|
|
|
/// Total number of elements in the data vector |
|
const int N = 16 * 1024; |
|
/// Number of threads per 1D block |
|
const int BLOCK_DIM = 256; |
|
|
|
/// CUDA return code checker |
|
/// |
|
/// Stops execution if a return code is not success, printing a diagnostic. |
|
/// |
|
/// \param[in] result the return code |
|
/// |
|
/// \param[in] desc a string descriptor to print if result is not a success |
|
/// |
|
/// \return result again, although the program will exit before returning |
|
inline cudaError_t checkErr(cudaError_t result, const char desc[]) { |
|
if (result != cudaSuccess) { |
|
fprintf(stderr, "Err: %s failed: %s\n", desc, |
|
cudaGetErrorString(result)); |
|
exit(1); |
|
} |
|
return result; |
|
} |
|
|
|
/// Relative error between actual and expected values |
|
/// |
|
/// \param[in] dirt actual value computed |
|
/// |
|
/// \param[in] gold ideal expected value |
|
/// |
|
/// \return |(dirt-gold)/gold|. |
|
template <typename T> |
|
T relativeError(const T dirt, const T gold) { |
|
return std::abs((dirt - gold) / gold); |
|
} |
|
|
|
/// Host function to check results of all three CUDA kernels |
|
/// |
|
/// \param[in] whichKernel 0, 1, or 2, corresponding to fast/easy, medium, and |
|
/// slow/hard kernels being tested |
|
/// |
|
/// \param[in] gpu pointer to the GPU-computed result array |
|
/// |
|
/// \param[in] size size of gpu array |
|
/// |
|
/// \param[in] relTol maximum relative error to tolerate. Any single element |
|
/// differs between CPU and GPU by more than this relative error results in |
|
/// verification failure |
|
/// |
|
/// \param[in] verbose whether or not to print a message when computing results |
|
/// |
|
/// \return 1 if verification successful, 0 otherwise |
|
template <typename T> |
|
int kernelVerifyHost(const int whichKernel, const T *in, const T *gpu, |
|
const int size, |
|
const T relTol = std::numeric_limits<T>::epsilon(), |
|
const bool verbose = true) { |
|
for (int i = 0; i < size; i++) { |
|
T gold; |
|
switch (whichKernel) { |
|
case 0: |
|
gold = -2.3f * in[i]; |
|
break; |
|
case 1: |
|
gold = std::asin(in[i]); |
|
break; |
|
case 2: |
|
default: |
|
gold = std::log10(std::exp(std::asin(in[i]))); |
|
break; |
|
} |
|
T relErr = relativeError(gpu[i], gold); |
|
if (relErr > relTol) { |
|
if (verbose) { |
|
printf( |
|
"Warn: kernelVerifyHost: @ [%d], relativeError(gpu=%g, " |
|
"cpu=%g) = %g > tol=%g\n", |
|
i, gpu[i], gold, relErr, relTol); |
|
} |
|
return 0; |
|
} else if (verbose && size < 20) { |
|
printf("Info: relErr @ [%d] = %g.\n", i, relErr); |
|
} |
|
} |
|
return 1; |
|
} |
|
|
|
/// Initialize an array with some entirely-arbitrary data |
|
/// |
|
/// \param[in] data the array to fill |
|
/// |
|
/// \param[in] size size of the array |
|
template <typename T> |
|
void initializeData(T data[], const int size) { |
|
for (int i = 0; i < size; i++) { |
|
data[i] = std::sin(i + 3.93123f); |
|
} |
|
} |
|
|
|
/// Print an array as floating poing numbers |
|
/// |
|
/// \param[in] data the array to print |
|
/// |
|
/// \param[in] size size of the array |
|
template <typename T> |
|
void printData(const T data[], const int size) { |
|
for (int i = 0; i < size; i++) { |
|
printf("[%03d] = %g\n", i, static_cast<double>(data[i])); |
|
} |
|
} |
|
|
|
/// Template specialization for printing integer arrays |
|
/// |
|
/// Not currently used. |
|
/// |
|
/// \param[in] data the array to print |
|
/// |
|
/// \param[in] size size of the array |
|
template <> |
|
void printData<>(const int data[], const int size) { |
|
for (int i = 0; i < size; i++) { |
|
printf("[%03d] = %d\n", i, data[i]); |
|
} |
|
} |
|
|
|
/// Class to simplify (a little) the storage of timers used by this program |
|
struct TimerResults { |
|
/// This time (see above) includes CPU initialization of input data, kernel |
|
/// invokations, and any required memory copies. In microseconds. |
|
float overallMs; |
|
|
|
/// This time is just the kernel, in microseconds |
|
float kernelMs; |
|
|
|
/// This would hold memory transfer time but isn't being measured now |
|
float memoryTransferMs; |
|
|
|
TimerResults() : overallMs(-1), kernelMs(-1), memoryTransferMs(-1) {} |
|
|
|
/// Given start/stop events for the overall timer and kernel timer, update |
|
/// overallMs and kernelMs |
|
/// |
|
/// \param[in] overall_start |
|
/// |
|
/// \param[in] overall_stop |
|
/// |
|
/// \param[in] kernel_start |
|
/// |
|
/// \param[in] kernel_stop |
|
void processEvents(const cudaEvent_t &overall_start, |
|
const cudaEvent_t &overall_stop, |
|
const cudaEvent_t &kernel_start, |
|
const cudaEvent_t &kernel_stop) { |
|
checkErr(cudaEventElapsedTime(&overallMs, overall_start, overall_stop), |
|
"cudaEventElapsedTime overall"); |
|
overallMs *= 1000; |
|
checkErr(cudaEventElapsedTime(&kernelMs, kernel_start, kernel_stop), |
|
"cudaEventElapsedTime kernel"); |
|
kernelMs *= 1000; |
|
} |
|
}; |
|
|
|
/// Main test function |
|
/// |
|
/// Initializes memory, populates it, invokes a CUDA kernel on it, times |
|
/// everything, and verifies that it produced the correct result. |
|
/// |
|
/// \param[out] results class holding the overall and kernel times |
|
/// |
|
/// \param[in] whichKernel 0, 1, or 2 corresponding to fast/easy, medium, and |
|
/// long/hard CUDA kernels |
|
/// |
|
/// \param[in] pageLocked whether to allocate host memory as page-locked. If |
|
/// false, the default host allocator is called. |
|
/// |
|
/// \param[in] mapped if true, in addition to being page-locked, the host memory |
|
/// will be device-mapped (see CUDA C Programming Guide for the details). Unused |
|
/// if pageLocked is false. |
|
/// |
|
/// \param[in] nondefaultStream if true, a new stream will be created and all |
|
/// CUDA activities will be done in it. Otherwise, the default stream is used. |
|
/// |
|
/// \param[in] verbose whether to print a message at the end of the function |
|
/// giving the timed results |
|
/// |
|
/// \return 0 if success, non-zero if some problem was encountered. |
|
template <typename T> |
|
int runme(TimerResults &results, const int whichKernel, |
|
const bool pagelocked = true, const bool mapped = true, |
|
const bool nondefaultStream = true, const bool verbose = false) { |
|
// Initialize the device |
|
checkErr(cudaSetDevice(0), "cudaSetDevice"); |
|
|
|
// Configure device to prefer cache over shared memory |
|
checkErr(cudaDeviceSetCacheConfig(cudaFuncCachePreferL1), |
|
"cudaDeviceSetCacheConfig"); |
|
|
|
// Create a non-default stream to operate in so asynchronous things won't be |
|
// serialized |
|
cudaStream_t stream = 0; |
|
if (nondefaultStream) { |
|
checkErr(cudaStreamCreate(&stream), "cudaStreamCreate"); |
|
} |
|
|
|
// Allocate host and device memory |
|
T *data, *data_device; |
|
if (pagelocked) { |
|
if (mapped) { |
|
// Allocate on host |
|
checkErr(cudaHostAlloc(&data, sizeof(T) * N, cudaHostAllocMapped), |
|
"cudaHostAlloc"); |
|
|
|
// Get the mapped device pointer |
|
checkErr(cudaHostGetDevicePointer(&data_device, data, 0), |
|
"cudaHostGetDevicePointer"); |
|
|
|
} else { |
|
// Allocate on host |
|
checkErr(cudaHostAlloc(&data, sizeof(T) * N, cudaHostAllocDefault), |
|
"cudaHostAlloc"); |
|
|
|
// Allocate on device |
|
checkErr(cudaMalloc(&data_device, sizeof(T) * N), "cudaMalloc"); |
|
} |
|
} else { |
|
data = new T[N]; |
|
checkErr(cudaMalloc(&data_device, sizeof(T) * N), "cudaMalloc"); |
|
} |
|
|
|
// Kernel sizes. Do this before timer starts. |
|
const dim3 numThreads(N < BLOCK_DIM ? N : BLOCK_DIM); |
|
const dim3 numBlocks((N + (BLOCK_DIM - 1)) / BLOCK_DIM); |
|
|
|
// Set up timer |
|
cudaEvent_t overall_start, overall_stop, kernel_start, kernel_stop; |
|
checkErr(cudaEventCreate(&overall_start), "cudaEventCreate overall_start"); |
|
checkErr(cudaEventCreate(&overall_stop), "cudaEventCreate overall_stop"); |
|
checkErr(cudaEventCreate(&kernel_start), "cudaEventCreate kernel_start"); |
|
checkErr(cudaEventCreate(&kernel_stop), "cudaEventCreate kernel_stop"); |
|
// Start timer to include CPU setup of data because who knows when the |
|
// copies to the GPU start. |
|
checkErr(cudaEventRecord(overall_start, stream), |
|
"cudaEventRecord overall_start"); |
|
|
|
// Initialize and view data. Included in timer in case to-GPU copy starts as |
|
// CPU writes this data. |
|
initializeData(data, N); |
|
if (N < 20) { |
|
printData(data, N); |
|
} |
|
|
|
// Copy to device if not page-locked, or if page-locked and not mapped |
|
if (!pagelocked || (pagelocked && !mapped)) { |
|
checkErr(cudaMemcpyAsync(data_device, data, sizeof(T) * N, |
|
cudaMemcpyHostToDevice, stream), |
|
"cudaMemcpyAsync H2D"); |
|
} |
|
|
|
// Kernel invokation |
|
checkErr(cudaEventRecord(kernel_start, stream), |
|
"cudaEventRecord kernel_start"); |
|
|
|
switch (whichKernel) { |
|
case 0: |
|
kernelEasy<<<numBlocks, numThreads, 0, stream>>>(data_device); |
|
break; |
|
case 1: |
|
kernelMedium<<<numBlocks, numThreads, 0, stream>>>(data_device); |
|
break; |
|
case 2: |
|
default: |
|
kernelHard<<<numBlocks, numThreads, 0, stream>>>(data_device); |
|
break; |
|
} |
|
|
|
checkErr(cudaEventRecord(kernel_stop, stream), |
|
"cudaEventRecord kernel_stop"); |
|
|
|
// If not page-locked, or if page-locked and not mapped, copy result from |
|
// device to CPU |
|
if (!pagelocked || (pagelocked && !mapped)) { |
|
checkErr(cudaMemcpyAsync(data, data_device, sizeof(T) * N, |
|
cudaMemcpyDeviceToHost, stream), |
|
"cudaMemcpyAsync D2H"); |
|
} |
|
|
|
// At this stage, `data` has been overwritten by the result. |
|
checkErr(cudaEventRecord(overall_stop, stream), |
|
"cudaEventRecord overall_stop"); |
|
|
|
// Synchronize the device: wait till all work is done. |
|
checkErr(cudaEventSynchronize(overall_stop), "cudaEventSynchronize"); |
|
|
|
// Update time |
|
results.processEvents(overall_start, overall_stop, kernel_start, |
|
kernel_stop); |
|
|
|
// Verify kernel |
|
T *backup = new T[N]; |
|
initializeData(backup, N); |
|
if (0 == kernelVerifyHost(whichKernel, backup, data, N, |
|
500 * std::numeric_limits<T>::epsilon())) { |
|
printf("Err: GPU and CPU results did not match.\n"); |
|
return -1; |
|
} else { |
|
if (verbose) { |
|
printf( |
|
"Info: CPU & GPU matched. Overall GPU: %g ms, kernel only: %g " |
|
"ms.\n", |
|
results.overallMs, results.kernelMs); |
|
} |
|
} |
|
|
|
delete[] backup; |
|
|
|
// Cleanup |
|
if (pagelocked) { |
|
cudaFreeHost(data); |
|
if (!mapped) { |
|
cudaFree(data_device); |
|
} |
|
} else { |
|
cudaFree(data_device); |
|
delete[] data; |
|
} |
|
cudaEventDestroy(overall_start); |
|
cudaEventDestroy(overall_stop); |
|
cudaEventDestroy(kernel_start); |
|
cudaEventDestroy(kernel_stop); |
|
return 0; |
|
} |
|
|
|
/// Sets up the subset of tests to run, runs them several times, finding the |
|
/// fastest overall time for each one, and prints a comma-separated table of |
|
/// those results. |
|
int main() { |
|
|
|
// Each kernel gets six tests: |
|
// |
|
// 1) page-locked device-mapped host memory |
|
// 2) page-locked, not device-mapped, host memory |
|
// 3) paged host memory (not device-mapped) |
|
// 4-6) like (1)-(3) except in a non-default stream. |
|
// |
|
// Therefore, what follow are six boolean arrays that encode these as inputs |
|
// to the `runme` test function above. |
|
|
|
// pagelocked | mapped | nondefaultStream |
|
const bool mapped_def[] = {true, true, false}; |
|
const bool nonmap_pagelock_def[] = {true, false, false}; |
|
const bool page_def[] = {false, false, false}; |
|
|
|
const bool mapped_nondef[] = {true, true, true}; |
|
const bool nonmap_pagelock_nondef[] = {true, false, true}; |
|
const bool page_ndef[] = {false, false, true}; |
|
|
|
// More tests can readily be defined above. However, only those 3-tuples |
|
// enrolled in the follow `combinations` array of arrays will actually be |
|
// executed. This is intended to aid exploration. |
|
const bool *combinations[] = {mapped_def, nonmap_pagelock_def, |
|
page_def, mapped_nondef, |
|
nonmap_pagelock_nondef, page_ndef}; |
|
|
|
// Once `combinations` is defined, find out how many experiments to run, for |
|
// each kernel |
|
const int Ncombinations = sizeof(combinations) / sizeof(combinations[0]); |
|
|
|
// How many trials? |
|
const int Ntrials = 1000; |
|
|
|
// CSV header |
|
printf("Info: CSV table of results: times (in microseconds) best of %d\n", |
|
Ntrials); |
|
printf(" k=kernel number\n"); |
|
printf(" pl=page locked?\n"); |
|
printf(" dm=device mapped?\n"); |
|
printf(" ns=new stream?\n"); |
|
printf(" O=overall GPU time\n"); |
|
printf(" K=kernel GPU time\n"); |
|
printf(" d=difference, O-K\n"); |
|
printf(" r=ratio between first row of this kernel and this row\n\n"); |
|
|
|
printf("k,pl,dm,ns,O,K,d,r\n"); |
|
|
|
// For each kernel, execute each experiment enrolled in the `combinations` |
|
// array of arrays. |
|
for (int kernel = 0; kernel < 3; kernel++) { |
|
TimerResults firstCombo; |
|
|
|
for (int combo = 0; combo < Ncombinations; combo++) { |
|
// Initialize a timer class |
|
TimerResults minTime; |
|
minTime.kernelMs = 1e9; |
|
minTime.overallMs = 1e9; |
|
|
|
for (int i = 0; i < Ntrials; i++) { |
|
// Time this trial of this experiment for this kernel |
|
TimerResults time; |
|
int retval = runme<float>(time, kernel, combinations[combo][0], |
|
combinations[combo][1], |
|
combinations[combo][2]); |
|
|
|
if (retval < 0) { |
|
printf("Err: runme() failed.\n"); |
|
return -1; |
|
} |
|
|
|
// Assuming the test didn't fail, find the fastest overall time |
|
if (time.overallMs < minTime.overallMs) { |
|
minTime = time; |
|
} |
|
} |
|
|
|
if (combo == 0) { |
|
firstCombo = minTime; |
|
} |
|
|
|
// Print the fastest time for this experiment, for this kernel |
|
printf("%d,%d,%d,%d,%g,%g,%g,%g\n", kernel, |
|
(int)combinations[combo][0], (int)combinations[combo][1], |
|
(int)combinations[combo][2], minTime.overallMs, |
|
minTime.kernelMs, minTime.overallMs - minTime.kernelMs, |
|
minTime.overallMs / firstCombo.overallMs); |
|
} |
|
printf("\n"); |
|
} |
|
|
|
// Do this here rather than in runme() to avoid painful slowdown. It has to |
|
// be here somewhere for Visual Profiler to be happy. |
|
cudaDeviceReset(); |
|
|
|
return 0; |
|
} |
Oops,
runme
never callscudaStreamDestroy
!