Skip to content

Instantly share code, notes, and snippets.

@telamon
Created October 5, 2023 22:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save telamon/dc87e4a2bb67a839bcf02536ec687a05 to your computer and use it in GitHub Desktop.
Save telamon/dc87e4a2bb67a839bcf02536ec687a05 to your computer and use it in GitHub Desktop.
beginner pytorch tensor-cheatsheet.

Pytorch Tensor Cheatsheet

1. Initialization

  • Scalar (0D tensor):

    x = torch.tensor(42)
    # Shape: []
  • Vector (1D tensor):

    x = torch.tensor([1, 2, 3])
    # Shape: [3]
  • Matrix (2D tensor):

    x = torch.tensor([[1, 2], [3, 4], [5, 6]])
    # Shape: [3, 2]
  • 3D tensor:

    x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    # Shape: [2, 2, 2]

2. Reshaping

  • Reshape/View:

    x = torch.randn(2, 3)
    y = x.view(6)
    # Shape of y: [6]
  • Squeeze (remove dimensions of size 1):

    x = torch.randn(1, 3, 1)
    y = x.squeeze()
    # Shape of y: [3]
  • Unsqueeze (add a dimension of size 1):

    x = torch.randn(3)
    y = x.unsqueeze(0)
    # Shape of y: [1, 3]

3. Combining Tensors

  • Stack:

    x = torch.tensor([1, 2])
    y = torch.tensor([3, 4])
    z = torch.stack((x, y))
    # Shape of z: [2, 2]
  • Concatenate:

    x = torch.tensor([[1, 2]])
    y = torch.tensor([[3, 4]])
    z = torch.cat((x, y), dim=0)
    # Shape of z: [2, 2]

4. Reduction Operations

  • Sum along a dimension:
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    y = torch.sum(x, dim=0)
    # Shape of y: [3]

5. Expansion/Broadcasting

  • Expand dimensions:
    x = torch.randn(3)
    y = x[:, None]  # Adds a new dimension
    # Shape of y: [3, 1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment