Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Created November 19, 2019 20:40
Show Gist options
  • Save dlibenzi/4b449de70d95de7460f136c023c6092f to your computer and use it in GitHub Desktop.
Save dlibenzi/4b449de70d95de7460f136c023c6092f to your computer and use it in GitHub Desktop.
(pytorch) dlibenzi@dlibenzi2:~/google-git/pytorch$ git diff aten/src/ATen/Declarations.cwrap aten/src/TH/generic/THTensorMath.h aten/src/TH/generic/THTensorMoreMath.cpp
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index e2f17c5970..6908d2d613 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -286,6 +286,7 @@
name: _th_equal
cname: equal
cpu_bool: True
+ cpu_bfloat16: True
cuda_bool: True
variants:
- function
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index e2c7c8d981..73432f4406 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -6,6 +6,7 @@
#include <ATen/core/DistributionsHelper.h>
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
+TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
#if !defined(TH_REAL_IS_HALF)
@@ -72,7 +73,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value);
#if !defined(TH_REAL_IS_BFLOAT16)
TH_API accreal THTensor_(sumall)(THTensor *t);
-TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
TH_API void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src);
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 56ae143de4..439cf9915e 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -13,8 +13,6 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
return THTensor_(nElement)(t);
}
-#if !defined(TH_REAL_IS_BFLOAT16)
-
static int THTensor_(equalImpl)(THTensor *ta, THTensor* tb)
{
std::atomic<int> equal{1};
@@ -62,6 +60,8 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb) {
return THTensor_(equalImpl)(ta, tb);
}
+#if !defined(TH_REAL_IS_BFLOAT16)
+
// Helper function to be used in a reduction operation.
// Due to resize semantics of outputs, if the specified output tensor r_ has
// same size as the output of the reduction operation, then any noncontiguities
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment