From aa6cfb2669e51b3211017bde2faa6eddba22ecbe Mon Sep 17 00:00:00 2001
From: Ben Kimock <kimockb@gmail.com>
Date: Fri, 8 Mar 2024 21:50:23 -0500
Subject: [PATCH] Sink ptrtoint for RMW ops on pointers to cg_llvm

---
 compiler/rustc_codegen_llvm/src/builder.rs    |  8 ++-
 compiler/rustc_codegen_ssa/src/common.rs      |  2 +-
 .../rustc_codegen_ssa/src/mir/intrinsic.rs    | 50 ++++---------------
 3 files changed, 18 insertions(+), 42 deletions(-)

diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs
index ca2e2b575805f..63e59ea13fc35 100644
--- a/compiler/rustc_codegen_llvm/src/builder.rs
+++ b/compiler/rustc_codegen_llvm/src/builder.rs
@@ -1132,9 +1132,15 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
         &mut self,
         op: rustc_codegen_ssa::common::AtomicRmwBinOp,
         dst: &'ll Value,
-        src: &'ll Value,
+        mut src: &'ll Value,
         order: rustc_codegen_ssa::common::AtomicOrdering,
     ) -> &'ll Value {
+        // The only RMW operation that LLVM supports on pointers is compare-exchange.
+        if self.val_ty(src) == self.type_ptr()
+            && op != rustc_codegen_ssa::common::AtomicRmwBinOp::AtomicXchg
+        {
+            src = self.ptrtoint(src, self.type_isize());
+        }
         unsafe {
             llvm::LLVMBuildAtomicRMW(
                 self.llbuilder,
diff --git a/compiler/rustc_codegen_ssa/src/common.rs b/compiler/rustc_codegen_ssa/src/common.rs
index 641ac3eb80872..44a2434238dad 100644
--- a/compiler/rustc_codegen_ssa/src/common.rs
+++ b/compiler/rustc_codegen_ssa/src/common.rs
@@ -42,7 +42,7 @@ pub enum RealPredicate {
     RealPredicateTrue,
 }
 
-#[derive(Copy, Clone)]
+#[derive(Copy, Clone, PartialEq)]
 pub enum AtomicRmwBinOp {
     AtomicXchg,
     AtomicAdd,
diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
index 82488829b6e16..1d1826d984474 100644
--- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
@@ -350,14 +350,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                         if int_type_width_signed(ty, bx.tcx()).is_some() || ty.is_unsafe_ptr() {
                             let weak = instruction == "cxchgweak";
                             let dst = args[0].immediate();
-                            let mut cmp = args[1].immediate();
-                            let mut src = args[2].immediate();
-                            if ty.is_unsafe_ptr() {
-                                // Some platforms do not support atomic operations on pointers,
-                                // so we cast to integer first.
-                                cmp = bx.ptrtoint(cmp, bx.type_isize());
-                                src = bx.ptrtoint(src, bx.type_isize());
-                            }
+                            let cmp = args[1].immediate();
+                            let src = args[2].immediate();
                             let (val, success) = bx.atomic_cmpxchg(
                                 dst,
                                 cmp,
@@ -385,26 +379,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                             let layout = bx.layout_of(ty);
                             let size = layout.size;
                             let source = args[0].immediate();
-                            if ty.is_unsafe_ptr() {
-                                // Some platforms do not support atomic operations on pointers,
-                                // so we cast to integer first...
-                                let llty = bx.type_isize();
-                                let result = bx.atomic_load(
-                                    llty,
-                                    source,
-                                    parse_ordering(bx, ordering),
-                                    size,
-                                );
-                                // ... and then cast the result back to a pointer
-                                bx.inttoptr(result, bx.backend_type(layout))
-                            } else {
-                                bx.atomic_load(
-                                    bx.backend_type(layout),
-                                    source,
-                                    parse_ordering(bx, ordering),
-                                    size,
-                                )
-                            }
+                            bx.atomic_load(
+                                bx.backend_type(layout),
+                                source,
+                                parse_ordering(bx, ordering),
+                                size,
+                            )
                         } else {
                             invalid_monomorphization(ty);
                             return Ok(());
@@ -415,13 +395,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                         let ty = fn_args.type_at(0);
                         if int_type_width_signed(ty, bx.tcx()).is_some() || ty.is_unsafe_ptr() {
                             let size = bx.layout_of(ty).size;
-                            let mut val = args[1].immediate();
+                            let val = args[1].immediate();
                             let ptr = args[0].immediate();
-                            if ty.is_unsafe_ptr() {
-                                // Some platforms do not support atomic operations on pointers,
-                                // so we cast to integer first.
-                                val = bx.ptrtoint(val, bx.type_isize());
-                            }
                             bx.atomic_store(val, ptr, parse_ordering(bx, ordering), size);
                         } else {
                             invalid_monomorphization(ty);
@@ -465,12 +440,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                         let ty = fn_args.type_at(0);
                         if int_type_width_signed(ty, bx.tcx()).is_some() || ty.is_unsafe_ptr() {
                             let ptr = args[0].immediate();
-                            let mut val = args[1].immediate();
-                            if ty.is_unsafe_ptr() {
-                                // Some platforms do not support atomic operations on pointers,
-                                // so we cast to integer first.
-                                val = bx.ptrtoint(val, bx.type_isize());
-                            }
+                            let val = args[1].immediate();
                             bx.atomic_rmw(atom_op, ptr, val, parse_ordering(bx, ordering))
                         } else {
                             invalid_monomorphization(ty);