-
-
Save zou3519/ddd4b2d4aacc98bf20d114f26b27b082 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
static std::tuple<Tensor, optional<int64_t>> vt_batch_rule( | |
const Tensor& self, optional<int64_t> self_bdim, | |
const Tensor& other, optional<int64_t> other_bdim) { | |
if (!self_bdim && !other_bdim) { | |
return { at::matmul(self, other), nullopt }; | |
} | |
else if (self_bdim && other_bdim) { | |
// BI, B...IO -> B1I, ...BIO -> ...B1O -> ...BO | |
auto self_ = moveBatchDimToFront(self, self_bdim); | |
auto other_ = at::movedim(other, *other_bdim, -2); | |
self_ = self_.unsqueeze(-2); | |
auto result = at::matmul(self_, other_).squeeze(-2); | |
return { std::move(result), 1 }; | |
} | |
else if (!self_bdim && other_bdim) { | |
// I, B...IO | |
auto other_ = moveBatchDimToFront(other, other_bdim); | |
return { at::matmul(self, other_), 0 }; | |
} | |
else if (self_bdim && !other_bdim) { | |
// BI, ...IO -> ...BO | |
auto self_ = moveBatchDimToFront(self, self_bdim); | |
auto result = at::matmul(self_, other); | |
auto result_bdim = result.dim() - 2; | |
return { std::move(result), result_bdim }; | |
} | |
} | |
static std::tuple<Tensor, optional<int64_t>> matmul_batch_rule( | |
const Tensor& self, optional<int64_t> self_bdim, | |
const Tensor& other, optional<int64_t> other_bdim) { | |
if (!self_bdim && !other_bdim) { | |
return { at::matmul(self, other), nullopt }; | |
} | |
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); | |
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); | |
TORCH_CHECK(self_logical_rank >= 1 && other_logical_rank >= 1, "Got wrong shapes for matmul"); | |
if (self_logical_rank == 2 && other_logical_rank == 2) { | |
return mm_batch_rule(self, self_bdim, other, other_bdim); | |
} | |
else if (self_logical_rank == 1 && other_logical_rank == 1) { | |
return dot_batch_rule(self, self_bdim, other, other_bdim); | |
} | |
else if (self_logical_rank == 1 && other_logical_rank > 2) { | |
return vt_batch_rule(self, self_bdim, other, other_bdim); | |
} | |
else if (self_logical_rank > 2 && other_logical_rank == 1) { | |
return tv_batch_rule(self, self_bdim, other, other_bdim); | |
} | |
if (self_bdim && other_bdim) { | |
// fast path | |
if (self_logical_rank == other_logical_rank) { | |
auto self_ = moveBatchDimToFront(self, self_bdim); | |
auto other_ = moveBatchDimToFront(other, other_bdim); | |
return { at::matmul(self_, other_), 0 }; | |
} | |
// B...NI, B...IO -> ...BNI, ...BIO -> ...BNO | |
auto self_ = at::movedim(self, *self_bdim, -2); | |
auto other_ = at::movedim(other, *other_bdim, -2); | |
auto result = at::matmul(self_, other_); | |
auto result_bdim = result.dim() - 3; | |
return { std::move(result), result_bdim }; | |
} | |
else if (!self_bdim && other_bdim) { | |
// fast path: tell matmul to deal with it | |
if (self_logical_rank < other_logical_rank) { | |
auto other_ = moveBatchDimToFront(other, other_bdim); | |
return { at::matmul(self, other_), 0 }; | |
} | |
// ...NI, B...IO -> ...NI, ...I(BO) -> ...N(BO) -> ...NBO | |
auto bdim_size = other.size(*other_bdim); | |
auto other_ = reshape_dim_into(*other_bdim, -1, self); | |
auto result = at::matmul(self, other_); | |
result = reshape_dim_outof(-1, bdim_size, result); | |
auto result_bdim = result.dim() - 2; | |
return { std::move(result), result_bdim }; | |
} | |
else if (self_bdim && !other_bdim) { | |
// fast path: tell matmul to deal with it | |
if (self_logical_rank > other_logical_rank) { | |
auto self_ = moveBatchDimToFront(self, self_bdim); | |
return { at::matmul(self_, other), 0 }; | |
} | |
// B...NI, ...IO -> ...(BN)I, ...IO -> ...(BN)O -> ...BNO | |
auto bdim_size = self.size(*self_bdim); | |
auto self_ = reshape_dim_into(*self_bdim, -2, self); | |
auto result = at::matmul(self_, other); | |
result = reshape_dim_outof(-2, bdim_size, result); | |
auto result_bdim = result.dim() - 3; | |
return { std::move(result), result_bdim }; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment