Skip to content

Instantly share code, notes, and snippets.

View killeent's full-sized avatar

Trevor Killeen killeent

View GitHub Profile

A Tour of PyTorch Internals (Part I)

The fundamental unit in PyTorch is the Tensor. This post will serve as an overview for how we implement Tensors in PyTorch, such that the user can interact with it from the Python shell. In particular, we want to answer four main questions:

  1. How does PyTorch extend the Python interpreter to define a Tensor type that can be manipulated from Python code?
  2. How does PyTorch wrap the C libraries that actually define the Tensor's properties and methods?
  3. How does PyTorch cwrap work to generate code for Tensor methods?
  4. How does PyTorch's build system take all of these components to compile and generate a workable application?

Extending the Python Interpreter

PyTorch defines a new package torch. In this post we will consider the ._C module. This module is known as an "extension module" - a Python module written in C. Such modules allow us to define new built-in object types (e.g. the Tensor) and to call C/C++ functions.

WORK IN PROGRESS

PyTorch Internals Part II - The Build System

In the first post I explained how we generate a torch.Tensor object that you can use in your Python interpreter. Next, I will explore the build system for PyTorch. The PyTorch codebase has a variety of components:

  • The core Torch libraries: TH, THC, THNN, THCUNN
  • Vendor libraries: CuDNN, NCCL
  • Python Extension libraries
  • Additional third-party libraries: NumPy, MKL, LAPACK
#if !defined(TH_REAL_IS_HALF)
PyObject * THPTensor_(addmv)(PyObject *self, PyObject *args, PyObject *kwargs)
{
PyObject *__kw_beta = NULL;
PyObject *__kw_alpha = NULL;
PyObject *__kw_mat = NULL;
PyObject *__kw_vec = NULL;
if (kwargs) {
__kw_beta = PyDict_GetItemString(kwargs, "beta");
__kw_alpha = PyDict_GetItemString(kwargs, "alpha");
diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt
index 9cccd34..136ce27 100644
--- a/aten/CMakeLists.txt
+++ b/aten/CMakeLists.txt
@@ -70,5 +70,10 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_CURRENT_BINARY_DIR}/src/ATen)
add_subdirectory(src/ATen/test)
-add_subdirectory(contrib/data)
-add_subdirectory(contrib/meter)
diff --git a/torch/csrc/distributed/Module.cpp b/torch/csrc/distributed/Module.cpp
index a985509..293a4e1 100644
--- a/torch/csrc/distributed/Module.cpp
+++ b/torch/csrc/distributed/Module.cpp
@@ -186,8 +186,8 @@ THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj)
PyObject *type = (PyObject*)Py_TYPE(obj);
#define REGISTER_TH_DESCRIPTOR(TYPE, REAL) \
if (type == THP##TYPE##Class) \
- return at::CPU(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true);
#!/usr/bin/env bash
set -e
PYCMD=${PYCMD:="python"}
COVERAGE=0
while [[ "$#" -gt 0 ]]; do
case "$1" in
-p|--python) PYCMD=$2; shift 2 ;;
-c|--coverage) COVERAGE=1; shift 1;;
--) shift; break ;;
// Assumes:
// - input is (N, C, H, W)
// - gradOutput is (N, C, goH, goW)
// - gradWeight is (C, 1, kH, kW) --> (C, kH, kW)
// Naive Loop: No striding, padding, dilation handled
// These three loops would be parallelized, such that each is computed by a single block
for (int ch = 0; ch < C; ++ch) {
for (gw_h_offset = 0; gw_h_offset < kH; ++gw_h_offset) {

PyTorch now supports a subset of NumPy style advanced indexing. This allows users to select arbitrary indices at each dimension of the Tensor, including non-adjacent indices and duplicate indices, using the same []-style operation. This allows for a more flexible indexing strategy without needing calls to PyTorch's Index[Select, Add, ...] functions.

x = torch.Tensor(5, 5, 5)

# Pure Integer Array Indexing - specify arbitrary indices at each dim
x[[1, 2], [3, 2], [1, 0]] 
--> yields a 2-element Tensor (x[1][3][1], x[2][2][0])

# also supports broadcasting, duplicates

Advanced Indexing

Tensor & CPUByteType::cat_out(TensorList tensors, int dim, Tensor & self) {
auto self_ = checked_cast<CPUByteTensor>(self.pImpl,"self",0);
auto tensors_ = tensor_list_checked_cast<CPUByteTensor, Tensor, THByteTensor>(tensors,"tensors",1);
THByteTensor_catArray(self_->tensor, tensors_.data(), tensors_.size(), dim);
return self;
}