Skip to content

Instantly share code, notes, and snippets.

@lmontigny
Created April 12, 2023 15:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lmontigny/a52f17b05fa0e27e8ad483aafa796256 to your computer and use it in GitHub Desktop.
Save lmontigny/a52f17b05fa0e27e8ad483aafa796256 to your computer and use it in GitHub Desktop.
llvm::Value* CodeGenerator::codegenMul(const hdk::ir::BinOper* bin_oper,
llvm::Value* lhs_lv,
llvm::Value* rhs_lv,
const std::string& null_typename,
const std::string& null_check_suffix,
const hdk::ir::Type* type,
const CompilationOptions& co,
bool downscale) {
AUTOMATIC_IR_METADATA(cgen_state_);
CHECK_EQ(lhs_lv->getType(), rhs_lv->getType());
CHECK(type->isInteger() || type->isDecimal() || type->isInterval());
llvm::Value* chosen_max{nullptr};
llvm::Value* chosen_min{nullptr};
std::tie(chosen_max, chosen_min) = cgen_state_->inlineIntMaxMin(type->size(), true);
auto need_overflow_check =
!checkExpressionRanges(bin_oper,
static_cast<llvm::ConstantInt*>(chosen_min)->getSExtValue(),
static_cast<llvm::ConstantInt*>(chosen_max)->getSExtValue());
if (need_overflow_check && co.device_type == ExecutorDeviceType::CPU) {
return codegenBinOpWithOverflowForCPU(
bin_oper, lhs_lv, rhs_lv, null_check_suffix, type);
}
llvm::BasicBlock* mul_ok{nullptr};
llvm::BasicBlock* mul_fail{nullptr};
if (need_overflow_check) {
cgen_state_->needs_error_check_ = true;
mul_ok = llvm::BasicBlock::Create(
cgen_state_->context_, "mul_ok", cgen_state_->current_func_);
if (!null_check_suffix.empty()) {
codegenSkipOverflowCheckForNull(lhs_lv, rhs_lv, mul_ok, type);
}
// Overflow check following LLVM implementation
// https://github.com/hdoc/llvm-project/blob/release/15.x//llvm/include/llvm/Support/MathExtras.h
// Create LLVM Basic Block for control flow
mul_fail = llvm::BasicBlock::Create(
cgen_state_->context_, "mul_fail", cgen_state_->current_func_);
auto mul_check_1 = llvm::BasicBlock::Create(
cgen_state_->context_, "mul_check_1", cgen_state_->current_func_);
auto mul_check_2 = llvm::BasicBlock::Create(
cgen_state_->context_, "mul_check_2", cgen_state_->current_func_);
// Define LLVM constant
auto const_zero = llvm::ConstantInt::get(rhs_lv->getType(), 0, true);
auto const_one = llvm::ConstantInt::get(rhs_lv->getType(), 1, true);
// If any of the args was 0, no overflow occurs
auto lhs_is_zero = cgen_state_->ir_builder_.CreateICmpEQ(lhs_lv, const_zero);
auto rhs_is_zero = cgen_state_->ir_builder_.CreateICmpEQ(rhs_lv, const_zero);
auto args_zero = cgen_state_->ir_builder_.CreateOr(rhs_is_zero, lhs_is_zero);
cgen_state_->ir_builder_.CreateCondBr(args_zero, mul_ok, mul_check_1);
cgen_state_->ir_builder_.SetInsertPoint(mul_check_1);
// Check the sign of the args
auto lhs_is_neg = cgen_state_->ir_builder_.CreateICmpSLT(lhs_lv, const_zero);
auto rhs_is_neg = cgen_state_->ir_builder_.CreateICmpSLT(rhs_lv, const_zero);
auto args_is_neg = cgen_state_->ir_builder_.CreateOr(lhs_is_neg, rhs_is_neg);
auto lhs_is_pos = cgen_state_->ir_builder_.CreateICmpSGT(lhs_lv, const_zero);
auto rhs_is_pos = cgen_state_->ir_builder_.CreateICmpSGT(rhs_lv, const_zero);
auto args_is_pos = cgen_state_->ir_builder_.CreateOr(lhs_is_pos, rhs_is_pos);
// Get the absolute value of the args
auto lhs_neg = cgen_state_->ir_builder_.CreateNeg(lhs_lv);
auto rhs_neg = cgen_state_->ir_builder_.CreateNeg(rhs_lv);
auto lhs_pos = cgen_state_->ir_builder_.CreateSelect(lhs_is_neg, lhs_neg, lhs_lv);
auto rhs_pos = cgen_state_->ir_builder_.CreateSelect(rhs_is_neg, rhs_neg, rhs_lv);
// lhs and rhs are in [1, 2^n], where n is the number of digits.
// Check how the max allowed absolute value (2^n for negative, 2^(n-1) for
// positive) divided by an argument compares to the other.
auto plus_one =
cgen_state_->ir_builder_.CreateAdd(chosen_max, const_one, "", false, true);
auto limit = cgen_state_->ir_builder_.CreateSDiv(chosen_max, rhs_pos);
auto limit_plus = cgen_state_->ir_builder_.CreateSDiv(plus_one, rhs_pos);
auto cmp_plus = cgen_state_->ir_builder_.CreateICmpSGT(lhs_pos, limit_plus);
auto neg_overflow = cgen_state_->ir_builder_.CreateAnd(args_is_neg, cmp_plus);
cgen_state_->ir_builder_.CreateCondBr(neg_overflow, mul_fail, mul_check_2);
cgen_state_->ir_builder_.SetInsertPoint(mul_check_2);
auto cmp = cgen_state_->ir_builder_.CreateICmpSGT(lhs_pos, limit);
auto pos_overflow = cgen_state_->ir_builder_.CreateAnd(args_is_pos, cmp);
cgen_state_->ir_builder_.CreateCondBr(pos_overflow, mul_fail, mul_ok);
cgen_state_->ir_builder_.SetInsertPoint(mul_ok);
}
const auto ret =
null_check_suffix.empty()
? cgen_state_->ir_builder_.CreateMul(lhs_lv, rhs_lv)
: cgen_state_->emitCall(
"mul_" + null_typename + null_check_suffix,
{lhs_lv, rhs_lv, cgen_state_->llInt(inline_int_null_value(type))});
if (need_overflow_check) {
cgen_state_->ir_builder_.SetInsertPoint(mul_fail);
cgen_state_->ir_builder_.CreateRet(
cgen_state_->llInt(Executor::ERR_OVERFLOW_OR_UNDERFLOW));
cgen_state_->ir_builder_.SetInsertPoint(mul_ok);
}
return ret;
}
@lmontigny
Copy link
Author

declare i32 @old_filter_func()

; Function Attrs: alwaysinline
define i32 @filter_func(i8 addrspace(4)* %col_buf0, i64 %pos, i8 addrspace(4)* %col_buf1, i64 addrspace(4)* %out) #32 {
entry:
  br i1 true, label %filter_true, label %filter_false, !DiamondCodegen.cpp !30

filter_true:                                      ; preds = %entry
  %0 = call i64 @fixed_width_int_decode(i8 addrspace(4)* %col_buf0, i32 1, i64 %pos), !ColumnIR.cpp !31
  %1 = trunc i64 %0 to i8, !ColumnIR.cpp !31
  %2 = call i64 @cast_int8_t_to_int64_t_nullable(i8 %1, i8 -128, i64 -9223372036854775808), !CastIR.cpp !19
  %3 = call i64 @fixed_width_int_decode(i8 addrspace(4)* %col_buf1, i32 8, i64 %pos), !ColumnIR.cpp !32
  %4 = icmp eq i64 %2, -9223372036854775808, !LogicalIR.cpp !33
  %5 = icmp eq i64 %3, -9223372036854775808, !LogicalIR.cpp !33
  %6 = or i1 %4, %5, !ArithmeticIR.cpp !18
  br i1 %6, label %mul_ok, label %operands_not_null, !ArithmeticIR.cpp !18

filter_false:                                     ; preds = %mul_ok, %entry
  ret i32 0, !RowFuncBuilder.cpp !34

mul_ok:                                           ; preds = %mul_check_2, %operands_not_null, %filter_true
  %7 = call i64 @mul_int64_t_nullable(i64 %2, i64 %3, i64 -9223372036854775808), !ArithmeticIR.cpp !18
  %8 = call i64 @agg_sum_skip_val(i64 addrspace(4)* %out, i64 %7, i64 -9223372036854775808), !RowFuncBuilder.cpp !20
  br label %filter_false, !DiamondCodegen.cpp !35

operands_not_null:                                ; preds = %filter_true
  %9 = icmp eq i64 %2, 0, !ArithmeticIR.cpp !18
  %10 = icmp eq i64 %3, 0, !ArithmeticIR.cpp !18
  %11 = or i1 %10, %9, !ArithmeticIR.cpp !18
  br i1 %11, label %mul_ok, label %mul_check_1, !ArithmeticIR.cpp !18

mul_fail:                                         ; preds = %mul_check_2, %mul_check_1
  ret i32 7, !ArithmeticIR.cpp !18

mul_check_1:                                      ; preds = %operands_not_null
  %12 = icmp slt i64 %2, 0, !ArithmeticIR.cpp !18
  %13 = icmp slt i64 %3, 0, !ArithmeticIR.cpp !18
  %14 = or i1 %12, %13, !ArithmeticIR.cpp !18
  %15 = icmp sgt i64 %2, 0, !ArithmeticIR.cpp !18
  %16 = icmp sgt i64 %3, 0, !ArithmeticIR.cpp !18
  %17 = or i1 %15, %16, !ArithmeticIR.cpp !18
  %18 = sub i64 0, %2, !ArithmeticIR.cpp !18
  %19 = sub i64 0, %3, !ArithmeticIR.cpp !18
  %20 = select i1 %12, i64 %18, i64 %2, !ArithmeticIR.cpp !18
  %21 = select i1 %13, i64 %19, i64 %3, !ArithmeticIR.cpp !18
  %22 = sdiv i64 9223372036854775807, %21, !ArithmeticIR.cpp !18
  %23 = sdiv i64 -9223372036854775808, %21, !ArithmeticIR.cpp !18
  %24 = icmp sgt i64 %20, %23, !ArithmeticIR.cpp !18
  %25 = and i1 %14, %24, !ArithmeticIR.cpp !18
  br i1 %25, label %mul_fail, label %mul_check_2, !ArithmeticIR.cpp !18

mul_check_2:                                      ; preds = %mul_check_1
  %26 = icmp sgt i64 %20, %22, !ArithmeticIR.cpp !18
  %27 = and i1 %17, %26, !ArithmeticIR.cpp !18
  br i1 %27, label %mul_fail, label %mul_ok, !ArithmeticIR.cpp !18
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment