Skip to content

Instantly share code, notes, and snippets.

View killeent's full-sized avatar

Trevor Killeen killeent

View GitHub Profile
// Case 1: arg is a non-tuple sequence object
if (PySequence_Check(arg) && !PyTuple_Check(arg)) return true;
#ifdef WITH_NUMPY
// Case 2: arg is an nd-array with type integer or bool
if (PyArray_Check(arg) && (PyArray_TYPE((PyArrayObject*)arg) == NPY_INT64 || PyArray_TYPE((PyArrayObject*)arg) == NPY_BOOL)) return true;
#endif
// Case 3: arg is a tuple containing at least one sequence object, ndarray, or LongTensor
if (PyTuple_Check(arg)) {
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/Tensor.cpp"
#else
#ifdef WITH_NUMPY
#ifdef TH_REAL_IS_DOUBLE
#define NUMPY_TYPE_ENUM NPY_DOUBLE
#endif
#ifdef TH_REAL_IS_FLOAT

A Tour of PyTorch Internals (Part I)

The fundamental unit in PyTorch is the Tensor. This post will serve as an overview for how we implement Tensors in PyTorch, such that the user can interact with it from the Python shell. In particular, we want to answer four main questions:

  1. How does PyTorch extend the Python interpreter to define a Tensor type that can be manipulated from Python code?
  2. How does PyTorch wrap the C libraries that actually define the Tensor's properties and methods?
  3. How does PyTorch cwrap work to generate code for Tensor methods?
  4. How does PyTorch's build system take all of these components to compile and generate a workable application?

Extending the Python Interpreter

PyTorch defines a new package torch. In this post we will consider the ._C module. This module is known as an "extension module" - a Python module written in C. Such modules allow us to define new built-in object types (e.g. the Tensor) and to call C/C++ functions.

#if !defined(TH_REAL_IS_HALF)
PyObject * THPTensor_(addmv)(PyObject *self, PyObject *args, PyObject *kwargs)
{
PyObject *__kw_beta = NULL;
PyObject *__kw_alpha = NULL;
PyObject *__kw_mat = NULL;
PyObject *__kw_vec = NULL;
if (kwargs) {
__kw_beta = PyDict_GetItemString(kwargs, "beta");
__kw_alpha = PyDict_GetItemString(kwargs, "alpha");
bool THPUtils_checkAdvancedIndexing(PyObject *arg) {
// Checks whether the specified selection object should trigger advanced
// indexing
// Case 1: arg is a non-tuple sequence object
if (PyList_Check(arg) || PyRange_Check(arg)) return true;
#ifdef WITH_NUMPY
// Case 2: arg is an nd-array with type integer or bool
if (PyArray_Check(arg) && (PyArray_TYPE((PyArrayObject*)arg) == NPY_INT64 || PyArray_TYPE((PyArrayObject*)arg) == NPY_BOOL)) return true;
#ifndef THC_REDUCE_APPLY_UTILS_INC
#define THC_REDUCE_APPLY_UTILS_INC
#include <algorithm>
#include <cuda.h>
#include <assert.h>
#include "THCGeneral.h"
#include "THCTensor.h"
#include "THCDeviceUtils.cuh"
#include "THCTensorInfo.cuh"
Testing average duration for 10 loops
Testing 1D Tensor of size 8: 6 usec (TH), 51 usec (THC)
Testing 1D Tensor of size 16: 1 usec (TH), 43 usec (THC)
Testing 1D Tensor of size 32: 1 usec (TH), 42 usec (THC)
Testing 1D Tensor of size 64: 1 usec (TH), 57 usec (THC)
Testing 1D Tensor of size 128: 3 usec (TH), 60 usec (THC)
Testing 1D Tensor of size 256: 4 usec (TH), 99 usec (THC)
Testing 1D Tensor of size 512: 10 usec (TH), 128 usec (THC)
Testing 1D Tensor of size 1024: 24 usec (TH), 130 usec (THC)
Testing 1D Tensor of size 2048: 52 usec (TH), 1723 usec (THC)
// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
// return the reduced value
template <typename T, typename ReduceOp>
__device__ T reduceBlock(T* smem,
int numVals,
T threadVal,
ReduceOp reduceOp,
T init) {
if (numVals == 0) {
return init;
// Block-wide reduction where each thread locally reduces N
// values before letting a single warp take over
template <typename T, typename ReduceOp, int N>
__device__ T reduceBlockN(T *smem,
int numVals,
ReduceOp reduceOp,
T init) {
T local = threadIdx.x < numVals ? smem[threadIdx.x] : init;
#pragma unroll
Testing average duration for 10 loops
Testing 1D Tensor of size 8: 6 usec (TH), 49 usec (THC)
Testing 1D Tensor of size 16: 1 usec (TH), 39 usec (THC)
Testing 1D Tensor of size 32: 1 usec (TH), 41 usec (THC)
Testing 1D Tensor of size 64: 1 usec (TH), 51 usec (THC)
Testing 1D Tensor of size 128: 2 usec (TH), 53 usec (THC)
Testing 1D Tensor of size 256: 6 usec (TH), 83 usec (THC)
Testing 1D Tensor of size 512: 10 usec (TH), 108 usec (THC)
Testing 1D Tensor of size 1024: 23 usec (TH), 109 usec (THC)
Testing 1D Tensor of size 2048: 48 usec (TH), 1370 usec (THC)