Skip to content

Instantly share code, notes, and snippets.

@bonprosoft
Last active September 11, 2017 07:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bonprosoft/72308c76442b5842ddb0b0ff898d3df6 to your computer and use it in GitHub Desktop.
Save bonprosoft/72308c76442b5842ddb0b0ff898d3df6 to your computer and use it in GitHub Desktop.
Validation of cuDNN dropout function behavior
False
True
import cupy
from cupy import cudnn
libcudnn = cudnn.cudnn
SEED = 0
def validate():
handle = cudnn.get_handle()
d_states = cudnn.create_dropout_states(handle)
desc = cudnn.create_dropout_descriptor(
handle, 0.0, d_states.data.ptr, d_states.size, SEED)
def forward(x):
x = cupy.ascontiguousarray(x)
x = cudnn._as4darray(x)
x_desc = cudnn.create_tensor_descriptor(x)
y = cupy.empty_like(x)
reserve_size = libcudnn.getDropoutReserveSpaceSize(x_desc.value)
reserve_space = cupy.empty((reserve_size,))
libcudnn.dropoutForward(handle, desc.value,
x_desc.value, x.data.ptr,
x_desc.value, y.data.ptr,
reserve_space.data.ptr, reserve_size)
return reserve_space
def backward(dy, states):
dx = cupy.empty_like(dy)
dy = cupy.ascontiguousarray(dy)
dy = cudnn._as4darray(dy)
dy_desc = cudnn.create_tensor_descriptor(dy)
libcudnn.dropoutBackward(handle, desc.value,
dy_desc.value, dy.data.ptr,
dy_desc.value, dx.data.ptr,
states.data.ptr, states.size)
return dx
x = cupy.random.random_sample((10, 10))
x2 = cupy.random.random_sample((10, 20))
dy = cupy.random.random_sample((10, 10))
dy2 = cupy.random.random_sample((10, 20))
cudnn.set_dropout_descriptor(desc, handle, 0.75)
p = forward(x)
gx = backward(dy, p)
# perform other dropout operation with same descriptor
cudnn.set_dropout_descriptor(desc, handle, 0.5)
q = forward(x2)
gx2 = backward(dy2, q)
# backward with different dropout ratio
gx_ = backward(dy, p)
# backward with same dropout ratio
cudnn.set_dropout_descriptor(desc, handle, 0.75)
gx__ = backward(dy, p)
print(cupy.all(gx == gx_)) # may be different
print(cupy.all(gx == gx__)) # must be same
if __name__ == '__main__':
validate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment