Created
April 12, 2023 15:00
-
-
Save lmontigny/a52f17b05fa0e27e8ad483aafa796256 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
llvm::Value* CodeGenerator::codegenMul(const hdk::ir::BinOper* bin_oper, | |
llvm::Value* lhs_lv, | |
llvm::Value* rhs_lv, | |
const std::string& null_typename, | |
const std::string& null_check_suffix, | |
const hdk::ir::Type* type, | |
const CompilationOptions& co, | |
bool downscale) { | |
AUTOMATIC_IR_METADATA(cgen_state_); | |
CHECK_EQ(lhs_lv->getType(), rhs_lv->getType()); | |
CHECK(type->isInteger() || type->isDecimal() || type->isInterval()); | |
llvm::Value* chosen_max{nullptr}; | |
llvm::Value* chosen_min{nullptr}; | |
std::tie(chosen_max, chosen_min) = cgen_state_->inlineIntMaxMin(type->size(), true); | |
auto need_overflow_check = | |
!checkExpressionRanges(bin_oper, | |
static_cast<llvm::ConstantInt*>(chosen_min)->getSExtValue(), | |
static_cast<llvm::ConstantInt*>(chosen_max)->getSExtValue()); | |
if (need_overflow_check && co.device_type == ExecutorDeviceType::CPU) { | |
return codegenBinOpWithOverflowForCPU( | |
bin_oper, lhs_lv, rhs_lv, null_check_suffix, type); | |
} | |
llvm::BasicBlock* mul_ok{nullptr}; | |
llvm::BasicBlock* mul_fail{nullptr}; | |
if (need_overflow_check) { | |
cgen_state_->needs_error_check_ = true; | |
mul_ok = llvm::BasicBlock::Create( | |
cgen_state_->context_, "mul_ok", cgen_state_->current_func_); | |
if (!null_check_suffix.empty()) { | |
codegenSkipOverflowCheckForNull(lhs_lv, rhs_lv, mul_ok, type); | |
} | |
// Overflow check following LLVM implementation | |
// https://github.com/hdoc/llvm-project/blob/release/15.x//llvm/include/llvm/Support/MathExtras.h | |
// Create LLVM Basic Block for control flow | |
mul_fail = llvm::BasicBlock::Create( | |
cgen_state_->context_, "mul_fail", cgen_state_->current_func_); | |
auto mul_check_1 = llvm::BasicBlock::Create( | |
cgen_state_->context_, "mul_check_1", cgen_state_->current_func_); | |
auto mul_check_2 = llvm::BasicBlock::Create( | |
cgen_state_->context_, "mul_check_2", cgen_state_->current_func_); | |
// Define LLVM constant | |
auto const_zero = llvm::ConstantInt::get(rhs_lv->getType(), 0, true); | |
auto const_one = llvm::ConstantInt::get(rhs_lv->getType(), 1, true); | |
// If any of the args was 0, no overflow occurs | |
auto lhs_is_zero = cgen_state_->ir_builder_.CreateICmpEQ(lhs_lv, const_zero); | |
auto rhs_is_zero = cgen_state_->ir_builder_.CreateICmpEQ(rhs_lv, const_zero); | |
auto args_zero = cgen_state_->ir_builder_.CreateOr(rhs_is_zero, lhs_is_zero); | |
cgen_state_->ir_builder_.CreateCondBr(args_zero, mul_ok, mul_check_1); | |
cgen_state_->ir_builder_.SetInsertPoint(mul_check_1); | |
// Check the sign of the args | |
auto lhs_is_neg = cgen_state_->ir_builder_.CreateICmpSLT(lhs_lv, const_zero); | |
auto rhs_is_neg = cgen_state_->ir_builder_.CreateICmpSLT(rhs_lv, const_zero); | |
auto args_is_neg = cgen_state_->ir_builder_.CreateOr(lhs_is_neg, rhs_is_neg); | |
auto lhs_is_pos = cgen_state_->ir_builder_.CreateICmpSGT(lhs_lv, const_zero); | |
auto rhs_is_pos = cgen_state_->ir_builder_.CreateICmpSGT(rhs_lv, const_zero); | |
auto args_is_pos = cgen_state_->ir_builder_.CreateOr(lhs_is_pos, rhs_is_pos); | |
// Get the absolute value of the args | |
auto lhs_neg = cgen_state_->ir_builder_.CreateNeg(lhs_lv); | |
auto rhs_neg = cgen_state_->ir_builder_.CreateNeg(rhs_lv); | |
auto lhs_pos = cgen_state_->ir_builder_.CreateSelect(lhs_is_neg, lhs_neg, lhs_lv); | |
auto rhs_pos = cgen_state_->ir_builder_.CreateSelect(rhs_is_neg, rhs_neg, rhs_lv); | |
// lhs and rhs are in [1, 2^n], where n is the number of digits. | |
// Check how the max allowed absolute value (2^n for negative, 2^(n-1) for | |
// positive) divided by an argument compares to the other. | |
auto plus_one = | |
cgen_state_->ir_builder_.CreateAdd(chosen_max, const_one, "", false, true); | |
auto limit = cgen_state_->ir_builder_.CreateSDiv(chosen_max, rhs_pos); | |
auto limit_plus = cgen_state_->ir_builder_.CreateSDiv(plus_one, rhs_pos); | |
auto cmp_plus = cgen_state_->ir_builder_.CreateICmpSGT(lhs_pos, limit_plus); | |
auto neg_overflow = cgen_state_->ir_builder_.CreateAnd(args_is_neg, cmp_plus); | |
cgen_state_->ir_builder_.CreateCondBr(neg_overflow, mul_fail, mul_check_2); | |
cgen_state_->ir_builder_.SetInsertPoint(mul_check_2); | |
auto cmp = cgen_state_->ir_builder_.CreateICmpSGT(lhs_pos, limit); | |
auto pos_overflow = cgen_state_->ir_builder_.CreateAnd(args_is_pos, cmp); | |
cgen_state_->ir_builder_.CreateCondBr(pos_overflow, mul_fail, mul_ok); | |
cgen_state_->ir_builder_.SetInsertPoint(mul_ok); | |
} | |
const auto ret = | |
null_check_suffix.empty() | |
? cgen_state_->ir_builder_.CreateMul(lhs_lv, rhs_lv) | |
: cgen_state_->emitCall( | |
"mul_" + null_typename + null_check_suffix, | |
{lhs_lv, rhs_lv, cgen_state_->llInt(inline_int_null_value(type))}); | |
if (need_overflow_check) { | |
cgen_state_->ir_builder_.SetInsertPoint(mul_fail); | |
cgen_state_->ir_builder_.CreateRet( | |
cgen_state_->llInt(Executor::ERR_OVERFLOW_OR_UNDERFLOW)); | |
cgen_state_->ir_builder_.SetInsertPoint(mul_ok); | |
} | |
return ret; | |
} |
Author
lmontigny
commented
Apr 12, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment