Skip to content

Instantly share code, notes, and snippets.

@zou3519
Created May 6, 2021 21:12
Show Gist options
  • Save zou3519/ddd4b2d4aacc98bf20d114f26b27b082 to your computer and use it in GitHub Desktop.
Save zou3519/ddd4b2d4aacc98bf20d114f26b27b082 to your computer and use it in GitHub Desktop.
static std::tuple<Tensor, optional<int64_t>> vt_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim) {
if (!self_bdim && !other_bdim) {
return { at::matmul(self, other), nullopt };
}
else if (self_bdim && other_bdim) {
// BI, B...IO -> B1I, ...BIO -> ...B1O -> ...BO
auto self_ = moveBatchDimToFront(self, self_bdim);
auto other_ = at::movedim(other, *other_bdim, -2);
self_ = self_.unsqueeze(-2);
auto result = at::matmul(self_, other_).squeeze(-2);
return { std::move(result), 1 };
}
else if (!self_bdim && other_bdim) {
// I, B...IO
auto other_ = moveBatchDimToFront(other, other_bdim);
return { at::matmul(self, other_), 0 };
}
else if (self_bdim && !other_bdim) {
// BI, ...IO -> ...BO
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::matmul(self_, other);
auto result_bdim = result.dim() - 2;
return { std::move(result), result_bdim };
}
}
static std::tuple<Tensor, optional<int64_t>> matmul_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim) {
if (!self_bdim && !other_bdim) {
return { at::matmul(self, other), nullopt };
}
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
TORCH_CHECK(self_logical_rank >= 1 && other_logical_rank >= 1, "Got wrong shapes for matmul");
if (self_logical_rank == 2 && other_logical_rank == 2) {
return mm_batch_rule(self, self_bdim, other, other_bdim);
}
else if (self_logical_rank == 1 && other_logical_rank == 1) {
return dot_batch_rule(self, self_bdim, other, other_bdim);
}
else if (self_logical_rank == 1 && other_logical_rank > 2) {
return vt_batch_rule(self, self_bdim, other, other_bdim);
}
else if (self_logical_rank > 2 && other_logical_rank == 1) {
return tv_batch_rule(self, self_bdim, other, other_bdim);
}
if (self_bdim && other_bdim) {
// fast path
if (self_logical_rank == other_logical_rank) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto other_ = moveBatchDimToFront(other, other_bdim);
return { at::matmul(self_, other_), 0 };
}
// B...NI, B...IO -> ...BNI, ...BIO -> ...BNO
auto self_ = at::movedim(self, *self_bdim, -2);
auto other_ = at::movedim(other, *other_bdim, -2);
auto result = at::matmul(self_, other_);
auto result_bdim = result.dim() - 3;
return { std::move(result), result_bdim };
}
else if (!self_bdim && other_bdim) {
// fast path: tell matmul to deal with it
if (self_logical_rank < other_logical_rank) {
auto other_ = moveBatchDimToFront(other, other_bdim);
return { at::matmul(self, other_), 0 };
}
// ...NI, B...IO -> ...NI, ...I(BO) -> ...N(BO) -> ...NBO
auto bdim_size = other.size(*other_bdim);
auto other_ = reshape_dim_into(*other_bdim, -1, self);
auto result = at::matmul(self, other_);
result = reshape_dim_outof(-1, bdim_size, result);
auto result_bdim = result.dim() - 2;
return { std::move(result), result_bdim };
}
else if (self_bdim && !other_bdim) {
// fast path: tell matmul to deal with it
if (self_logical_rank > other_logical_rank) {
auto self_ = moveBatchDimToFront(self, self_bdim);
return { at::matmul(self_, other), 0 };
}
// B...NI, ...IO -> ...(BN)I, ...IO -> ...(BN)O -> ...BNO
auto bdim_size = self.size(*self_bdim);
auto self_ = reshape_dim_into(*self_bdim, -2, self);
auto result = at::matmul(self_, other);
result = reshape_dim_outof(-2, bdim_size, result);
auto result_bdim = result.dim() - 3;
return { std::move(result), result_bdim };
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment