From 65068b63bd48cb7b923e26eccc7896ab6f220a02 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sat, 4 May 2024 14:54:27 +0200 Subject: [PATCH] Add llvm generic_compare implementation for integers This implements the following calculation posted by Richard on Zulip: https://roc.zulipchat.com/#narrow/stream/304641-ideas/topic/ordering.2Fsorting.20ability/near/403858126 --- crates/compiler/gen_llvm/src/llvm/lowlevel.rs | 14 +- crates/compiler/gen_llvm/src/llvm/mod.rs | 1 + crates/compiler/gen_llvm/src/llvm/sort.rs | 200 ++++++++++++++++++ 3 files changed, 214 insertions(+), 1 deletion(-) create mode 100644 crates/compiler/gen_llvm/src/llvm/sort.rs diff --git a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs index 5015206f917..531c5cef683 100644 --- a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs +++ b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs @@ -55,6 +55,7 @@ use crate::llvm::{ LLVM_SUB_WITH_OVERFLOW, }, refcounting::PointerToRefcount, + sort::generic_compare, }; use super::{build::Env, convert::zig_dec_type}; @@ -1270,7 +1271,18 @@ pub(crate) fn run_low_level<'a, 'ctx>( BasicValueEnum::IntValue(bool_val) } Compare => { - panic!("TODO: implement this") + // Sort.compare : elem, elem -> [LessThan, Equal, GreaterThan] + arguments_with_layouts!((lhs_arg, lhs_layout), (rhs_arg, rhs_layout)); + + generic_compare( + env, + layout_interner, + layout_ids, + lhs_arg, + rhs_arg, + lhs_layout, + rhs_layout, + ) } Hash => { unimplemented!() diff --git a/crates/compiler/gen_llvm/src/llvm/mod.rs b/crates/compiler/gen_llvm/src/llvm/mod.rs index d606191338c..3d009a2abd5 100644 --- a/crates/compiler/gen_llvm/src/llvm/mod.rs +++ b/crates/compiler/gen_llvm/src/llvm/mod.rs @@ -9,6 +9,7 @@ pub mod externs; mod intrinsics; mod lowlevel; pub mod refcounting; +pub mod sort; mod align; mod erased; diff --git a/crates/compiler/gen_llvm/src/llvm/sort.rs b/crates/compiler/gen_llvm/src/llvm/sort.rs new file mode 100644 index 00000000000..0e0eb354e29 --- /dev/null +++ b/crates/compiler/gen_llvm/src/llvm/sort.rs @@ -0,0 +1,200 @@ +use super::build::BuilderExt; +use crate::llvm::build::Env; +use inkwell::values::{BasicValueEnum, IntValue}; +use inkwell::IntPredicate; +use roc_builtins::bitcode::IntWidth; +use roc_mono::layout::{ + Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner, +}; + +pub fn generic_compare<'a, 'ctx>( + env: &Env<'a, 'ctx, '_>, + layout_interner: &STLayoutInterner<'a>, + _layout_ids: &mut LayoutIds<'a>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + lhs_layout: InLayout<'a>, + _rhs_layout: InLayout<'a>, +) -> BasicValueEnum<'ctx> { + let lhs_repr = layout_interner.get_repr(lhs_layout); + match lhs_repr { + LayoutRepr::Builtin(Builtin::Int(int_width)) => { + int_compare(env, lhs_val, rhs_val, int_width) + } + LayoutRepr::Builtin(Builtin::Float(_)) => todo!(), + LayoutRepr::Builtin(Builtin::Bool) => todo!(), + LayoutRepr::Builtin(Builtin::Decimal) => todo!(), + LayoutRepr::Builtin(Builtin::Str) => todo!(), + LayoutRepr::Builtin(Builtin::List(_)) => todo!(), + LayoutRepr::Struct(_) => todo!(), + LayoutRepr::LambdaSet(_) => unreachable!("cannot compare closures"), + LayoutRepr::FunctionPointer(_) => unreachable!("cannot compare function pointers"), + LayoutRepr::Erased(_) => unreachable!("cannot compare erased types"), + LayoutRepr::Union(_) => todo!(), + LayoutRepr::Ptr(_) => todo!(), + LayoutRepr::RecursivePointer(_) => todo!(), + } +} + +fn int_compare<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: IntWidth, +) -> BasicValueEnum<'ctx> { + // The following calculation will return 0 for equals, 1 for greater than, + // and 2 for less than. + // (a > b) + 2 * (a < b); + let lhs_gt_rhs = int_gt(env, lhs_val, rhs_val, builtin); + let lhs_lt_rhs = int_lt(env, lhs_val, rhs_val, builtin); + let two = env.ptr_int().const_int(2, false); + let lhs_lt_rhs_times_two = + env.builder + .new_build_int_mul(lhs_lt_rhs, two, "lhs_lt_rhs_times_two"); + let int_compare = + env.builder + .new_build_int_sub(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare"); + BasicValueEnum::IntValue(int_compare) +} + +fn int_lt<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: IntWidth, +) -> IntValue<'ctx> { + use IntWidth::*; + match builtin { + I128 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i28", + ), + I64 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i64", + ), + I32 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i32", + ), + I16 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i16", + ), + I8 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i8", + ), + U128 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u128", + ), + U64 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u64", + ), + U32 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u32", + ), + U16 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u16", + ), + U8 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u8", + ), + } +} + +fn int_gt<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: IntWidth, +) -> IntValue<'ctx> { + use IntWidth::*; + match builtin { + I128 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i28", + ), + I64 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i64", + ), + I32 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i32", + ), + I16 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i16", + ), + I8 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i8", + ), + U128 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u128", + ), + U64 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u64", + ), + U32 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u32", + ), + U16 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u16", + ), + U8 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u8", + ), + } +}