Skip to content

Instantly share code, notes, and snippets.

@bmharper
Created February 23, 2019 07:11
Show Gist options
  • Save bmharper/86ca7fc477688baf02494b13ffb15c8a to your computer and use it in GitHub Desktop.
Save bmharper/86ca7fc477688baf02494b13ffb15c8a to your computer and use it in GitHub Desktop.
Demonstrate value/reference copy/assignment semantics of LibTorch at::Tensor
/*
Output:
operator= reference copy (same memory)
copy_ value copy (different memory)
operator= of operator[] value copy (different memory)
slice reference copy (same memory)
*/
#include <stdio.h>
#include <torch/script.h>
void TestTensorSemantics() {
// copy, assignment, etc
std::vector<std::pair<std::string, std::function<void(const at::Tensor& src, at::Tensor& dst)>>> tests = {
{"operator=", [](const at::Tensor& src, at::Tensor& dst) { dst = src; }},
{"copy_", [](const at::Tensor& src, at::Tensor& dst) { dst.copy_(src); }},
{"operator= of operator[]", [](const at::Tensor& src, at::Tensor& dst) { dst[0] = src[0]; }},
{"slice", [](const at::Tensor& src, at::Tensor& dst) { dst = src.slice(0, 0, 1); }},
};
for (auto tx : tests) {
auto src = torch::zeros({4, 3, 1});
auto dst = torch::zeros_like(src);
printf("%30s ", tx.first.c_str());
tx.second(src, dst);
auto src_d = src.accessor<float, 3>();
auto dst_d = dst.accessor<float, 3>();
src_d[0][0][0] = 999;
if (dst_d[0][0][0] == 999)
printf("reference copy (same memory)\n");
else
printf("value copy (different memory)\n");
}
}
int main(int argc, char** argv) {
TestTensorSemantics();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment