Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created February 3, 2023 21:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ezyang/b22fa7b72b7349137211d8dc7041f758 to your computer and use it in GitHub Desktop.
Save ezyang/b22fa7b72b7349137211d8dc7041f758 to your computer and use it in GitHub Desktop.
import torch
import operator
import itertools
import sys
from typing import Tuple
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._refs import _maybe_broadcast
from torch._prims_common import is_same_shape, make_contiguous_strides_for
"""
How to model check two meta function implementations?
Limit ourselves ONLY to sizes (dtype/device can be checked through exhaustive
enumeration / is quite a bit easier to solve for.)
General recipe: things that are symbolically represented, use Z3. Things
that are specialized (including guards), test combinatorially.
Specialized things:
- Number of dimensions
- 0/1 size dims
- ~~Duck sizing~~ turned off
Iterate through all specialized things (in particular, pick a size [0, 1, 2]).
For each configuration, run with Z3. The guard configuration says "what we've
tested". Invert it and ask Z3 for another example that doesn't match this.
Keep going (OR'ing together) until nothing less. Move onto next configuration.
"""
def gen_size():
dim = MAX_DIM
for size in itertools.product([1, 2], repeat=dim):
yield size
class TensorSpec:
__slots__ = ['_size', '_stride']
def __init__(self, size, stride):
self._size = size
self._stride = stride
def numel(self):
r = 1
for s in self._size:
r *= s
return r
@property
def ndim(self):
return len(self._size)
@property
def shape(self):
return self._size
def size(self):
return self._size
def stride(self):
return self._stride
def check_same_shape(*args):
shape = None
for arg in args:
if shape is None:
shape = arg.shape
assert is_same_shape(shape, arg.shape)
# NOTE: Based on the implementation in TensorIterator.cpp, but note that
# the note [Computing output strides] is incorrect, because it
# says that strides will be preserved even if they are not
# "non overlapping and dense", but this is incorrect. The
# output of elementwise operations are always given
# non overlapping and dense strides.
# This is also INCORRECT because it does not model TensorIterator's
# short-circuit, which can cause different strides.
def compute_elementwise_output_permutation(*tensors) -> Tuple[int, ...]:
check_same_shape(*tensors)
# Short-circuits for shapes with zero or one dimensions
# TODO: are these necessary?
ndim = tensors[0].ndim
if ndim == 0:
return ()
if ndim == 1:
return (1,)
shape = tensors[0].shape
def should_swap(idx_a, idx_b):
for tensor in tensors:
stride_a = tensor.stride()[idx_a]
stride_b = tensor.stride()[idx_b]
if stride_a == 0 or stride_b == 0:
continue
if stride_a < stride_b:
return -1
if stride_a > stride_b:
return 1
# stride_a == stride_b
if shape[idx_a] > shape[idx_b]:
return 1
# Note: this case is hit if all strides are zero,
# or all strides are equal and all dimensions have the same length
return 0
perm = list(reversed(range(ndim)))
# insertion sort with support for ambiguous comparisons
for i in range(1, ndim):
dim1 = i
for dim0 in reversed(range(i)):
comparison = should_swap(perm[dim0], perm[dim1])
if comparison > 0:
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
dim1 = dim0
elif comparison < 0:
break
# Identity permutation is [2, 1, 0]
return perm
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
perm = compute_elementwise_output_permutation(*tensors)
shape = tensors[0].shape
return get_permuted_strides_for(perm, shape)
def get_permuted_strides_for(perm, shape):
ndim = len(shape)
permuted_shape = [-1] * ndim
for idx, x in enumerate(reversed(perm)):
permuted_shape[idx] = shape[x]
new_strides = make_contiguous_strides_for(permuted_shape)
permuted_strides = [-1] * ndim
for idx, x in enumerate(reversed(perm)):
permuted_strides[x] = new_strides[idx]
return tuple(permuted_strides)
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
def is_contiguous(a) -> bool:
"""
Tests whether a tensor is contiguous or not.
Tensors are contiguous when they have no elements,
one element, or when they have "nested" strides.
"""
if a.numel() < 2:
return True
expected_stride = 1
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
# Skips checking strides when a dimension has length 1
if x == 1:
continue
if y != expected_stride:
return False
expected_stride = expected_stride * x
return True
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
def is_channels_last_contiguous_2d(a) -> bool:
# NHWC or not channels last 2D contiguous
if a.ndim != 4:
return False
expected_stride = 1
for idx in (1, 3, 2, 0):
length = a.shape[idx]
if length == 1:
continue
stride = a.stride()[idx]
if stride != expected_stride:
return False
expected_stride *= length
return True
def fast_path(*operands):
ndim = len(operands[0].shape)
is_contiguous_ = True
is_channels_last = True
# TODO: is_non-overlapping_and_dense (not bound from Python
# no inplace, no out, everything defined
for op in operands:
is_contiguous_ = is_contiguous_ and is_contiguous(op)
is_channels_last = is_channels_last and is_channels_last_contiguous_2d(op)
if is_contiguous_:
return get_permuted_strides_for(list(reversed(range(ndim))), operands[0].shape)
# if is_channels_last
return None
from torch._dynamo.source import LocalSource
MAX_DIM = 3
for size in gen_size():
for astride in gen_size():
for bstride in gen_size():
shape_env = ShapeEnv()
def inflate(prefix, xs):
return tuple(
shape_env.create_symintnode(
shape_env.create_symbol(x, LocalSource(f"{prefix}{i}"))
)
for i, x in enumerate(xs)
)
i_size = inflate("s", size)
a = TensorSpec(i_size, inflate("a", astride))
b = TensorSpec(i_size, inflate("b", bstride))
r1 = compute_elementwise_output_strides(a, b)
r2 = fast_path(a, b)
def deflate(xs):
return tuple(
shape_env.size_hint(x.node.expr)
if isinstance(x, torch.SymInt) else x
for x in xs
)
def check_significant_strides(size, astride, bstride):
for idx in range(a.ndim):
if astride[idx] != bstride[idx] and size[idx] > 1:
return False
return True
if r2 is not None:
matches = check_significant_strides(size, r1, r2)
if not matches:
print((torch.empty_strided(size, astride) + torch.empty_strided(size, bstride)).stride())
raise RuntimeError(f"{deflate(r1)} != {deflate(r2)} for {size} {astride} {bstride}")
#print([g.expr for g in shape_env.guards])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment