Skip to content

Instantly share code, notes, and snippets.

@grk
Created February 7, 2011 11:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save grk/814284 to your computer and use it in GitHub Desktop.
Save grk/814284 to your computer and use it in GitHub Desktop.
/*
* Copyright 1993-2006 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO USER:
*
* This source code is subject to NVIDIA ownership rights under U.S. and
* international Copyright laws.
*
* This software and the information contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a Non-Disclosure Agreement. Any reproduction or
* disclosure to any third party without the express written consent of
* NVIDIA is prohibited.
*
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
* OR PERFORMANCE OF THIS SOURCE CODE.
*
* U.S. Government End Users. This source code is a "commercial item" as
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
* "commercial computer software" and "commercial computer software
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
* and is provided to the U.S. Government only as a commercial end item.
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
* source code with only those rights set forth herein.
*/
/* Matrix multiplication: C = A * B.
* Device code.
*/
#ifndef _MATRIXMUL_KERNEL_H_
#define _MATRIXMUL_KERNEL_H_
#include <stdio.h>
#include "matrixmul.h"
////////////////////////////////////////////////////////////////////////////////
//! Simple test kernel for device functionality
//! @param g_idata input data in global memory
//! @param g_odata output data in global memory
////////////////////////////////////////////////////////////////////////////////
// Matrix multiplication kernel thread specification
__global__ void MatrixMulKernel(Matrix M, Matrix N, Matrix P)
{
int tx = threadIdx.x; int ty = threadIdx.y;
int bx = blockIdx.x; int by = blockIdx.y;
int row = by * TILE_WIDTH + ty;
int col = bx * TILE_WIDTH + tx;
__shared__ float Ms[TILE_WIDTH][TILE_WIDTH];
__shared__ float Ns[TILE_WIDTH][TILE_WIDTH];
float sum = 0.0;
int m_max = ceil((float)M.width/TILE_WIDTH);
unsigned int m = 0;
bool in_bounds_bottom = (by + 1) * TILE_WIDTH < M.height;
bool in_bounds_right = (bx + 1) * TILE_WIDTH < N.width;
if (in_bounds_right && in_bounds_bottom)
{
for (m = 0; m < (m_max - 1); ++m)
{
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
}
// for m = m_max - 1
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
if (m * TILE_WIDTH + k >= M.width)
break;
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
// end for m = m_max - 1
}
else if (in_bounds_right && !in_bounds_bottom)
{
for (m = 0; m < (m_max - 1); ++m)
{
if (row < M.height)
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
}
// for m = m_max - 1
if (m * TILE_WIDTH + tx < M.width && row < M.height)
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
if ((m * TILE_WIDTH + ty) < N.height)
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
if (m * TILE_WIDTH + k >= M.width)
break;
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
// end for m = m_max - 1
}
else if (!in_bounds_right && in_bounds_bottom)
{
for (m = 0; m < (m_max - 1); ++m)
{
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
if (col < N.width)
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
}
// for m = m_max - 1
if (m * TILE_WIDTH + tx < M.width)
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
if ((m * TILE_WIDTH + ty) < N.height && col < N.width)
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
if (m * TILE_WIDTH + k >= M.width)
break;
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
// end for m = m_max - 1
}
else
{
for (m = 0; m < (m_max - 1); ++m)
{
if (row < M.height)
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
if (col < N.width)
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
}
// for m = m_max - 1
if (m * TILE_WIDTH + tx < M.width && row < M.height)
Ms[ty][tx] = M.elements[row * M.width + m * TILE_WIDTH + tx];
if ((m * TILE_WIDTH + ty) < N.height && col < N.width)
Ns[ty][tx] = N.elements[(m * TILE_WIDTH + ty) * N.width + col];
__syncthreads();
for (unsigned int k = 0; k < TILE_WIDTH; ++k)
{
if (m * TILE_WIDTH + k >= M.width)
break;
sum += Ms[ty][k] * Ns[k][tx];
}
__syncthreads();
// end for m = m_max - 1
}
if (row < P.height && col < P.width)
P.elements[row * P.width + col] = sum;
}
#endif // #ifndef _MATRIXMUL_KERNEL_H_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment