Skip to content

Instantly share code, notes, and snippets.

@zou3519
Created December 27, 2022 21:18
Show Gist options
  • Save zou3519/2e5d020d04b192c714121993da6daeeb to your computer and use it in GitHub Desktop.
Save zou3519/2e5d020d04b192c714121993da6daeeb to your computer and use it in GitHub Desktop.
template <typename batch_rule_t, batch_rule_t batch_rule>
at::Tensor & index_fill__int_Scalar_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) {
return at::_ops::index_fill__int_Scalar::call(self, dim, index, value);
}
Tensor self_value;
optional<int64_t> self_bdim;
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
Tensor index_value;
optional<int64_t> index_bdim;
std::tie(index_value, index_bdim) = unwrapTensorAtLevel(index, cur_level);
batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value);
return self;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment