Created
February 23, 2019 07:11
-
-
Save bmharper/86ca7fc477688baf02494b13ffb15c8a to your computer and use it in GitHub Desktop.
Demonstrate value/reference copy/assignment semantics of LibTorch at::Tensor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Output: | |
operator= reference copy (same memory) | |
copy_ value copy (different memory) | |
operator= of operator[] value copy (different memory) | |
slice reference copy (same memory) | |
*/ | |
#include <stdio.h> | |
#include <torch/script.h> | |
void TestTensorSemantics() { | |
// copy, assignment, etc | |
std::vector<std::pair<std::string, std::function<void(const at::Tensor& src, at::Tensor& dst)>>> tests = { | |
{"operator=", [](const at::Tensor& src, at::Tensor& dst) { dst = src; }}, | |
{"copy_", [](const at::Tensor& src, at::Tensor& dst) { dst.copy_(src); }}, | |
{"operator= of operator[]", [](const at::Tensor& src, at::Tensor& dst) { dst[0] = src[0]; }}, | |
{"slice", [](const at::Tensor& src, at::Tensor& dst) { dst = src.slice(0, 0, 1); }}, | |
}; | |
for (auto tx : tests) { | |
auto src = torch::zeros({4, 3, 1}); | |
auto dst = torch::zeros_like(src); | |
printf("%30s ", tx.first.c_str()); | |
tx.second(src, dst); | |
auto src_d = src.accessor<float, 3>(); | |
auto dst_d = dst.accessor<float, 3>(); | |
src_d[0][0][0] = 999; | |
if (dst_d[0][0][0] == 999) | |
printf("reference copy (same memory)\n"); | |
else | |
printf("value copy (different memory)\n"); | |
} | |
} | |
int main(int argc, char** argv) { | |
TestTensorSemantics(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment