Skip to content

Instantly share code, notes, and snippets.

@cpuhrsch
Created March 2, 2019 00:21
Show Gist options
  • Save cpuhrsch/f400a6614c0f67dda601aad98a84767c to your computer and use it in GitHub Desktop.
Save cpuhrsch/f400a6614c0f67dda601aad98a84767c to your computer and use it in GitHub Desktop.
diff --git a/build/aten/src/ATen/Declarations.yaml b/build/aten/src/ATen/Declarations.yaml
index 7ca5b4659..249b9a9b5 100644
--- a/build/aten/src/ATen/Declarations.yaml
+++ b/build/aten/src/ATen/Declarations.yaml
@@ -25896,8 +25896,8 @@
with_gil: false
deprecated: false
- name: embedding_backward
- matches_jit_signature: false
- schema_string: aten::embedding_backward(Tensor grad, IndexTensor indices, int num_weights,
+ matches_jit_signature: true
+ schema_string: aten::embedding_backward(Tensor grad, Tensor indices, int num_weights,
int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
method_prefix_derived: ''
arguments:
@@ -25907,7 +25907,7 @@
name: grad
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
@@ -26037,9 +26037,9 @@
with_gil: false
deprecated: false
- name: embedding_sparse_backward
- matches_jit_signature: false
- schema_string: aten::embedding_sparse_backward(Tensor grad, IndexTensor indices,
- int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
+ matches_jit_signature: true
+ schema_string: aten::embedding_sparse_backward(Tensor grad, Tensor indices, int
+ num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
@@ -26048,7 +26048,7 @@
name: grad
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
@@ -26085,9 +26085,9 @@
deprecated: false
- name: embedding_bag
matches_jit_signature: false
- schema_string: aten::embedding_bag(Tensor weight, IndexTensor indices, IndexTensor
- offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor,
- Tensor, Tensor, Tensor)
+ schema_string: aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets,
+ bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor,
+ Tensor, Tensor)
method_prefix_derived: ''
arguments:
- annotation: null
@@ -26096,12 +26096,12 @@
name: weight
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: offsets
type: const Tensor &
@@ -26149,10 +26149,10 @@
with_gil: false
deprecated: false
- name: _embedding_bag
- matches_jit_signature: false
- schema_string: aten::_embedding_bag(Tensor weight, IndexTensor indices, IndexTensor
- offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor,
- Tensor, Tensor, Tensor)
+ matches_jit_signature: true
+ schema_string: aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets,
+ bool scale_grad_by_freq=False, int mode=0, bool sparse=False) -> (Tensor, Tensor,
+ Tensor, Tensor)
method_prefix_derived: ''
arguments:
- annotation: null
@@ -26161,12 +26161,12 @@
name: weight
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: offsets
type: const Tensor &
@@ -26214,10 +26214,10 @@
with_gil: false
deprecated: false
- name: _embedding_bag_backward
- matches_jit_signature: false
- schema_string: aten::_embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor
- offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices,
- int num_weights, bool scale_grad_by_freq, int mode, bool sparse) -> Tensor
+ matches_jit_signature: true
+ schema_string: aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor
+ offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights,
+ bool scale_grad_by_freq, int mode, bool sparse) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
@@ -26226,27 +26226,27 @@
name: grad
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: offsets
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: offset2bag
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: bag_size
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: maximum_indices
type: const Tensor &
@@ -26287,10 +26287,10 @@
with_gil: false
deprecated: false
- name: _embedding_bag_sparse_backward
- matches_jit_signature: false
- schema_string: aten::_embedding_bag_sparse_backward(Tensor grad, IndexTensor indices,
- IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int num_weights,
- bool scale_grad_by_freq, int mode) -> Tensor
+ matches_jit_signature: true
+ schema_string: aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices,
+ Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq,
+ int mode) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
@@ -26299,22 +26299,22 @@
name: grad
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: offsets
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: offset2bag
type: const Tensor &
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: bag_size
type: const Tensor &
@@ -39554,11 +39554,11 @@
deprecated: false
- name: where
matches_jit_signature: false
- schema_string: aten::where(BoolTensor condition, Tensor self, Tensor other) -> Tensor
+ schema_string: aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
- dynamic_type: BoolTensor
+ dynamic_type: Tensor
is_nullable: false
name: condition
type: const Tensor &
@@ -41714,12 +41714,12 @@
deprecated: false
- name: sparse_coo_tensor
matches_jit_signature: false
- schema_string: aten::sparse_coo_tensor(IndexTensor indices, Tensor values, *, TensorOptions
+ schema_string: aten::sparse_coo_tensor(Tensor indices, Tensor values, *, TensorOptions
options=[]) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
@@ -41753,12 +41753,12 @@
deprecated: false
- name: sparse_coo_tensor
matches_jit_signature: false
- schema_string: aten::sparse_coo_tensor(IndexTensor indices, Tensor values, int[]
- size, *, TensorOptions options=[]) -> Tensor
+ schema_string: aten::sparse_coo_tensor(Tensor indices, Tensor values, int[] size,
+ *, TensorOptions options=[]) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
@@ -41797,12 +41797,12 @@
deprecated: false
- name: _sparse_coo_tensor_unsafe
matches_jit_signature: false
- schema_string: aten::_sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values,
- int[] size, *, TensorOptions options=[]) -> Tensor
+ schema_string: aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[]
+ size, *, TensorOptions options=[]) -> Tensor
method_prefix_derived: ''
arguments:
- annotation: null
- dynamic_type: IndexTensor
+ dynamic_type: Tensor
is_nullable: false
name: indices
type: const Tensor &
diff --git a/build/aten/src/ATen/MSNPUType.cpp b/build/aten/src/ATen/MSNPUType.cpp
index 35f9d639c..ff75cbf78 100644
--- a/build/aten/src/ATen/MSNPUType.cpp
+++ b/build/aten/src/ATen/MSNPUType.cpp
@@ -2107,7 +2107,7 @@ Tensor MSNPUType::embedding(const Tensor & weight, const Tensor & indices, int64
return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, bool, bool)>("embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")(weight, indices, padding_idx, scale_grad_by_freq, sparse);
}
Tensor MSNPUType::embedding_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool, bool)>("embedding_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool, bool)>("embedding_backward(Tensor grad, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse);
}
Tensor MSNPUType::embedding_dense_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) const {
return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool)>("embedding_dense_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq);
@@ -2116,19 +2116,19 @@ Tensor & MSNPUType::embedding_renorm_(Tensor & self, const Tensor & indices, dou
return MSNPUTypeDispatch::get_function<Tensor & (*)(Tensor &, const Tensor &, double, double)>("embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) -> Tensor")(self, indices, max_norm, norm_type);
}
Tensor MSNPUType::embedding_sparse_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool)>("embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool)>("embedding_sparse_backward(Tensor grad, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq);
}
std::tuple<Tensor,Tensor,Tensor,Tensor> MSNPUType::embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) const {
- return MSNPUTypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
+ return MSNPUTypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
}
std::tuple<Tensor,Tensor,Tensor,Tensor> MSNPUType::_embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) const {
- return MSNPUTypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("_embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
+ return MSNPUTypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
}
Tensor MSNPUType::_embedding_bag_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t, bool)>("_embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t, bool)>("_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse);
}
Tensor MSNPUType::_embedding_bag_sparse_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t)>("_embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t)>("_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode);
}
Tensor MSNPUType::_embedding_bag_dense_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) const {
return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t)>("_embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode);
@@ -3151,7 +3151,7 @@ Tensor MSNPUType::view_as(const Tensor & self, const Tensor & other) const {
return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &)>("view_as(Tensor self, Tensor other) -> Tensor")(self, other);
}
Tensor MSNPUType::where(const Tensor & condition, const Tensor & self, const Tensor & other) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &)>("where(BoolTensor condition, Tensor self, Tensor other) -> Tensor")(condition, self, other);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &)>("where(Tensor condition, Tensor self, Tensor other) -> Tensor")(condition, self, other);
}
Tensor MSNPUType::_s_where(const Tensor & condition, const Tensor & self, const Tensor & other) const {
return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &)>("_s_where(Tensor condition, Tensor self, Tensor other) -> Tensor")(condition, self, other);
@@ -3319,13 +3319,13 @@ Tensor MSNPUType::sparse_coo_tensor(IntArrayRef size, const TensorOptions & opti
return MSNPUTypeDispatch::get_function<Tensor (*)(IntArrayRef, const TensorOptions &)>("sparse_coo_tensor(IntArrayRef size, TensorOptions options) -> Tensor")(size, options);
}
Tensor MSNPUType::sparse_coo_tensor(const Tensor & indices, const Tensor & values, const TensorOptions & options) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const TensorOptions &)>("sparse_coo_tensor(IndexTensor indices, Tensor values, TensorOptions options) -> Tensor")(indices, values, options);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const TensorOptions &)>("sparse_coo_tensor(Tensor indices, Tensor values, TensorOptions options) -> Tensor")(indices, values, options);
}
Tensor MSNPUType::sparse_coo_tensor(const Tensor & indices, const Tensor & values, IntArrayRef size, const TensorOptions & options) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("sparse_coo_tensor(IndexTensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("sparse_coo_tensor(Tensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
}
Tensor MSNPUType::_sparse_coo_tensor_unsafe(const Tensor & indices, const Tensor & values, IntArrayRef size, const TensorOptions & options) const {
- return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("_sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
+ return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
}
Tensor MSNPUType::_sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, const TensorOptions & options) const {
return MSNPUTypeDispatch::get_function<Tensor (*)(int64_t, int64_t, IntArrayRef, const TensorOptions &)>("_sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, TensorOptions options) -> Tensor")(sparse_dim, dense_dim, size, options);
diff --git a/build/aten/src/ATen/XLAType.cpp b/build/aten/src/ATen/XLAType.cpp
index f4e919eee..cbd587522 100644
--- a/build/aten/src/ATen/XLAType.cpp
+++ b/build/aten/src/ATen/XLAType.cpp
@@ -2107,7 +2107,7 @@ Tensor XLAType::embedding(const Tensor & weight, const Tensor & indices, int64_t
return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, bool, bool)>("embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")(weight, indices, padding_idx, scale_grad_by_freq, sparse);
}
Tensor XLAType::embedding_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool, bool)>("embedding_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool, bool)>("embedding_backward(Tensor grad, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse);
}
Tensor XLAType::embedding_dense_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) const {
return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool)>("embedding_dense_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq);
@@ -2116,19 +2116,19 @@ Tensor & XLAType::embedding_renorm_(Tensor & self, const Tensor & indices, doubl
return XLATypeDispatch::get_function<Tensor & (*)(Tensor &, const Tensor &, double, double)>("embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) -> Tensor")(self, indices, max_norm, norm_type);
}
Tensor XLAType::embedding_sparse_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool)>("embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, int64_t, int64_t, bool)>("embedding_sparse_backward(Tensor grad, Tensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor")(grad, indices, num_weights, padding_idx, scale_grad_by_freq);
}
std::tuple<Tensor,Tensor,Tensor,Tensor> XLAType::embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) const {
- return XLATypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
+ return XLATypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
}
std::tuple<Tensor,Tensor,Tensor,Tensor> XLAType::_embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) const {
- return XLATypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("_embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
+ return XLATypeDispatch::get_function<std::tuple<Tensor,Tensor,Tensor,Tensor> (*)(const Tensor &, const Tensor &, const Tensor &, bool, int64_t, bool)>("_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) -> std::tuple<Tensor,Tensor,Tensor,Tensor>")(weight, indices, offsets, scale_grad_by_freq, mode, sparse);
}
Tensor XLAType::_embedding_bag_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t, bool)>("_embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t, bool)>("_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse);
}
Tensor XLAType::_embedding_bag_sparse_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t)>("_embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t)>("_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode);
}
Tensor XLAType::_embedding_bag_dense_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) const {
return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, bool, int64_t)>("_embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor")(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode);
@@ -3151,7 +3151,7 @@ Tensor XLAType::view_as(const Tensor & self, const Tensor & other) const {
return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &)>("view_as(Tensor self, Tensor other) -> Tensor")(self, other);
}
Tensor XLAType::where(const Tensor & condition, const Tensor & self, const Tensor & other) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &)>("where(BoolTensor condition, Tensor self, Tensor other) -> Tensor")(condition, self, other);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &)>("where(Tensor condition, Tensor self, Tensor other) -> Tensor")(condition, self, other);
}
Tensor XLAType::_s_where(const Tensor & condition, const Tensor & self, const Tensor & other) const {
return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const Tensor &)>("_s_where(Tensor condition, Tensor self, Tensor other) -> Tensor")(condition, self, other);
@@ -3319,13 +3319,13 @@ Tensor XLAType::sparse_coo_tensor(IntArrayRef size, const TensorOptions & option
return XLATypeDispatch::get_function<Tensor (*)(IntArrayRef, const TensorOptions &)>("sparse_coo_tensor(IntArrayRef size, TensorOptions options) -> Tensor")(size, options);
}
Tensor XLAType::sparse_coo_tensor(const Tensor & indices, const Tensor & values, const TensorOptions & options) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const TensorOptions &)>("sparse_coo_tensor(IndexTensor indices, Tensor values, TensorOptions options) -> Tensor")(indices, values, options);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, const TensorOptions &)>("sparse_coo_tensor(Tensor indices, Tensor values, TensorOptions options) -> Tensor")(indices, values, options);
}
Tensor XLAType::sparse_coo_tensor(const Tensor & indices, const Tensor & values, IntArrayRef size, const TensorOptions & options) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("sparse_coo_tensor(IndexTensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("sparse_coo_tensor(Tensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
}
Tensor XLAType::_sparse_coo_tensor_unsafe(const Tensor & indices, const Tensor & values, IntArrayRef size, const TensorOptions & options) const {
- return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("_sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
+ return XLATypeDispatch::get_function<Tensor (*)(const Tensor &, const Tensor &, IntArrayRef, const TensorOptions &)>("_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, IntArrayRef size, TensorOptions options) -> Tensor")(indices, values, size, options);
}
Tensor XLAType::_sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, const TensorOptions & options) const {
return XLATypeDispatch::get_function<Tensor (*)(int64_t, int64_t, IntArrayRef, const TensorOptions &)>("_sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, TensorOptions options) -> Tensor")(sparse_dim, dense_dim, size, options);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment