Skip to content

Instantly share code, notes, and snippets.

@duncanriach
Last active March 17, 2022 23:06
Show Gist options
  • Save duncanriach/4c18cb07a73510c5fcb2deb52adbffaa to your computer and use it in GitHub Desktop.
Save duncanriach/4c18cb07a73510c5fcb2deb52adbffaa to your computer and use it in GitHub Desktop.
"""
This repro is associated with
https://github.com/tensorflow/tensorflow/issues/47174
This repro demonstrates that:
(1) tf.nn.depthwise_conv2d operates deterministically* in the forward direction
when running on both CPU and GPU;
(2) when running on CPU, tf.nn.depthwise_conv2d operates deterministically* in
backprop to both input and filter;
(3) when running on GPU and when op-determinism is expected,
tf.nn.depthwise_conv2d operates deterministically* in backprop to input;
(4) when running on GPU and when op-determinism is expected,
tf.nn.depthwise_conv2d operates determinsitically* in backprop to filter
when it is using cuDNN convolution (more info below);
(5) when running on GPU and when op-determinism is expected,
tf.nn.depthwise_conv2d operates nondeterminstically* in backprop to filter
when it is not using cuDNN convolution (more info below);
and
(6) when running on GPU and when op-determinism is not expected, operation*
is as would be expected given that cuDNN backprop algorithms to both
input and filter operate nondeterministically* by default and the
specialized algorithms for depthwise conv operate deterministically* in
backprop to input and nondeterministically* in backprop to filter.
*: for the tested parameters
"when op-determinism is expected" currently means when TF_DETERMINISTIC_OPS is
set to "true" or "1"
"when op-determinism is not expected" currently means when TF_DETERMINISTIC_OPS
either has not been set or has been set to "false" or "0".
cuDNN convolution is selected when there is only one input channel and cuDNN
supports the desired filter dimensions and number of output channels.
tf.keras.layers.DepthwiseConv2D uses tf.nn.depthwise_conv2d simply and direcly.
Therefore, any assertions about the latter apply equally to the former.
The findings above are summarised in the following table, where "D" stands
for deterministic and "ND" stands for nondeterministic:
| col | 1 | 2 | 3 | 4 | 5
| device | CPU | GPU | GPU | GPU | GPU
| op-determinism expected | | Y | Y | N | N
row | cuDNN used | | Y | N | Y | N
====o=========================o=====o=====o=====o=====o====
1 | forward | D | D | D | D | D
2 | backprop to input | D | D | D | ND | D
3 | backprop to filter | D | D | ND | ND | ND
Note that the only problematic case here is the cell at row 3, col 3
(GPU, op-determinism expected, cuDNN not used) because we expect op-determinism
and we're not getting it. On the other hand, we clearly don't care about
nondeterminism when op-determinism is not expected (i.e. col 4 and col 5).
For GPU functionality, these findings are consistent with looking at the code,
in which the CUDA kernel for the depthwise backprop to filter uses CUDA
atomicAdd and the CUDA kernel for the depthwise backprop to input does not.
For CPU functionality, some of the implementation code uses multi-threaded
sharding, but that doesn't appear to result in nondeterminism, which is not
unusual.
So far in this discussion, cuDNN auto-tuning has not been considered. cuDNN
auto-tuning would not (and should not) be disabled in col 4 of the table
above (GPU, op-determinism not expected, cuDNN used). Therefore, the "D" at
col 4, row 1 is not actually correct: auto-tuning could (and ultimately would)
introduce nondeterminism.
### SECTION REPLACED. See update below ###
Finally, while looking at the code-paths via which tf.nn.depthwise_conv2d uses
cuDNN convolution (e.g. in conv_grad_filter_ops.cc), it seems that there might
be other (perhaps partial) implementations of cuDNN auto-tuning, independent
from those used for the regular convolution ops (e.g. tf.nn.conv2d). Because
it's very difficult (or perhaps impossible) to test that auto-tuning
nondeterminism is not present, a careful code-level inspection and understanding of
the cuDNN auto-tuning mechanism for this op should be performed as part of
addressing nondeterminism in this op.
##########################################
2022-03-17 Update:
Finally, I have confirmed that the code-paths via which tf.nn.depthwise_conv2d
uses cuDNN convolution (e.g. in conv_grad_filter_ops.cc) also benefit from the
existing deterministic selection of deterministic cuDNN convolution algorithms
(when op-determinism is expected).
"""
import os
import tensorflow as tf
class DeterministicTest(tf.test.TestCase):
def _genParams(self, use_cudnn=False, data_format="NHWC", dtype=tf.float32,
seed=123):
tf.random.set_seed(seed)
batch_size = 2 # no interaction over batch, so make small
if use_cudnn:
# One input channel, plus a cuDNN-supported filter size and number of
# output channels will result in cuDNN being used for both
# backprop-to-input and backprop-to-filter on cuDNN 7 and higher.
input_channels = 1
else:
input_channels = 2 # no interaction over channels, so make small
input_height = 500
input_width = 1000
if data_format == "NHWC":
input_shape = (batch_size, input_height, input_width, input_channels)
else: # "NCHW"
input_shape = (batch_size, input_channels, input_height, input_width)
input_data = tf.random.normal(input_shape, dtype=dtype)
# The following filter size results in nondeterminism being exercised in
# cuDNN backprop (when determinism is not enabled) to both input and filter
# as well as in the specialized depthwise backprop to filter.
filter_height = 7
filter_width = 7
channel_multiplier = 10
filter_shape = (
filter_height, filter_width, input_channels, channel_multiplier)
filter_data = tf.random.normal(filter_shape, dtype=dtype)
strides = [1, 1, 1, 1]
padding = 'SAME'
output_height = input_height # because same padding
output_width = input_width # because same padding
output_channels = input_channels * channel_multiplier
if data_format == "NHWC":
output_shape = (batch_size, output_height, output_width, output_channels)
else: # "NCHW"
output_shape = (batch_size, output_channels, output_height, output_width)
return input_data, filter_data, strides, padding, output_shape
def _testForwardCase(self, use_cudnn=False, data_format="NHWC",
dtype=tf.float32):
for seed in range(5):
p = self._genParams(use_cudnn, data_format, dtype, seed=seed)
input_data, filter_data, strides, padding, _ = p
result_a = tf.nn.depthwise_conv2d(
input_data, filter_data, strides, padding, data_format)
result_b = tf.nn.depthwise_conv2d(
input_data, filter_data, strides, padding, data_format)
self.assertAllEqual(result_a, result_b)
def testForwardGPU(self):
for use_cudnn in [False, True]:
for data_format in ["NHWC", "NCHW"]:
for dtype in [tf.float16, tf.float32, tf.float64]:
self._testForwardCase(use_cudnn, data_format, dtype=dtype)
def testForwardCPU(self):
data_format = "NHWC" # CPU does not implement NCHW version of op
for dtype in [tf.float16, tf.float32, tf.float64]:
with tf.device("/cpu:0"): # self.session(use_gpu=False) does not work
self._testForwardCase(data_format=data_format, dtype=dtype)
def _deterministicOpsEnabled(self):
envvar = os.getenv('TF_DETERMINISTIC_OPS')
if envvar == "true" or envvar == "1":
return True
else:
return False
def _testBackwardCase(self, use_cudnn=False, data_format="NHWC",
expect_cpu_results=False, dtype=tf.float32):
p = self._genParams(use_cudnn, data_format, dtype, seed=123)
input_data, filter_data, strides, padding, output_shape = p
def gradients(seed):
tf.random.set_seed(seed)
upstream_gradients = tf.random.normal(output_shape, dtype=dtype)
with tf.GradientTape() as tape:
tape.watch(input_data)
tape.watch(filter_data)
op_output = tf.nn.depthwise_conv2d(
input_data, filter_data, strides, padding, data_format)
gradient_injector_output = op_output * upstream_gradients
return tape.gradient(gradient_injector_output, [input_data, filter_data])
seed = 987
input_gradients_a, filter_gradients_a = gradients(seed)
input_gradients_b, filter_gradients_b = gradients(seed)
if expect_cpu_results:
self.assertAllEqual(input_gradients_a, input_gradients_b)
self.assertAllEqual(filter_gradients_a, filter_gradients_b)
else: # GPU results
if self._deterministicOpsEnabled():
self.assertAllEqual(input_gradients_a, input_gradients_b)
if use_cudnn:
self.assertAllEqual(filter_gradients_a, filter_gradients_b)
else:
self.assertNotAllEqual(filter_gradients_a, filter_gradients_b)
else: # deterministic ops not enabled (comment-out in main)
if use_cudnn:
self.assertNotAllEqual(input_gradients_a, input_gradients_b)
else:
self.assertAllEqual(input_gradients_a, input_gradients_b)
self.assertNotAllEqual(filter_gradients_a, filter_gradients_b)
def _testBackwardGPU(self):
for use_cudnn in [False, True]:
for data_format in ["NHWC", "NCHW"]:
for dtype in [tf.float16, tf.float32, tf.float64]:
self._testBackwardCase(use_cudnn, data_format, dtype=dtype)
def _testBackwardCPU(self):
data_format = "NHWC" # CPU does not implement NCHW version of op
for dtype in [tf.float16, tf.float32, tf.float64]:
with tf.device("/cpu:0"): # self.session(use_gpu=False) does not work
self._testBackwardCase(
data_format=data_format, expect_cpu_results=True, dtype=dtype)
if __name__ == '__main__':
os.environ["TF_DETERMINISTIC_OPS"] = "1"
tf.test.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment