Skip to content

Instantly share code, notes, and snippets.

@killeent
Created October 27, 2017 18:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save killeent/2c128b4ed8ae084c3c7dfa94bbcfcaeb to your computer and use it in GitHub Desktop.
Save killeent/2c128b4ed8ae084c3c7dfa94bbcfcaeb to your computer and use it in GitHub Desktop.
diff --git a/torch/csrc/distributed/Module.cpp b/torch/csrc/distributed/Module.cpp
index a985509..293a4e1 100644
--- a/torch/csrc/distributed/Module.cpp
+++ b/torch/csrc/distributed/Module.cpp
@@ -186,8 +186,8 @@ THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj)
PyObject *type = (PyObject*)Py_TYPE(obj);
#define REGISTER_TH_DESCRIPTOR(TYPE, REAL) \
if (type == THP##TYPE##Class) \
- return at::CPU(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true);
- /* return THDTensorDescriptor_newFromTH##TYPE(((THP##TYPE*)obj)->cdata); */
+ return THDTensorDescriptor_newFromTH##TYPE(((THP##TYPE*)obj)->cdata);
+ /* return at::CPU(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true); */
REGISTER_TH_DESCRIPTOR(DoubleTensor, at::kDouble);
REGISTER_TH_DESCRIPTOR(FloatTensor, at::kFloat);
REGISTER_TH_DESCRIPTOR(LongTensor, at::kLong);
@@ -199,12 +199,12 @@ THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj)
#ifdef WITH_CUDA
#define REGISTER_THC_DESCRIPTOR(TYPE, REAL) \
if (type == THCP##TYPE##Class) \
- return at::CUDA(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true);
- /* return THDTensorDescriptor_newFromTHCuda##TYPE((THCuda##TYPE*)(((torch::THPVoidTensor*)obj)->cdata)); */
+ return THDTensorDescriptor_newFromTHCuda##TYPE((THCuda##TYPE*)(((torch::THPVoidTensor*)obj)->cdata));
+ /* return at::CUDA(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true); */
REGISTER_THC_DESCRIPTOR(DoubleTensor, at::kDouble);
if (type == THCPFloatTensorClass)
- return at::CUDA(at::kFloat).unsafeTensorFromTH((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata), true);
- /* return THDTensorDescriptor_newFromTHCudaFloatTensor((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata)); */
+ return THDTensorDescriptor_newFromTHCudaFloatTensor((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata));
+ /* return at::CUDA(at::kFloat).unsafeTensorFromTH((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata), true); */
REGISTER_THC_DESCRIPTOR(LongTensor, at::kLong);
REGISTER_THC_DESCRIPTOR(IntTensor, at::kInt);
REGISTER_THC_DESCRIPTOR(ShortTensor, at::kShort);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment