Created
November 19, 2019 20:40
-
-
Save dlibenzi/4b449de70d95de7460f136c023c6092f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(pytorch) dlibenzi@dlibenzi2:~/google-git/pytorch$ git diff aten/src/ATen/Declarations.cwrap aten/src/TH/generic/THTensorMath.h aten/src/TH/generic/THTensorMoreMath.cpp | |
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap | |
index e2f17c5970..6908d2d613 100644 | |
--- a/aten/src/ATen/Declarations.cwrap | |
+++ b/aten/src/ATen/Declarations.cwrap | |
@@ -286,6 +286,7 @@ | |
name: _th_equal | |
cname: equal | |
cpu_bool: True | |
+ cpu_bfloat16: True | |
cuda_bool: True | |
variants: | |
- function | |
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h | |
index e2c7c8d981..73432f4406 100644 | |
--- a/aten/src/TH/generic/THTensorMath.h | |
+++ b/aten/src/TH/generic/THTensorMath.h | |
@@ -6,6 +6,7 @@ | |
#include <ATen/core/DistributionsHelper.h> | |
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor); | |
+TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb); | |
#if !defined(TH_REAL_IS_HALF) | |
@@ -72,7 +73,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value); | |
#if !defined(TH_REAL_IS_BFLOAT16) | |
TH_API accreal THTensor_(sumall)(THTensor *t); | |
-TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb); | |
TH_API void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value); | |
TH_API void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src); | |
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp | |
index 56ae143de4..439cf9915e 100644 | |
--- a/aten/src/TH/generic/THTensorMoreMath.cpp | |
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp | |
@@ -13,8 +13,6 @@ ptrdiff_t THTensor_(numel)(THTensor *t) | |
return THTensor_(nElement)(t); | |
} | |
-#if !defined(TH_REAL_IS_BFLOAT16) | |
- | |
static int THTensor_(equalImpl)(THTensor *ta, THTensor* tb) | |
{ | |
std::atomic<int> equal{1}; | |
@@ -62,6 +60,8 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb) { | |
return THTensor_(equalImpl)(ta, tb); | |
} | |
+#if !defined(TH_REAL_IS_BFLOAT16) | |
+ | |
// Helper function to be used in a reduction operation. | |
// Due to resize semantics of outputs, if the specified output tensor r_ has | |
// same size as the output of the reduction operation, then any noncontiguities | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment