Skip to content

Instantly share code, notes, and snippets.

@lebedov
Created February 23, 2017 15:10
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lebedov/e8f932e3f6bc129adcfcae43d5229d8b to your computer and use it in GitHub Desktop.
Save lebedov/e8f932e3f6bc129adcfcae43d5229d8b to your computer and use it in GitHub Desktop.
Algorithm for concatenating half precision pytorch tensors.
#!/usr/bin/env python
"""
Algorithm for concatenating half precision tensors by allocating new output matrix
of appropriate size and copying each of the constituent tensors into it with
appropriate offsets.
"""
import numpy as np
import torch
from torch.autograd import Variable
from torch.cuda import HalfTensor
def cat_half(inputs, dimension=0):
"""
Concatenate half precision tensors along specified dimension.
"""
# Validate check inputs:
assert all([isinstance(x, HalfTensor) for x in inputs]) or \
all([isinstance(x, Variable) and isinstance(x.data, HalfTensor) for x in inputs])
# If the inputs are Variable instances, the output should also be a Variable instance:
if isinstance(x, Variable):
out_variable = True
else:
out_variable = False
# Create array of tensor dimensions:
dims = []
for x in inputs:
dims.append(list(x.size()))
dims = np.array(dims)
# Ensure that the magnitude of the dimensions other than that
# along which the tensors will be concatenated are all equal:
for i in range(dims.shape[1]):
if i != dimension and len(set(dims[:, i])) > 1:
raise ValueError('cannot concatenate')
# Allocate new tensor whose concatenation dimension is the sum of
# the corresponding dimension magnitudes of the inputs:
new_dims = dims[0]
new_dims[dimension] = sum(dims[:, dimension])
out = HalfTensor(*new_dims)
# Copy in the input tensors:
offset = 0
for x in inputs:
s = [slice(None, None) for i in range(dims.shape[1])]
s[dimension] = slice(offset, offset+x.size(dimension))
if isinstance(x, Variable):
out[tuple(s)] = x.data
else:
out[tuple(s)] = x
offset += x.size(dimension)
if out_variable:
out = Variable(out)
return out
if __name__ == '__main__':
a = torch.rand(3, 2)
b = torch.rand(2, 2)
ab_float = torch.cat((a, b))
a_half = a.cuda(0).half()
b_half = b.cuda(0).half()
ab_half = cat_half((a_half, b_half))
d = torch.rand(3, 1, 2)
e = torch.rand(3, 2, 2)
de_float = torch.cat((d, e), 1)
d_half = d.cuda(0).half()
e_half = e.cuda(0).half()
de_half = cat_half((d_half, e_half), 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment