Created
February 15, 2018 23:54
-
-
Save anonymous/dc0cd7de343922a8c0c0636ccc4889a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestAbsolute(unittest.TestCase): | |
def test_absolute(self): | |
LANG = """ | |
def abs(float(M, N) A) -> (O1) { | |
O1(m, n) = fabs(A(m, n)) | |
} | |
""" | |
absolute = tc.define(LANG, name="abs") | |
A = -1 * torch.randn(3, 4).cuda() | |
out = absolute(A, options=tc.Options("pointwise")) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
from tensor_comprehensions.mapping_options import Options | |
import torch | |
import torch.cuda | |
import os, unittest | |
MATMUL_LANG = """ | |
def matmul(float(M,N) A, float(N,K) B) -> (output) { | |
output(i, j) +=! A(i, kk) * B(kk, j) | |
} | |
""" | |
PATH_PREFIX = os.path.join("/tmp/", "tc_test") | |
if not os.path.exists(PATH_PREFIX): | |
os.makedirs(PATH_PREFIX) | |
########################################################################### | |
# Autotuner tests | |
########################################################################### | |
class TestAutotuner(unittest.TestCase): | |
########################################################################### | |
# Pass tuple inputs for autotuning | |
########################################################################### | |
def test_autotuner_tuple_size_no_cache(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
matmul.autotune((3, 4), (4, 5), **tc.small_size_autotuner_options) | |
matmul.autotune((100, 400), (400, 500), **tc.autotuner_default_options) | |
def test_autotuner_tuple_size_cache_to_default(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
matmul.autotune((3, 4), (4, 5), cache=True, **tc.small_size_autotuner_options) | |
matmul.autotune((100, 400), (400, 500), cache=True, **tc.autotuner_default_options) | |
def test_autotuner_tuple_size_cache_to_file_run_kernel(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
cache1 = "{}/matmul_3_4_5".format(PATH_PREFIX) | |
cache2 = "{}/matmul_100_400_500".format(PATH_PREFIX) | |
matmul.autotune((3, 4), (4, 5), cache=cache1, **tc.small_size_autotuner_options) | |
matmul.autotune((100, 400), (400, 500), cache=cache2, **tc.autotuner_default_options) | |
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() | |
out = matmul(mat1, mat2, cache=cache1) | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
out = matmul(mat1, mat2, cache=cache2) | |
########################################################################### | |
# Pass Tensors for autotuning | |
########################################################################### | |
# NOTE: Use "--tuner_min_launch_total_threads=1" for running small sizes | |
# tc.small_size_autotuner_options has this option set already | |
def test_autotuner_no_cache_small_size(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() | |
options = matmul.autotune(mat1, mat2, **tc.small_size_autotuner_options) | |
def test_autotuner_no_cache(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
options = matmul.autotune(mat1, mat2, **tc.autotuner_default_options) | |
def test_autotuner_no_cache_explicit_set(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
options = matmul.autotune(mat1, mat2, cache=False, **tc.autotuner_default_options) | |
def test_autotuner_cache_to_default(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
matmul.autotune(mat1, mat2, cache=True, **tc.autotuner_default_options) | |
def test_autotuner_cachefile_first(self): | |
cache_file = "{}/matmul_100_400_500".format(PATH_PREFIX) # use argparse if input from command line | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
matmul.autotune(mat1, mat2, cache=cache_file, **tc.autotuner_default_options) | |
def test_autotuner_cachefile_load(self): | |
lang = MATMUL_LANG | |
cache_file = "{}/matmul_100_400_500".format(PATH_PREFIX) # use argparse if input from command line | |
assert os.path.isfile("{}.cuda".format(cache_file)), "looks like the cache_file doesn't exist" | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
out = matmul(mat1, mat2, cache=cache_file) | |
def test_autotuner_no_cache_and_run_kernel(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
options = matmul.autotune(mat1, mat2, **tc.autotuner_default_options) | |
out = matmul(mat1, mat2, options=options) | |
def test_autotuner_start_options_and_run_kernel(self): | |
lang = MATMUL_LANG | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda() | |
options = Options("mlp") | |
best_options = matmul.autotune(mat1, mat2, cache=True, options=options, **tc.autotuner_default_options) | |
out = matmul(mat1, mat2, options=best_options) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestAvgPool(unittest.TestCase): | |
def test_avgpool(self): | |
# NOTE: take note of use of {{ }} below for handling TC with scalars | |
LANG=""" | |
def avgpool(float(B, C, H, W) input) -> (output) {{ | |
output(b, c, h, w) += input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW} | |
}} | |
""" | |
avgpool = tc.define(LANG, name="avgpool", constants={"sH":1, "sW":1, "kH":2, "kW":2}) | |
inp = torch.ones(32, 3, 10, 10).cuda() | |
out = avgpool(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
PATH_PREFIX = os.path.join("/tmp/", "tc_test") | |
if not os.path.exists(PATH_PREFIX): | |
os.makedirs(PATH_PREFIX) | |
class TestAvgPoolAutotune(unittest.TestCase): | |
def test_avgpool_autotune(self): | |
# NOTE: take note of use of {{ }} below for handling TC with scalars | |
LANG=""" | |
def avgpool(float(B, C, H, W) input) -> (output) {{ | |
output(b, c, h, w) += input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW} | |
}} | |
""" | |
avgpool = tc.define(LANG, name="avgpool", constants={"sH":1, "sW":1, "kH":2, "kW":2}) | |
inp = torch.ones(32, 3, 10, 10).cuda() | |
best_options = avgpool.autotune(inp, **tc.small_size_autotuner_options) | |
out = avgpool(inp, options=best_options) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestBatchMatmul(unittest.TestCase): | |
def test_batchmatmul(self): | |
lang = """ | |
def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) { | |
Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k) | |
} | |
""" | |
matmul = tc.define(lang, name="batch_matmul") | |
mat1, mat2 = torch.randn(32, 100, 400).cuda(), torch.randn(32, 400, 500).cuda() | |
out = matmul(mat1, mat2) | |
print(out[0].shape) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import torch | |
import torch.cuda | |
import tensor_comprehensions as tc | |
import unittest | |
class TestBatchNorm(unittest.TestCase): | |
def test_batchnorm(self): | |
# NOTE: take note of use of {{ }} below for handling TC with scalars | |
lang = """ | |
def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) | |
-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) | |
{{ | |
mean(c) +=! I(nn, c, hh, ww) | |
mean(c) = mean(c) / (N * H * W) | |
rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c) | |
centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) | |
variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) | |
expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W) | |
rVarOut(c) = rsqrt( | |
(1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c)) | |
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) | |
normalizedOut(n, c, h, w) = O(n, c, h, w) | |
}} | |
""" | |
batchnorm = tc.define(lang, name="batchnorm", constants={"momentum": 0.5, "eps": 1e-5}) | |
inp = torch.randn(32, 4, 56, 56).cuda() | |
running_mean, running_var = torch.randn(4).cuda(), torch.randn(4).cuda() | |
out = batchnorm(inp, running_mean, running_var) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestCast(unittest.TestCase): | |
def test_cast(self): | |
LANG = """ | |
def cast(float(M,N) A) -> (int32(M,N) O1) {{ | |
O1(m, n) = int32(A(m, n) + {four}) | |
}} | |
""" | |
cast = tc.define(LANG, name="cast", constants={"four": 4}) | |
A = torch.randn(32, 16).cuda() | |
out = cast(A) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestConcat(unittest.TestCase): | |
def test_concat(self): | |
LANG = """ | |
def concat(float(M, N) A, float(M, N) B) -> (O1) { | |
O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2 | |
} | |
""" | |
concat = tc.define(LANG, name="concat") | |
A, B = torch.randn(32, 16).cuda(), torch.randn(32, 16).cuda() | |
out = concat(A, B) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestConvolutionSimple(unittest.TestCase): | |
def test_convolution_simple(self): | |
LANG=""" | |
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) B) -> (O) { | |
O(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw) | |
O(n, m, h, w) = O(n, m, h, w) + B(m) | |
} | |
""" | |
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 3, 3, 1, 1 | |
convolution = tc.define(LANG, name="convolution") | |
I, W = torch.randn(N, C, H, W).cuda(), torch.randn(O, C, kH, kW).cuda() | |
B = torch.randn(O).cuda() | |
out = convolution(I, W, B) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestConvolutionStrided(unittest.TestCase): | |
def test_convolution_strided(self): | |
# NOTE: take note of use of {{ }} below for handling TC with scalars | |
LANG=""" | |
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) B) -> (O) {{ | |
O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) | |
O(n, m, h, w) = O(n, m, h, w) + B(m) | |
}} | |
""" | |
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 3, 3, 1, 1 | |
convolution = tc.define(LANG, name="convolution", constants={"sh": sH, "sw": sW}) | |
I, W = torch.randn(N, C, H, W).cuda(), torch.randn(O, C, kH, kW).cuda() | |
B = torch.randn(O).cuda() | |
out = convolution(I, W, B) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import os, unittest | |
class TestAutotuneConvolutionStrided(unittest.TestCase): | |
def test_autotune_convolution_strided(self): | |
# NOTE: take note of use of {{ }} below for handling TC with scalars | |
LANG=""" | |
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) B) -> (O) {{ | |
O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) | |
O(n, m, h, w) = O(n, m, h, w) + B(m) | |
}} | |
""" | |
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 3, 3, 1, 1 | |
convolution = tc.define(LANG, name="convolution", constants={"sh":sH, "sw":sW}) | |
I, W1, B = torch.randn(N, C, H, W).cuda(), torch.randn(O, C, kH, kW).cuda(), torch.randn(O).cuda() | |
best_options = convolution.autotune(I, W1, B, **tc.autotuner_default_options) | |
out = convolution(I, W1, B, options=best_options) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
from torch.autograd import Variable | |
from torch.nn.parameter import Parameter | |
import unittest | |
class TestTrainConvolutionStrided(unittest.TestCase): | |
def test_train_convolution_strided(self): | |
# NOTE: take note of use of {{ }} below for handling TC with scalars | |
LANG = """ | |
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{ | |
O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw) | |
}} | |
def convolution_grad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) O_grad) | |
-> (I_grad, W1_grad) | |
{{ | |
I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw) | |
W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w) | |
}} | |
""" | |
# NOTE: TC doesn't support padding yet | |
# see https://github.com/facebookresearch/TensorComprehensions/issues/11 | |
# due to this reason, we use kernel=1 for now (only because we want to) | |
# do the backwards as well. If kernel != 1 then we will have inconsistent | |
# values of H, W in the backward TC | |
N, C, H, W, O, kH, kW, sH, sW = 32, 4, 56, 56, 16, 1, 1, 1, 1 | |
convolution = tc.define(LANG, training=True, name="convolution", backward="convolution_grad", constants={"sh":sH, "sw":sW}) | |
I = Variable(torch.randn(N, C, H, W).cuda(), requires_grad=True) | |
W = Parameter(torch.randn(O, C, kH, kW).cuda()) | |
out = convolution(I, W) | |
out[0].sum().backward() | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestCopy(unittest.TestCase): | |
def test_copy(self): | |
LANG = """ | |
def copy(float(M, N) I) -> (O) { | |
O(i, j) = I(i, j) | |
} | |
""" | |
copy = tc.define(LANG, name="copy") | |
inp = torch.randn(32, 32).cuda() | |
out = copy(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestCosine(unittest.TestCase): | |
def test_cosine(self): | |
LANG = """ | |
def cosine(float(M) I) -> (O) { | |
O(i) = cos(I(i)) | |
} | |
""" | |
cosine = tc.define(LANG, name="cosine") | |
inp = torch.randn(32).cuda() | |
out = cosine(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestCosineSimilarity(unittest.TestCase): | |
# NOTE: TC can't do allocations itself, so everything has to be declared | |
# as input or output. Hence, we return the temporary outputs as well | |
def test_cosine_similarity(self): | |
LANG = """ | |
def cosine_similarity(float(M, N) I1, float(M, N) I2) -> (O, sumI1, sumI2) {{ | |
sumI1(m) +=! I1(m, n) * I1(m, n) | |
sumI2(m) +=! I2(m, n) * I2(m, n) | |
O(m) +=! (I1(m, n) * I2(m, n)) / fmax(rsqrt(sumI1(m)) * sqrt(sumI2(m)), {eps}) | |
}} | |
""" | |
cosine_similarity = tc.define(LANG, name="cosine_similarity", constants={"eps": 1e-5}) | |
inp1, inp2 = torch.randn(100, 128).cuda(), torch.randn(100, 128).cuda() | |
out = cosine_similarity(inp1, inp2) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
# enable this to dump cuda code generated whenever tc layer runs: simple run or | |
# autotuner run | |
tc.GlobalDebugInit(["tc", "--dump_cuda=true"]) | |
class TestDumpCuda(unittest.TestCase): | |
def test_dump_cuda(self): | |
LANG = """ | |
def matmul(float(M,N) A, float(N,K) B) -> (output) { | |
output(i, j) +=! A(i, kk) * B(kk, j) | |
} | |
""" | |
matmul = tc.define(LANG, name="matmul") | |
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() | |
out = matmul(mat1, mat2) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestManualCudaInjection(unittest.TestCase): | |
def test_simple_cuda_injection(self): | |
lang = """ | |
def add(float(N) A, float(N) B) -> (output) { | |
output(i) = A(i) + B(i) + 1 | |
} | |
""" | |
cuda_code = """ | |
extern "C"{ | |
__global__ void my_add(float* __restrict__ output, const float* __restrict__ A, const float* __restrict B) | |
{ | |
int t = threadIdx.x; | |
output[t] = A[t] + B[t]; | |
} | |
} | |
""" | |
add = tc.define(lang, name="add", inject_kernel="my_add", cuda_code=cuda_code) | |
a, b = torch.randn(100).cuda(), torch.randn(100).cuda() | |
out = add(a, b, grid=[1, 1, 1], block=[100, 1, 1]) | |
if __name__ == '__main__': | |
unittest.main() | |
# TODO: add test for 'where', cpu codepath |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestFC(unittest.TestCase): | |
def test_fc(self): | |
LANG = """ | |
def fc(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) { | |
O1(b, n) +=! I(b, m) * W1(n, m) | |
O1(b, n) = O1(b, n) + B1(n) | |
} | |
""" | |
B, M, N = 100, 128, 100 | |
fc = tc.define(LANG, name="fc") | |
I, W1, B1 = torch.randn(B, M).cuda(), torch.randn(N, M).cuda(), torch.randn(N).cuda() | |
out = fc(I, W1, B1) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestFCRelu(unittest.TestCase): | |
def test_fcrelu(self): | |
LANG = """ | |
def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1){ | |
O1(b, n) +=! I(b, m) * W1(n, m) | |
O1(b, n) = O1(b, n) + B1(n) | |
O1(b, n) = fmax(O1(b, n), 0) | |
} | |
""" | |
B, M, N = 100, 128, 100 | |
fcrelu = tc.define(LANG, name="fcrelu") | |
I, W1, B1 = torch.randn(B, M).cuda(), torch.randn(N, M).cuda(), torch.randn(N).cuda() | |
out = fcrelu(I, W1, B1) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestGroupConvolutionSimple(unittest.TestCase): | |
def test_group_convolution_simple(self): | |
LANG=""" | |
def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) -> (O) | |
{ | |
O(n, g, f, h, w) +=! I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw) | |
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) | |
} | |
""" | |
N, G, C, H, W, F, KH, KW = 32, 32, 4, 56, 56, 4, 3, 3 | |
group_convolution = tc.define(LANG, name="group_convolution") | |
I, W1 = torch.randn(N, G, C, H, W).cuda(), torch.randn(G, F, C, KH, KW).cuda() | |
B = torch.randn(G, F).cuda() | |
out = group_convolution(I, W1, B) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestGroupConvolutionStrided(unittest.TestCase): | |
def test_group_convolution_strided(self): | |
LANG=""" | |
def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) -> (O) | |
{{ | |
O(n, g, f, h, w) +=! I(n, g, c, {sh} * h + kh, {sw} * w + kw) * W1(g, f, c, kh, kw) | |
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) | |
}} | |
""" | |
N, G, C, H, W, F, KH, KW, sH, sW = 32, 32, 4, 56, 56, 4, 3, 3, 1, 1 | |
group_convolution = tc.define(LANG, name="group_convolution", constants={"sh":sH, "sw":sW}) | |
I, W1 = torch.randn(N, G, C, H, W).cuda(), torch.randn(G, F, C, KH, KW).cuda() | |
B = torch.randn(G, F).cuda() | |
out = group_convolution(I, W1, B) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestIndexing(unittest.TestCase): | |
def test_indexing(self): | |
LANG = """ | |
def indexing(float(H, W) input, int32(L) index) -> (output) {{ | |
output(l, w) = input(index(l), w) where l in 0:{L} | |
}} | |
""" | |
indexing = tc.define(LANG, name="indexing", constants={"L":2}) | |
inp = torch.arange(0, 16).view(4, 4).cuda() | |
idx = torch.IntTensor([1, 1]).cuda() | |
out = indexing(inp, idx) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestLookupTable(unittest.TestCase): | |
def test_lookup_table(self): | |
LANG = """ | |
def lut(float(B, R) LUT, int32(B, N) I) -> (O) { | |
O(b, n) +=! LUT(I(b, n), r) | |
} | |
""" | |
lut = tc.define(LANG, name="lut") | |
inp = torch.rand(17, 22).cuda() | |
idx = torch.IntTensor(17, 82).fill_(1).cuda() | |
out = lut(inp, idx) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestMatmul(unittest.TestCase): | |
def test_matmul(self): | |
lang = """ | |
def matmul(float(M,N) A, float(N,K) B) -> (output) { | |
output(i, j) +=! A(i, kk) * B(kk, j) | |
} | |
""" | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() | |
out = matmul(mat1, mat2) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestMatmulReuse(unittest.TestCase): | |
def test_matmul_reuse(self): | |
lang = """ | |
def matmul(float(M,N) A, float(N,K) B) -> (output) { | |
output(i, j) +=! A(i, kk) * B(kk, j) | |
} | |
""" | |
matmul = tc.define(lang, name="matmul") | |
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() | |
out = matmul(mat1, mat2) | |
# reuse the same outputs now instad of allocating again, so we save | |
# overhead of allocating storage again. Also, if the input sizes are same | |
# we skip the compilation and run directly | |
mat3, mat4 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() | |
matmul(mat3, mat4, outputs=out) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestMaxPool(unittest.TestCase): | |
def test_maxpool(self): | |
# NOTE: take note of use of {{ }} | |
LANG=""" | |
def maxpool(float(B, C, H, W) input) -> (output) {{ | |
output(b, c, h, w) max= input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW} | |
}} | |
""" | |
maxpool = tc.define(LANG, name="maxpool", constants={"sH":1, "sW":1, "kH":2, "kW":2}) | |
inp = torch.ones(32, 3, 10, 10).cuda() | |
out = maxpool(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestRelu(unittest.TestCase): | |
def test_relu(self): | |
LANG = """ | |
def relu(float(B,M) I) -> (O1){ | |
O1(b, m) = fmax(I(b, m), 0) | |
} | |
""" | |
relu = tc.define(LANG, name="relu") | |
inp = torch.randn(100, 128).cuda() | |
out = relu(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestScale(unittest.TestCase): | |
# NOTE: take note of use of {{ }} | |
def test_scale(self): | |
LANG = """ | |
def scale(float(M, N) I) -> (O) {{ | |
O(m, n) = I(m, n) * {s} | |
}} | |
""" | |
scale = tc.define(LANG, name="scale", constants={"s": 10}) | |
inp = torch.randn(100, 128).cuda() | |
out = scale(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestSigmoid(unittest.TestCase): | |
def test_sigmoid(self): | |
LANG = """ | |
def sigmoid(float(N, C, H, W) I) -> (O) { | |
O(n, c, h, w) = 1 / (1 + exp(-I(n, c, h, w))) | |
} | |
""" | |
sigmoid = tc.define(LANG, name="sigmoid") | |
inp = torch.randn(32, 3, 128, 128).cuda() | |
out = sigmoid(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestSmallMobileNet(unittest.TestCase): | |
def test_small_mobilenet(self): | |
LANG = """ | |
def small_mobilenet(float(C1, H, W) I, float(C1, KH1, KW1) W1, | |
float(C1) B1, float(C2, C1) W2, float(C2) B2) | |
-> (O1, O2) | |
{ | |
O1(c1, h, w) +=! I(c1, h + kh, w + kw) * W1(c1, kh, kw) | |
O1(c1, h, w) = O1(c1, h, w) + B1(c1) | |
O1(c1, h, w) = fmax(O1(c1, h, w), 0) | |
O2(c2, h, w) +=! O1(c1, h, w) * W2(c2, c1) | |
O2(c2, h, w) = O2(c2, h, w) + B2(c2) | |
O2(c2, h, w) = fmax(O2(c2, h, w), 0) | |
} | |
""" | |
C1, C2, H, W, KH1, KH2 = 128, 128, 16, 16, 3, 3 | |
small_mobilenet = tc.define(LANG, name="small_mobilenet") | |
I, W1 = torch.randn(C1, H, W).cuda(), torch.randn(C1, KH1, KH2).cuda() | |
B1, W2= torch.randn(C1).cuda(), torch.randn(C2, C1).cuda() | |
B2 = torch.randn(C2).cuda() | |
best_options = small_mobilenet.autotune(I, W1, B1, W2, B2, **tc.autotuner_default_options) | |
out = small_mobilenet(I, W1, B1, W2, B2, options=best_options) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestSoftmax(unittest.TestCase): | |
def test_softmax(self): | |
LANG = """ | |
def softmax(float(N, D) I) -> (O, expsum) { | |
expsum(n) +=! exp(I(n, d)) | |
O(n, d) = exp(I(n, d)) / expsum(n) | |
} | |
""" | |
softmax = tc.define(LANG, name="softmax") | |
inp = torch.randn(32, 16).cuda() | |
out = softmax(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestTanh(unittest.TestCase): | |
def test_tanh(self): | |
LANG = """ | |
def Tanh(float(M) I) -> (O) { | |
O(m) = tanh(I(m)) | |
} | |
""" | |
Tanh = tc.define(LANG, name="Tanh") | |
inp = torch.randn(32).cuda() | |
out = Tanh(inp) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestTensorDot(unittest.TestCase): | |
def test_tensordot(self): | |
LANG=""" | |
def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1) -> (O) | |
{ | |
O(n, c1, c3, h, w) +=! I0(n, c1, c2, h, w) * I1(n, c2, c3, h, w) | |
} | |
""" | |
N, C1, C2, C3, H, W = 32, 512, 8, 2, 28, 28 | |
tensordot = tc.define(LANG, name="tensordot") | |
I0, I1 = torch.randn(N, C1, C2, H, W).cuda(), torch.randn(N, C2, C3, H, W).cuda() | |
best_options = tensordot.autotune(I0, I1, **tc.autotuner_default_options) | |
out = tensordot(I0, I1, options=best_options) | |
if __name__ == '__main__': | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
import tensor_comprehensions as tc | |
import torch | |
import torch.cuda | |
import unittest | |
class TestTranspose(unittest.TestCase): | |
def test_transpose(self): | |
LANG=""" | |
def transpose(float(N, C, H, W) I) -> (O) { | |
O(c, n, w, h) = I(n, c, h, w) | |
} | |
""" | |
N, C, H, W = 32, 512, 56, 56 | |
transpose = tc.define(LANG, name="transpose") | |
I = torch.randn(N, C, H, W).cuda() | |
out = transpose(I) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment