Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[x64][win] Add compiler support for x64 import call optimization (equivalent to MSVC /d2guardretpoline) #126631

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dpaoliello
Copy link
Contributor

This is the x64 equivalent of #121516

Since import call optimization was originally added to x64 Windows to implement a more efficient retpoline mitigation the section and constant names relating to this all mention "retpoline" and we need to mark indirect calls, control-flow guard calls and jumps for jump tables in the section alongside calls to imported functions.

As with the AArch64 feature, this emits a new section into the obj which is used by the MSVC linker to generate the Dynamic Value Relocation Table and the section itself does not appear in the final binary.

The Windows Loader requires a specific sequence of instructions be emitted when this feature is enabled:

  • Indirect calls/jumps must have the function pointer to jump to in rax.
  • Calls to imported functions must use the rex prefix and be followed by a 5-byte nop.
  • Indirect calls must be followed by a 3-byte nop.

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Daniel Paoliello (dpaoliello)

Changes

This is the x64 equivalent of #121516

Since import call optimization was originally added to x64 Windows to implement a more efficient retpoline mitigation the section and constant names relating to this all mention "retpoline" and we need to mark indirect calls, control-flow guard calls and jumps for jump tables in the section alongside calls to imported functions.

As with the AArch64 feature, this emits a new section into the obj which is used by the MSVC linker to generate the Dynamic Value Relocation Table and the section itself does not appear in the final binary.

The Windows Loader requires a specific sequence of instructions be emitted when this feature is enabled:

  • Indirect calls/jumps must have the function pointer to jump to in rax.
  • Calls to imported functions must use the rex prefix and be followed by a 5-byte nop.
  • Indirect calls must be followed by a 3-byte nop.

Patch is 29.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126631.diff

15 Files Affected:

  • (modified) llvm/include/llvm/Transforms/CFGuard.h (+3)
  • (modified) llvm/lib/MC/MCObjectFileInfo.cpp (+5)
  • (modified) llvm/lib/Target/X86/X86AsmPrinter.cpp (+32)
  • (modified) llvm/lib/Target/X86/X86AsmPrinter.h (+32-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+15-3)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+6-2)
  • (modified) llvm/lib/Target/X86/X86ISelLoweringCall.cpp (+5-2)
  • (modified) llvm/lib/Target/X86/X86InstrCompiler.td (+2)
  • (modified) llvm/lib/Target/X86/X86InstrFragments.td (+3)
  • (modified) llvm/lib/Target/X86/X86MCInstLower.cpp (+152-17)
  • (modified) llvm/lib/Transforms/CFGuard/CFGuard.cpp (+13-2)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll (+83)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll (+21)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization.ll (+65)
  • (added) llvm/test/MC/X86/win-import-call-optimization.s (+69)
diff --git a/llvm/include/llvm/Transforms/CFGuard.h b/llvm/include/llvm/Transforms/CFGuard.h
index caf822a2ec9fb3a..b81db8f487965ff 100644
--- a/llvm/include/llvm/Transforms/CFGuard.h
+++ b/llvm/include/llvm/Transforms/CFGuard.h
@@ -16,6 +16,7 @@
 namespace llvm {
 
 class FunctionPass;
+class GlobalValue;
 
 class CFGuardPass : public PassInfoMixin<CFGuardPass> {
 public:
@@ -34,6 +35,8 @@ FunctionPass *createCFGuardCheckPass();
 /// Insert Control FLow Guard dispatches on indirect function calls.
 FunctionPass *createCFGuardDispatchPass();
 
+bool isCFGuardFunction(const GlobalValue *GV);
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp
index 150e38a94db6a67..334673c4dba79a9 100644
--- a/llvm/lib/MC/MCObjectFileInfo.cpp
+++ b/llvm/lib/MC/MCObjectFileInfo.cpp
@@ -599,6 +599,11 @@ void MCObjectFileInfo::initCOFFMCObjectFileInfo(const Triple &T) {
   if (T.getArch() == Triple::aarch64) {
     ImportCallSection =
         Ctx->getCOFFSection(".impcall", COFF::IMAGE_SCN_LNK_INFO);
+  } else if (T.getArch() == Triple::x86_64) {
+    // Import Call Optimization on x64 leverages the same metadata as the
+    // retpoline mitigation, hence the unusual section name.
+    ImportCallSection =
+        Ctx->getCOFFSection(".retplne", COFF::IMAGE_SCN_LNK_INFO);
   }
 
   // Debug info.
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.cpp b/llvm/lib/Target/X86/X86AsmPrinter.cpp
index f01e47b41cf5e44..52f8280b259650d 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.cpp
+++ b/llvm/lib/Target/X86/X86AsmPrinter.cpp
@@ -920,6 +920,9 @@ void X86AsmPrinter::emitStartOfAsmFile(Module &M) {
     OutStreamer->emitSymbolAttribute(S, MCSA_Global);
     OutStreamer->emitAssignment(
         S, MCConstantExpr::create(Feat00Value, MMI->getContext()));
+
+    if (M.getModuleFlag("import-call-optimization"))
+      EnableImportCallOptimization = true;
   }
   OutStreamer->emitSyntaxDirective();
 
@@ -1021,6 +1024,35 @@ void X86AsmPrinter::emitEndOfAsmFile(Module &M) {
     // safe to set.
     OutStreamer->emitAssemblerFlag(MCAF_SubsectionsViaSymbols);
   } else if (TT.isOSBinFormatCOFF()) {
+    // If import call optimization is enabled, emit the appropriate section.
+    // We do this whether or not we recorded any items.
+    if (EnableImportCallOptimization) {
+      OutStreamer->switchSection(getObjFileLowering().getImportCallSection());
+
+      // Section always starts with some magic.
+      constexpr char ImpCallMagic[12] = "RetpolineV1";
+      OutStreamer->emitBytes(StringRef{ImpCallMagic, sizeof(ImpCallMagic)});
+
+      // Layout of this section is:
+      // Per section that contains an item to record:
+      //  uint32_t SectionSize: Size in bytes for information in this section.
+      //  uint32_t Section Number
+      //  Per call to imported function in section:
+      //    uint32_t Kind: the kind of item.
+      //    uint32_t InstOffset: the offset of the instr in its parent section.
+      for (auto &[Section, CallsToImportedFuncs] :
+           SectionToImportedFunctionCalls) {
+        unsigned SectionSize =
+            sizeof(uint32_t) * (2 + 2 * CallsToImportedFuncs.size());
+        OutStreamer->emitInt32(SectionSize);
+        OutStreamer->emitCOFFSecNumber(Section->getBeginSymbol());
+        for (auto &[CallsiteSymbol, Kind] : CallsToImportedFuncs) {
+          OutStreamer->emitInt32(Kind);
+          OutStreamer->emitCOFFSecOffset(CallsiteSymbol);
+        }
+      }
+    }
+
     if (usesMSVCFloatingPoint(TT, M)) {
       // In Windows' libcmt.lib, there is a file which is linked in only if the
       // symbol _fltused is referenced. Linking this in causes some
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.h b/llvm/lib/Target/X86/X86AsmPrinter.h
index 693021eca329588..47e82c4dfcea5d4 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.h
+++ b/llvm/lib/Target/X86/X86AsmPrinter.h
@@ -31,6 +31,26 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   bool EmitFPOData = false;
   bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
   bool IndCSPrefix = false;
+  bool EnableImportCallOptimization = false;
+
+  enum ImportCallKind : unsigned {
+    IMAGE_RETPOLINE_AMD64_IMPORT_BR = 0x02,
+    IMAGE_RETPOLINE_AMD64_IMPORT_CALL = 0x03,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR = 0x04,
+    IMAGE_RETPOLINE_AMD64_INDIR_CALL = 0x05,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR_REX = 0x06,
+    IMAGE_RETPOLINE_AMD64_CFG_BR = 0x08,
+    IMAGE_RETPOLINE_AMD64_CFG_CALL = 0x09,
+    IMAGE_RETPOLINE_AMD64_CFG_BR_REX = 0x0A,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST = 0x010,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_LAST = 0x01F,
+  };
+  struct ImportCallInfo {
+    MCSymbol *CalleeSymbol;
+    ImportCallKind Kind;
+  };
+  DenseMap<MCSection *, std::vector<ImportCallInfo>>
+      SectionToImportedFunctionCalls;
 
   // This utility class tracks the length of a stackmap instruction's 'shadow'.
   // It is used by the X86AsmPrinter to ensure that the stackmap shadow
@@ -45,7 +65,7 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
     void startFunction(MachineFunction &MF) {
       this->MF = &MF;
     }
-    void count(MCInst &Inst, const MCSubtargetInfo &STI,
+    void count(const MCInst &Inst, const MCSubtargetInfo &STI,
                MCCodeEmitter *CodeEmitter);
 
     // Called to signal the start of a shadow of RequiredSize bytes.
@@ -126,6 +146,17 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   void emitMachOIFuncStubHelperBody(Module &M, const GlobalIFunc &GI,
                                     MCSymbol *LazyPointer) override;
 
+  void emitCallInstruction(const llvm::MCInst &MCI);
+
+  // Emits a label to mark the next instruction as being relevant to Import Call
+  // Optimization.
+  void emitLabelAndRecordForImportCallOptimization(ImportCallKind Kind);
+
+  // Ensure that rax is used as the operand for the given instruction.
+  //
+  // NOTE: This assumes that it is safe to clobber rax.
+  void ensureRaxUsedForOperand(MCInst &TmpInst);
+
 public:
   X86AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer);
 
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6cf6061deba7025..30a98a5a13ebfad 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -18922,7 +18922,7 @@ SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
 
 SDValue X86TargetLowering::LowerExternalSymbol(SDValue Op,
                                                SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 SDValue
@@ -18950,7 +18950,8 @@ X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const {
 /// Creates target global address or external symbol nodes for calls or
 /// other uses.
 SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                                 bool ForCall) const {
+                                                 bool ForCall,
+                                                 bool *IsImpCall) const {
   // Unpack the global address or external symbol.
   SDLoc dl(Op);
   const GlobalValue *GV = nullptr;
@@ -19000,6 +19001,16 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
   if (ForCall && !NeedsLoad && !HasPICReg && Offset == 0)
     return Result;
 
+  // If Import Call Optimization is enabled and this is an imported function
+  // then make a note of it and return the global address without wrapping.
+  if (IsImpCall && (OpFlags == X86II::MO_DLLIMPORT) &&
+      Mod.getModuleFlag("import-call-optimization")) {
+    assert(ForCall && "Should only enable import call optimization if we are "
+                      "lowering a call");
+    *IsImpCall = true;
+    return Result;
+  }
+
   Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);
 
   // With PIC, the address is actually $g + Offset.
@@ -19025,7 +19036,7 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
 
 SDValue
 X86TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 static SDValue GetTLSADDR(SelectionDAG &DAG, GlobalAddressSDNode *GA,
@@ -34562,6 +34573,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FST)
   NODE_NAME_CASE(CALL)
   NODE_NAME_CASE(CALL_RVMARKER)
+  NODE_NAME_CASE(IMP_CALL)
   NODE_NAME_CASE(BT)
   NODE_NAME_CASE(CMP)
   NODE_NAME_CASE(FCMP)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index fe79fefeed631c6..6324cc65398a0a4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -81,6 +81,10 @@ namespace llvm {
     // marker instruction.
     CALL_RVMARKER,
 
+    // Pseudo for a call to an imported function to ensure the correct machine
+    // instruction is emitted for Import Call Optimization.
+    IMP_CALL,
+
     /// X86 compare and logical compare instructions.
     CMP,
     FCMP,
@@ -1733,8 +1737,8 @@ namespace llvm {
 
     /// Creates target global address or external symbol nodes for calls or
     /// other uses.
-    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                  bool ForCall) const;
+    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG, bool ForCall,
+                                  bool *IsImpCall) const;
 
     SDValue LowerSINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerUINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
index 6835c7e336a5cb1..cbbdf37a3fb75f7 100644
--- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
+++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
@@ -2402,6 +2402,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     InGlue = Chain.getValue(1);
   }
 
+  bool IsImpCall = false;
   if (DAG.getTarget().getCodeModel() == CodeModel::Large) {
     assert(Is64Bit && "Large code model is only legal in 64-bit mode.");
     // In the 64-bit large code model, we have to make all calls
@@ -2414,7 +2415,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // ForCall to true here has the effect of removing WrapperRIP when possible
     // to allow direct calls to be selected without first materializing the
     // address into a register.
-    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true);
+    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true, &IsImpCall);
   } else if (Subtarget.isTarget64BitILP32() &&
              Callee.getValueType() == MVT::i32) {
     // Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI
@@ -2536,7 +2537,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   // Returns a chain & a glue for retval copy to use.
   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
-  if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
+  if (IsImpCall) {
+    Chain = DAG.getNode(X86ISD::IMP_CALL, dl, NodeTys, Ops);
+  } else if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
     Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops);
   } else if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
     // Calls with a "clang.arc.attachedcall" bundle are special. They should be
diff --git a/llvm/lib/Target/X86/X86InstrCompiler.td b/llvm/lib/Target/X86/X86InstrCompiler.td
index 9687ae29f1c782f..5f603de695906f1 100644
--- a/llvm/lib/Target/X86/X86InstrCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrCompiler.td
@@ -1309,6 +1309,8 @@ def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 texternalsym:$dst)),
 def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 tglobaladdr:$dst)),
           (CALL64pcrel32_RVMARKER tglobaladdr:$rvfunc, tglobaladdr:$dst)>;
 
+def : Pat<(X86imp_call (i64 tglobaladdr:$dst)),
+          (CALL64pcrel32 tglobaladdr:$dst)>;
 
 // Tailcall stuff. The TCRETURN instructions execute after the epilog, so they
 // can never use callee-saved registers. That is the purpose of the GR64_TC
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index ddbc7c55a6113b4..3ab820de78efcbe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -210,6 +210,9 @@ def X86call_rvmarker  : SDNode<"X86ISD::CALL_RVMARKER",     SDT_X86Call,
                         [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
                          SDNPVariadic]>;
 
+def X86imp_call  : SDNode<"X86ISD::IMP_CALL",     SDT_X86Call,
+                        [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
+                         SDNPVariadic]>;
 
 def X86NoTrackCall : SDNode<"X86ISD::NT_CALL", SDT_X86Call,
                             [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
diff --git a/llvm/lib/Target/X86/X86MCInstLower.cpp b/llvm/lib/Target/X86/X86MCInstLower.cpp
index 0f8fbf5be1c9557..f265093a60d12e9 100644
--- a/llvm/lib/Target/X86/X86MCInstLower.cpp
+++ b/llvm/lib/Target/X86/X86MCInstLower.cpp
@@ -47,6 +47,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
 #include <string>
@@ -112,7 +113,7 @@ struct NoAutoPaddingScope {
 static void emitX86Nops(MCStreamer &OS, unsigned NumBytes,
                         const X86Subtarget *Subtarget);
 
-void X86AsmPrinter::StackMapShadowTracker::count(MCInst &Inst,
+void X86AsmPrinter::StackMapShadowTracker::count(const MCInst &Inst,
                                                  const MCSubtargetInfo &STI,
                                                  MCCodeEmitter *CodeEmitter) {
   if (InShadow) {
@@ -2193,6 +2194,27 @@ static void addConstantComments(const MachineInstr *MI,
   }
 }
 
+bool isImportedFunction(const MachineOperand &MO) {
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_DLLIMPORT);
+}
+
+bool isCallToCFGuardFunction(const MachineInstr *MI) {
+  assert(MI->getOpcode() == X86::TAILJMPm64_REX ||
+         MI->getOpcode() == X86::CALL64m);
+  const MachineOperand &MO = MI->getOperand(3);
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_NO_FLAG) &&
+         isCFGuardFunction(MO.getGlobal());
+}
+
+bool hasJumpTableInfoInBlock(const llvm::MachineInstr *MI) {
+  const MachineBasicBlock &MBB = *MI->getParent();
+  for (auto I = MBB.instr_rbegin(), E = MBB.instr_rend(); I != E; ++I)
+    if (I->isJumpTableDebugInfo())
+      return true;
+
+  return false;
+}
+
 void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   // FIXME: Enable feature predicate checks once all the test pass.
   // X86_MC::verifyInstructionPredicates(MI->getOpcode(),
@@ -2271,20 +2293,64 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::TAILJMPd64:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
-    [[fallthrough]];
-  case X86::TAILJMPr:
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
+  case X86::TAILJMPm64_REX:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_BR_REX);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
   case X86::TAILJMPm:
   case X86::TAILJMPd:
   case X86::TAILJMPd_CC:
-  case X86::TAILJMPr64:
   case X86::TAILJMPm64:
   case X86::TAILJMPd64_CC:
-  case X86::TAILJMPr64_REX:
-  case X86::TAILJMPm64_REX:
     // Lower these as normal, but add some comments.
     OutStreamer->AddComment("TAILCALL");
     break;
 
+  case X86::TAILJMPr:
+  case X86::TAILJMPr64:
+  case X86::TAILJMPr64_REX: {
+    MCInst TmpInst;
+    MCInstLowering.Lower(MI, TmpInst);
+
+    if (EnableImportCallOptimization) {
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    EmitAndCountInstruction(TmpInst);
+    return;
+  }
+
+  case X86::JMP64r:
+  case X86::JMP64m:
+    if (EnableImportCallOptimization && hasJumpTableInfoInBlock(MI)) {
+      uint16_t EncodedReg =
+          this->getSubtarget().getRegisterInfo()->getEncodingValue(
+              MI->getOperand(0).getReg().asMCReg());
+      emitLabelAndRecordForImportCallOptimization(
+          (ImportCallKind)(IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST +
+                           EncodedReg));
+    }
+    break;
+
   case X86::TLS_addr32:
   case X86::TLS_addr64:
   case X86::TLS_addrX32:
@@ -2469,7 +2535,49 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::CALL64pcrel32:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_CALL);
+
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // For Import Call Optimization to work, we need a the call instruction
+      // with a rex prefix, and a 5-byte nop after the call instruction.
+      EmitAndCountInstruction(MCInstBuilder(X86::REX64_PREFIX));
+      emitCallInstruction(TmpInst);
+      emitNop(*OutStreamer, 5, Subtarget);
+      return;
+    }
+
     break;
+  case X86::CALL64r:
+    if (EnableImportCallOptimization) {
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_CALL);
+      emitCallInstruction(TmpInst);
+
+      // For Import Call Optimization to work, we a 3-byte nop after the call
+      // instruction.
+      emitNop(*OutStreamer, 3, Subtarget);
+      return;
+    }
+
+    break;
+  case X86::CALL64m:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_CALL);
+    }
+    break;
+
   case X86::JCC_1:
     // Two instruction prefixes (2EH for branch not-taken and 3EH for branch
     // taken) are used as branch hints. Here we add branch taken prefix for
@@ -2490,20 +2598,47 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   MCInst TmpInst;
   MCInstLowering.Lower(MI, TmpInst);
 
-  // Stackmap shadows cannot include branch targets, so we can count the bytes
-  // in a call towards the shadow, but must ensure that the no thread returns
-  // in to the stackmap shadow.  The only way to achieve this is if the call
-  // is at the end of the shadow.
   if (MI->isCall()) {
-    // Count then size of the call towards the shadow
-    SMShadowTracker.count(TmpInst, getSubtargetInfo(), CodeEmitter.get());
-    // Then flush the shadow so that we fill with nops before the call, not
-    // after it.
-    SMShadowTracker.emitShadowPadding(*OutStreamer, getSubtargetInfo());
-    // Then emit the call
-    OutStreamer->emitInstruction(TmpInst, getSubtargetInfo());
+    emitCallInstruction(TmpInst);
     return;
   }
 
   EmitAndCountInstruction(TmpInst);
 }
+
+void X86AsmPrinter::emitCallInstruction(const llvm::MCInst &MCI) {
+  // Stackmap shadows cannot include branch targets, so we can count the bytes
+  // in a call towards the shado...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-backend-x86

Author: Daniel Paoliello (dpaoliello)

Changes

This is the x64 equivalent of #121516

Since import call optimization was originally added to x64 Windows to implement a more efficient retpoline mitigation the section and constant names relating to this all mention "retpoline" and we need to mark indirect calls, control-flow guard calls and jumps for jump tables in the section alongside calls to imported functions.

As with the AArch64 feature, this emits a new section into the obj which is used by the MSVC linker to generate the Dynamic Value Relocation Table and the section itself does not appear in the final binary.

The Windows Loader requires a specific sequence of instructions be emitted when this feature is enabled:

  • Indirect calls/jumps must have the function pointer to jump to in rax.
  • Calls to imported functions must use the rex prefix and be followed by a 5-byte nop.
  • Indirect calls must be followed by a 3-byte nop.

Patch is 29.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126631.diff

15 Files Affected:

  • (modified) llvm/include/llvm/Transforms/CFGuard.h (+3)
  • (modified) llvm/lib/MC/MCObjectFileInfo.cpp (+5)
  • (modified) llvm/lib/Target/X86/X86AsmPrinter.cpp (+32)
  • (modified) llvm/lib/Target/X86/X86AsmPrinter.h (+32-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+15-3)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+6-2)
  • (modified) llvm/lib/Target/X86/X86ISelLoweringCall.cpp (+5-2)
  • (modified) llvm/lib/Target/X86/X86InstrCompiler.td (+2)
  • (modified) llvm/lib/Target/X86/X86InstrFragments.td (+3)
  • (modified) llvm/lib/Target/X86/X86MCInstLower.cpp (+152-17)
  • (modified) llvm/lib/Transforms/CFGuard/CFGuard.cpp (+13-2)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll (+83)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll (+21)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization.ll (+65)
  • (added) llvm/test/MC/X86/win-import-call-optimization.s (+69)
diff --git a/llvm/include/llvm/Transforms/CFGuard.h b/llvm/include/llvm/Transforms/CFGuard.h
index caf822a2ec9fb3a..b81db8f487965ff 100644
--- a/llvm/include/llvm/Transforms/CFGuard.h
+++ b/llvm/include/llvm/Transforms/CFGuard.h
@@ -16,6 +16,7 @@
 namespace llvm {
 
 class FunctionPass;
+class GlobalValue;
 
 class CFGuardPass : public PassInfoMixin<CFGuardPass> {
 public:
@@ -34,6 +35,8 @@ FunctionPass *createCFGuardCheckPass();
 /// Insert Control FLow Guard dispatches on indirect function calls.
 FunctionPass *createCFGuardDispatchPass();
 
+bool isCFGuardFunction(const GlobalValue *GV);
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp
index 150e38a94db6a67..334673c4dba79a9 100644
--- a/llvm/lib/MC/MCObjectFileInfo.cpp
+++ b/llvm/lib/MC/MCObjectFileInfo.cpp
@@ -599,6 +599,11 @@ void MCObjectFileInfo::initCOFFMCObjectFileInfo(const Triple &T) {
   if (T.getArch() == Triple::aarch64) {
     ImportCallSection =
         Ctx->getCOFFSection(".impcall", COFF::IMAGE_SCN_LNK_INFO);
+  } else if (T.getArch() == Triple::x86_64) {
+    // Import Call Optimization on x64 leverages the same metadata as the
+    // retpoline mitigation, hence the unusual section name.
+    ImportCallSection =
+        Ctx->getCOFFSection(".retplne", COFF::IMAGE_SCN_LNK_INFO);
   }
 
   // Debug info.
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.cpp b/llvm/lib/Target/X86/X86AsmPrinter.cpp
index f01e47b41cf5e44..52f8280b259650d 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.cpp
+++ b/llvm/lib/Target/X86/X86AsmPrinter.cpp
@@ -920,6 +920,9 @@ void X86AsmPrinter::emitStartOfAsmFile(Module &M) {
     OutStreamer->emitSymbolAttribute(S, MCSA_Global);
     OutStreamer->emitAssignment(
         S, MCConstantExpr::create(Feat00Value, MMI->getContext()));
+
+    if (M.getModuleFlag("import-call-optimization"))
+      EnableImportCallOptimization = true;
   }
   OutStreamer->emitSyntaxDirective();
 
@@ -1021,6 +1024,35 @@ void X86AsmPrinter::emitEndOfAsmFile(Module &M) {
     // safe to set.
     OutStreamer->emitAssemblerFlag(MCAF_SubsectionsViaSymbols);
   } else if (TT.isOSBinFormatCOFF()) {
+    // If import call optimization is enabled, emit the appropriate section.
+    // We do this whether or not we recorded any items.
+    if (EnableImportCallOptimization) {
+      OutStreamer->switchSection(getObjFileLowering().getImportCallSection());
+
+      // Section always starts with some magic.
+      constexpr char ImpCallMagic[12] = "RetpolineV1";
+      OutStreamer->emitBytes(StringRef{ImpCallMagic, sizeof(ImpCallMagic)});
+
+      // Layout of this section is:
+      // Per section that contains an item to record:
+      //  uint32_t SectionSize: Size in bytes for information in this section.
+      //  uint32_t Section Number
+      //  Per call to imported function in section:
+      //    uint32_t Kind: the kind of item.
+      //    uint32_t InstOffset: the offset of the instr in its parent section.
+      for (auto &[Section, CallsToImportedFuncs] :
+           SectionToImportedFunctionCalls) {
+        unsigned SectionSize =
+            sizeof(uint32_t) * (2 + 2 * CallsToImportedFuncs.size());
+        OutStreamer->emitInt32(SectionSize);
+        OutStreamer->emitCOFFSecNumber(Section->getBeginSymbol());
+        for (auto &[CallsiteSymbol, Kind] : CallsToImportedFuncs) {
+          OutStreamer->emitInt32(Kind);
+          OutStreamer->emitCOFFSecOffset(CallsiteSymbol);
+        }
+      }
+    }
+
     if (usesMSVCFloatingPoint(TT, M)) {
       // In Windows' libcmt.lib, there is a file which is linked in only if the
       // symbol _fltused is referenced. Linking this in causes some
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.h b/llvm/lib/Target/X86/X86AsmPrinter.h
index 693021eca329588..47e82c4dfcea5d4 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.h
+++ b/llvm/lib/Target/X86/X86AsmPrinter.h
@@ -31,6 +31,26 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   bool EmitFPOData = false;
   bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
   bool IndCSPrefix = false;
+  bool EnableImportCallOptimization = false;
+
+  enum ImportCallKind : unsigned {
+    IMAGE_RETPOLINE_AMD64_IMPORT_BR = 0x02,
+    IMAGE_RETPOLINE_AMD64_IMPORT_CALL = 0x03,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR = 0x04,
+    IMAGE_RETPOLINE_AMD64_INDIR_CALL = 0x05,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR_REX = 0x06,
+    IMAGE_RETPOLINE_AMD64_CFG_BR = 0x08,
+    IMAGE_RETPOLINE_AMD64_CFG_CALL = 0x09,
+    IMAGE_RETPOLINE_AMD64_CFG_BR_REX = 0x0A,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST = 0x010,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_LAST = 0x01F,
+  };
+  struct ImportCallInfo {
+    MCSymbol *CalleeSymbol;
+    ImportCallKind Kind;
+  };
+  DenseMap<MCSection *, std::vector<ImportCallInfo>>
+      SectionToImportedFunctionCalls;
 
   // This utility class tracks the length of a stackmap instruction's 'shadow'.
   // It is used by the X86AsmPrinter to ensure that the stackmap shadow
@@ -45,7 +65,7 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
     void startFunction(MachineFunction &MF) {
       this->MF = &MF;
     }
-    void count(MCInst &Inst, const MCSubtargetInfo &STI,
+    void count(const MCInst &Inst, const MCSubtargetInfo &STI,
                MCCodeEmitter *CodeEmitter);
 
     // Called to signal the start of a shadow of RequiredSize bytes.
@@ -126,6 +146,17 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   void emitMachOIFuncStubHelperBody(Module &M, const GlobalIFunc &GI,
                                     MCSymbol *LazyPointer) override;
 
+  void emitCallInstruction(const llvm::MCInst &MCI);
+
+  // Emits a label to mark the next instruction as being relevant to Import Call
+  // Optimization.
+  void emitLabelAndRecordForImportCallOptimization(ImportCallKind Kind);
+
+  // Ensure that rax is used as the operand for the given instruction.
+  //
+  // NOTE: This assumes that it is safe to clobber rax.
+  void ensureRaxUsedForOperand(MCInst &TmpInst);
+
 public:
   X86AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer);
 
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6cf6061deba7025..30a98a5a13ebfad 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -18922,7 +18922,7 @@ SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
 
 SDValue X86TargetLowering::LowerExternalSymbol(SDValue Op,
                                                SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 SDValue
@@ -18950,7 +18950,8 @@ X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const {
 /// Creates target global address or external symbol nodes for calls or
 /// other uses.
 SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                                 bool ForCall) const {
+                                                 bool ForCall,
+                                                 bool *IsImpCall) const {
   // Unpack the global address or external symbol.
   SDLoc dl(Op);
   const GlobalValue *GV = nullptr;
@@ -19000,6 +19001,16 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
   if (ForCall && !NeedsLoad && !HasPICReg && Offset == 0)
     return Result;
 
+  // If Import Call Optimization is enabled and this is an imported function
+  // then make a note of it and return the global address without wrapping.
+  if (IsImpCall && (OpFlags == X86II::MO_DLLIMPORT) &&
+      Mod.getModuleFlag("import-call-optimization")) {
+    assert(ForCall && "Should only enable import call optimization if we are "
+                      "lowering a call");
+    *IsImpCall = true;
+    return Result;
+  }
+
   Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);
 
   // With PIC, the address is actually $g + Offset.
@@ -19025,7 +19036,7 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
 
 SDValue
 X86TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 static SDValue GetTLSADDR(SelectionDAG &DAG, GlobalAddressSDNode *GA,
@@ -34562,6 +34573,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FST)
   NODE_NAME_CASE(CALL)
   NODE_NAME_CASE(CALL_RVMARKER)
+  NODE_NAME_CASE(IMP_CALL)
   NODE_NAME_CASE(BT)
   NODE_NAME_CASE(CMP)
   NODE_NAME_CASE(FCMP)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index fe79fefeed631c6..6324cc65398a0a4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -81,6 +81,10 @@ namespace llvm {
     // marker instruction.
     CALL_RVMARKER,
 
+    // Pseudo for a call to an imported function to ensure the correct machine
+    // instruction is emitted for Import Call Optimization.
+    IMP_CALL,
+
     /// X86 compare and logical compare instructions.
     CMP,
     FCMP,
@@ -1733,8 +1737,8 @@ namespace llvm {
 
     /// Creates target global address or external symbol nodes for calls or
     /// other uses.
-    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                  bool ForCall) const;
+    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG, bool ForCall,
+                                  bool *IsImpCall) const;
 
     SDValue LowerSINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerUINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
index 6835c7e336a5cb1..cbbdf37a3fb75f7 100644
--- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
+++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
@@ -2402,6 +2402,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     InGlue = Chain.getValue(1);
   }
 
+  bool IsImpCall = false;
   if (DAG.getTarget().getCodeModel() == CodeModel::Large) {
     assert(Is64Bit && "Large code model is only legal in 64-bit mode.");
     // In the 64-bit large code model, we have to make all calls
@@ -2414,7 +2415,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // ForCall to true here has the effect of removing WrapperRIP when possible
     // to allow direct calls to be selected without first materializing the
     // address into a register.
-    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true);
+    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true, &IsImpCall);
   } else if (Subtarget.isTarget64BitILP32() &&
              Callee.getValueType() == MVT::i32) {
     // Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI
@@ -2536,7 +2537,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   // Returns a chain & a glue for retval copy to use.
   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
-  if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
+  if (IsImpCall) {
+    Chain = DAG.getNode(X86ISD::IMP_CALL, dl, NodeTys, Ops);
+  } else if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
     Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops);
   } else if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
     // Calls with a "clang.arc.attachedcall" bundle are special. They should be
diff --git a/llvm/lib/Target/X86/X86InstrCompiler.td b/llvm/lib/Target/X86/X86InstrCompiler.td
index 9687ae29f1c782f..5f603de695906f1 100644
--- a/llvm/lib/Target/X86/X86InstrCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrCompiler.td
@@ -1309,6 +1309,8 @@ def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 texternalsym:$dst)),
 def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 tglobaladdr:$dst)),
           (CALL64pcrel32_RVMARKER tglobaladdr:$rvfunc, tglobaladdr:$dst)>;
 
+def : Pat<(X86imp_call (i64 tglobaladdr:$dst)),
+          (CALL64pcrel32 tglobaladdr:$dst)>;
 
 // Tailcall stuff. The TCRETURN instructions execute after the epilog, so they
 // can never use callee-saved registers. That is the purpose of the GR64_TC
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index ddbc7c55a6113b4..3ab820de78efcbe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -210,6 +210,9 @@ def X86call_rvmarker  : SDNode<"X86ISD::CALL_RVMARKER",     SDT_X86Call,
                         [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
                          SDNPVariadic]>;
 
+def X86imp_call  : SDNode<"X86ISD::IMP_CALL",     SDT_X86Call,
+                        [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
+                         SDNPVariadic]>;
 
 def X86NoTrackCall : SDNode<"X86ISD::NT_CALL", SDT_X86Call,
                             [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
diff --git a/llvm/lib/Target/X86/X86MCInstLower.cpp b/llvm/lib/Target/X86/X86MCInstLower.cpp
index 0f8fbf5be1c9557..f265093a60d12e9 100644
--- a/llvm/lib/Target/X86/X86MCInstLower.cpp
+++ b/llvm/lib/Target/X86/X86MCInstLower.cpp
@@ -47,6 +47,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
 #include <string>
@@ -112,7 +113,7 @@ struct NoAutoPaddingScope {
 static void emitX86Nops(MCStreamer &OS, unsigned NumBytes,
                         const X86Subtarget *Subtarget);
 
-void X86AsmPrinter::StackMapShadowTracker::count(MCInst &Inst,
+void X86AsmPrinter::StackMapShadowTracker::count(const MCInst &Inst,
                                                  const MCSubtargetInfo &STI,
                                                  MCCodeEmitter *CodeEmitter) {
   if (InShadow) {
@@ -2193,6 +2194,27 @@ static void addConstantComments(const MachineInstr *MI,
   }
 }
 
+bool isImportedFunction(const MachineOperand &MO) {
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_DLLIMPORT);
+}
+
+bool isCallToCFGuardFunction(const MachineInstr *MI) {
+  assert(MI->getOpcode() == X86::TAILJMPm64_REX ||
+         MI->getOpcode() == X86::CALL64m);
+  const MachineOperand &MO = MI->getOperand(3);
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_NO_FLAG) &&
+         isCFGuardFunction(MO.getGlobal());
+}
+
+bool hasJumpTableInfoInBlock(const llvm::MachineInstr *MI) {
+  const MachineBasicBlock &MBB = *MI->getParent();
+  for (auto I = MBB.instr_rbegin(), E = MBB.instr_rend(); I != E; ++I)
+    if (I->isJumpTableDebugInfo())
+      return true;
+
+  return false;
+}
+
 void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   // FIXME: Enable feature predicate checks once all the test pass.
   // X86_MC::verifyInstructionPredicates(MI->getOpcode(),
@@ -2271,20 +2293,64 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::TAILJMPd64:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
-    [[fallthrough]];
-  case X86::TAILJMPr:
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
+  case X86::TAILJMPm64_REX:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_BR_REX);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
   case X86::TAILJMPm:
   case X86::TAILJMPd:
   case X86::TAILJMPd_CC:
-  case X86::TAILJMPr64:
   case X86::TAILJMPm64:
   case X86::TAILJMPd64_CC:
-  case X86::TAILJMPr64_REX:
-  case X86::TAILJMPm64_REX:
     // Lower these as normal, but add some comments.
     OutStreamer->AddComment("TAILCALL");
     break;
 
+  case X86::TAILJMPr:
+  case X86::TAILJMPr64:
+  case X86::TAILJMPr64_REX: {
+    MCInst TmpInst;
+    MCInstLowering.Lower(MI, TmpInst);
+
+    if (EnableImportCallOptimization) {
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    EmitAndCountInstruction(TmpInst);
+    return;
+  }
+
+  case X86::JMP64r:
+  case X86::JMP64m:
+    if (EnableImportCallOptimization && hasJumpTableInfoInBlock(MI)) {
+      uint16_t EncodedReg =
+          this->getSubtarget().getRegisterInfo()->getEncodingValue(
+              MI->getOperand(0).getReg().asMCReg());
+      emitLabelAndRecordForImportCallOptimization(
+          (ImportCallKind)(IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST +
+                           EncodedReg));
+    }
+    break;
+
   case X86::TLS_addr32:
   case X86::TLS_addr64:
   case X86::TLS_addrX32:
@@ -2469,7 +2535,49 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::CALL64pcrel32:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_CALL);
+
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // For Import Call Optimization to work, we need a the call instruction
+      // with a rex prefix, and a 5-byte nop after the call instruction.
+      EmitAndCountInstruction(MCInstBuilder(X86::REX64_PREFIX));
+      emitCallInstruction(TmpInst);
+      emitNop(*OutStreamer, 5, Subtarget);
+      return;
+    }
+
     break;
+  case X86::CALL64r:
+    if (EnableImportCallOptimization) {
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_CALL);
+      emitCallInstruction(TmpInst);
+
+      // For Import Call Optimization to work, we a 3-byte nop after the call
+      // instruction.
+      emitNop(*OutStreamer, 3, Subtarget);
+      return;
+    }
+
+    break;
+  case X86::CALL64m:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_CALL);
+    }
+    break;
+
   case X86::JCC_1:
     // Two instruction prefixes (2EH for branch not-taken and 3EH for branch
     // taken) are used as branch hints. Here we add branch taken prefix for
@@ -2490,20 +2598,47 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   MCInst TmpInst;
   MCInstLowering.Lower(MI, TmpInst);
 
-  // Stackmap shadows cannot include branch targets, so we can count the bytes
-  // in a call towards the shadow, but must ensure that the no thread returns
-  // in to the stackmap shadow.  The only way to achieve this is if the call
-  // is at the end of the shadow.
   if (MI->isCall()) {
-    // Count then size of the call towards the shadow
-    SMShadowTracker.count(TmpInst, getSubtargetInfo(), CodeEmitter.get());
-    // Then flush the shadow so that we fill with nops before the call, not
-    // after it.
-    SMShadowTracker.emitShadowPadding(*OutStreamer, getSubtargetInfo());
-    // Then emit the call
-    OutStreamer->emitInstruction(TmpInst, getSubtargetInfo());
+    emitCallInstruction(TmpInst);
     return;
   }
 
   EmitAndCountInstruction(TmpInst);
 }
+
+void X86AsmPrinter::emitCallInstruction(const llvm::MCInst &MCI) {
+  // Stackmap shadows cannot include branch targets, so we can count the bytes
+  // in a call towards the shado...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mc

Author: Daniel Paoliello (dpaoliello)

Changes

This is the x64 equivalent of #121516

Since import call optimization was originally added to x64 Windows to implement a more efficient retpoline mitigation the section and constant names relating to this all mention "retpoline" and we need to mark indirect calls, control-flow guard calls and jumps for jump tables in the section alongside calls to imported functions.

As with the AArch64 feature, this emits a new section into the obj which is used by the MSVC linker to generate the Dynamic Value Relocation Table and the section itself does not appear in the final binary.

The Windows Loader requires a specific sequence of instructions be emitted when this feature is enabled:

  • Indirect calls/jumps must have the function pointer to jump to in rax.
  • Calls to imported functions must use the rex prefix and be followed by a 5-byte nop.
  • Indirect calls must be followed by a 3-byte nop.

Patch is 29.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126631.diff

15 Files Affected:

  • (modified) llvm/include/llvm/Transforms/CFGuard.h (+3)
  • (modified) llvm/lib/MC/MCObjectFileInfo.cpp (+5)
  • (modified) llvm/lib/Target/X86/X86AsmPrinter.cpp (+32)
  • (modified) llvm/lib/Target/X86/X86AsmPrinter.h (+32-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+15-3)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+6-2)
  • (modified) llvm/lib/Target/X86/X86ISelLoweringCall.cpp (+5-2)
  • (modified) llvm/lib/Target/X86/X86InstrCompiler.td (+2)
  • (modified) llvm/lib/Target/X86/X86InstrFragments.td (+3)
  • (modified) llvm/lib/Target/X86/X86MCInstLower.cpp (+152-17)
  • (modified) llvm/lib/Transforms/CFGuard/CFGuard.cpp (+13-2)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll (+83)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll (+21)
  • (added) llvm/test/CodeGen/X86/win-import-call-optimization.ll (+65)
  • (added) llvm/test/MC/X86/win-import-call-optimization.s (+69)
diff --git a/llvm/include/llvm/Transforms/CFGuard.h b/llvm/include/llvm/Transforms/CFGuard.h
index caf822a2ec9fb3a..b81db8f487965ff 100644
--- a/llvm/include/llvm/Transforms/CFGuard.h
+++ b/llvm/include/llvm/Transforms/CFGuard.h
@@ -16,6 +16,7 @@
 namespace llvm {
 
 class FunctionPass;
+class GlobalValue;
 
 class CFGuardPass : public PassInfoMixin<CFGuardPass> {
 public:
@@ -34,6 +35,8 @@ FunctionPass *createCFGuardCheckPass();
 /// Insert Control FLow Guard dispatches on indirect function calls.
 FunctionPass *createCFGuardDispatchPass();
 
+bool isCFGuardFunction(const GlobalValue *GV);
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp
index 150e38a94db6a67..334673c4dba79a9 100644
--- a/llvm/lib/MC/MCObjectFileInfo.cpp
+++ b/llvm/lib/MC/MCObjectFileInfo.cpp
@@ -599,6 +599,11 @@ void MCObjectFileInfo::initCOFFMCObjectFileInfo(const Triple &T) {
   if (T.getArch() == Triple::aarch64) {
     ImportCallSection =
         Ctx->getCOFFSection(".impcall", COFF::IMAGE_SCN_LNK_INFO);
+  } else if (T.getArch() == Triple::x86_64) {
+    // Import Call Optimization on x64 leverages the same metadata as the
+    // retpoline mitigation, hence the unusual section name.
+    ImportCallSection =
+        Ctx->getCOFFSection(".retplne", COFF::IMAGE_SCN_LNK_INFO);
   }
 
   // Debug info.
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.cpp b/llvm/lib/Target/X86/X86AsmPrinter.cpp
index f01e47b41cf5e44..52f8280b259650d 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.cpp
+++ b/llvm/lib/Target/X86/X86AsmPrinter.cpp
@@ -920,6 +920,9 @@ void X86AsmPrinter::emitStartOfAsmFile(Module &M) {
     OutStreamer->emitSymbolAttribute(S, MCSA_Global);
     OutStreamer->emitAssignment(
         S, MCConstantExpr::create(Feat00Value, MMI->getContext()));
+
+    if (M.getModuleFlag("import-call-optimization"))
+      EnableImportCallOptimization = true;
   }
   OutStreamer->emitSyntaxDirective();
 
@@ -1021,6 +1024,35 @@ void X86AsmPrinter::emitEndOfAsmFile(Module &M) {
     // safe to set.
     OutStreamer->emitAssemblerFlag(MCAF_SubsectionsViaSymbols);
   } else if (TT.isOSBinFormatCOFF()) {
+    // If import call optimization is enabled, emit the appropriate section.
+    // We do this whether or not we recorded any items.
+    if (EnableImportCallOptimization) {
+      OutStreamer->switchSection(getObjFileLowering().getImportCallSection());
+
+      // Section always starts with some magic.
+      constexpr char ImpCallMagic[12] = "RetpolineV1";
+      OutStreamer->emitBytes(StringRef{ImpCallMagic, sizeof(ImpCallMagic)});
+
+      // Layout of this section is:
+      // Per section that contains an item to record:
+      //  uint32_t SectionSize: Size in bytes for information in this section.
+      //  uint32_t Section Number
+      //  Per call to imported function in section:
+      //    uint32_t Kind: the kind of item.
+      //    uint32_t InstOffset: the offset of the instr in its parent section.
+      for (auto &[Section, CallsToImportedFuncs] :
+           SectionToImportedFunctionCalls) {
+        unsigned SectionSize =
+            sizeof(uint32_t) * (2 + 2 * CallsToImportedFuncs.size());
+        OutStreamer->emitInt32(SectionSize);
+        OutStreamer->emitCOFFSecNumber(Section->getBeginSymbol());
+        for (auto &[CallsiteSymbol, Kind] : CallsToImportedFuncs) {
+          OutStreamer->emitInt32(Kind);
+          OutStreamer->emitCOFFSecOffset(CallsiteSymbol);
+        }
+      }
+    }
+
     if (usesMSVCFloatingPoint(TT, M)) {
       // In Windows' libcmt.lib, there is a file which is linked in only if the
       // symbol _fltused is referenced. Linking this in causes some
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.h b/llvm/lib/Target/X86/X86AsmPrinter.h
index 693021eca329588..47e82c4dfcea5d4 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.h
+++ b/llvm/lib/Target/X86/X86AsmPrinter.h
@@ -31,6 +31,26 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   bool EmitFPOData = false;
   bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
   bool IndCSPrefix = false;
+  bool EnableImportCallOptimization = false;
+
+  enum ImportCallKind : unsigned {
+    IMAGE_RETPOLINE_AMD64_IMPORT_BR = 0x02,
+    IMAGE_RETPOLINE_AMD64_IMPORT_CALL = 0x03,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR = 0x04,
+    IMAGE_RETPOLINE_AMD64_INDIR_CALL = 0x05,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR_REX = 0x06,
+    IMAGE_RETPOLINE_AMD64_CFG_BR = 0x08,
+    IMAGE_RETPOLINE_AMD64_CFG_CALL = 0x09,
+    IMAGE_RETPOLINE_AMD64_CFG_BR_REX = 0x0A,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST = 0x010,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_LAST = 0x01F,
+  };
+  struct ImportCallInfo {
+    MCSymbol *CalleeSymbol;
+    ImportCallKind Kind;
+  };
+  DenseMap<MCSection *, std::vector<ImportCallInfo>>
+      SectionToImportedFunctionCalls;
 
   // This utility class tracks the length of a stackmap instruction's 'shadow'.
   // It is used by the X86AsmPrinter to ensure that the stackmap shadow
@@ -45,7 +65,7 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
     void startFunction(MachineFunction &MF) {
       this->MF = &MF;
     }
-    void count(MCInst &Inst, const MCSubtargetInfo &STI,
+    void count(const MCInst &Inst, const MCSubtargetInfo &STI,
                MCCodeEmitter *CodeEmitter);
 
     // Called to signal the start of a shadow of RequiredSize bytes.
@@ -126,6 +146,17 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   void emitMachOIFuncStubHelperBody(Module &M, const GlobalIFunc &GI,
                                     MCSymbol *LazyPointer) override;
 
+  void emitCallInstruction(const llvm::MCInst &MCI);
+
+  // Emits a label to mark the next instruction as being relevant to Import Call
+  // Optimization.
+  void emitLabelAndRecordForImportCallOptimization(ImportCallKind Kind);
+
+  // Ensure that rax is used as the operand for the given instruction.
+  //
+  // NOTE: This assumes that it is safe to clobber rax.
+  void ensureRaxUsedForOperand(MCInst &TmpInst);
+
 public:
   X86AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer);
 
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6cf6061deba7025..30a98a5a13ebfad 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -18922,7 +18922,7 @@ SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
 
 SDValue X86TargetLowering::LowerExternalSymbol(SDValue Op,
                                                SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 SDValue
@@ -18950,7 +18950,8 @@ X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const {
 /// Creates target global address or external symbol nodes for calls or
 /// other uses.
 SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                                 bool ForCall) const {
+                                                 bool ForCall,
+                                                 bool *IsImpCall) const {
   // Unpack the global address or external symbol.
   SDLoc dl(Op);
   const GlobalValue *GV = nullptr;
@@ -19000,6 +19001,16 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
   if (ForCall && !NeedsLoad && !HasPICReg && Offset == 0)
     return Result;
 
+  // If Import Call Optimization is enabled and this is an imported function
+  // then make a note of it and return the global address without wrapping.
+  if (IsImpCall && (OpFlags == X86II::MO_DLLIMPORT) &&
+      Mod.getModuleFlag("import-call-optimization")) {
+    assert(ForCall && "Should only enable import call optimization if we are "
+                      "lowering a call");
+    *IsImpCall = true;
+    return Result;
+  }
+
   Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);
 
   // With PIC, the address is actually $g + Offset.
@@ -19025,7 +19036,7 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
 
 SDValue
 X86TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 static SDValue GetTLSADDR(SelectionDAG &DAG, GlobalAddressSDNode *GA,
@@ -34562,6 +34573,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FST)
   NODE_NAME_CASE(CALL)
   NODE_NAME_CASE(CALL_RVMARKER)
+  NODE_NAME_CASE(IMP_CALL)
   NODE_NAME_CASE(BT)
   NODE_NAME_CASE(CMP)
   NODE_NAME_CASE(FCMP)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index fe79fefeed631c6..6324cc65398a0a4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -81,6 +81,10 @@ namespace llvm {
     // marker instruction.
     CALL_RVMARKER,
 
+    // Pseudo for a call to an imported function to ensure the correct machine
+    // instruction is emitted for Import Call Optimization.
+    IMP_CALL,
+
     /// X86 compare and logical compare instructions.
     CMP,
     FCMP,
@@ -1733,8 +1737,8 @@ namespace llvm {
 
     /// Creates target global address or external symbol nodes for calls or
     /// other uses.
-    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                  bool ForCall) const;
+    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG, bool ForCall,
+                                  bool *IsImpCall) const;
 
     SDValue LowerSINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerUINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
index 6835c7e336a5cb1..cbbdf37a3fb75f7 100644
--- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
+++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
@@ -2402,6 +2402,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     InGlue = Chain.getValue(1);
   }
 
+  bool IsImpCall = false;
   if (DAG.getTarget().getCodeModel() == CodeModel::Large) {
     assert(Is64Bit && "Large code model is only legal in 64-bit mode.");
     // In the 64-bit large code model, we have to make all calls
@@ -2414,7 +2415,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // ForCall to true here has the effect of removing WrapperRIP when possible
     // to allow direct calls to be selected without first materializing the
     // address into a register.
-    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true);
+    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true, &IsImpCall);
   } else if (Subtarget.isTarget64BitILP32() &&
              Callee.getValueType() == MVT::i32) {
     // Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI
@@ -2536,7 +2537,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   // Returns a chain & a glue for retval copy to use.
   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
-  if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
+  if (IsImpCall) {
+    Chain = DAG.getNode(X86ISD::IMP_CALL, dl, NodeTys, Ops);
+  } else if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
     Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops);
   } else if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
     // Calls with a "clang.arc.attachedcall" bundle are special. They should be
diff --git a/llvm/lib/Target/X86/X86InstrCompiler.td b/llvm/lib/Target/X86/X86InstrCompiler.td
index 9687ae29f1c782f..5f603de695906f1 100644
--- a/llvm/lib/Target/X86/X86InstrCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrCompiler.td
@@ -1309,6 +1309,8 @@ def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 texternalsym:$dst)),
 def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 tglobaladdr:$dst)),
           (CALL64pcrel32_RVMARKER tglobaladdr:$rvfunc, tglobaladdr:$dst)>;
 
+def : Pat<(X86imp_call (i64 tglobaladdr:$dst)),
+          (CALL64pcrel32 tglobaladdr:$dst)>;
 
 // Tailcall stuff. The TCRETURN instructions execute after the epilog, so they
 // can never use callee-saved registers. That is the purpose of the GR64_TC
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index ddbc7c55a6113b4..3ab820de78efcbe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -210,6 +210,9 @@ def X86call_rvmarker  : SDNode<"X86ISD::CALL_RVMARKER",     SDT_X86Call,
                         [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
                          SDNPVariadic]>;
 
+def X86imp_call  : SDNode<"X86ISD::IMP_CALL",     SDT_X86Call,
+                        [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
+                         SDNPVariadic]>;
 
 def X86NoTrackCall : SDNode<"X86ISD::NT_CALL", SDT_X86Call,
                             [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
diff --git a/llvm/lib/Target/X86/X86MCInstLower.cpp b/llvm/lib/Target/X86/X86MCInstLower.cpp
index 0f8fbf5be1c9557..f265093a60d12e9 100644
--- a/llvm/lib/Target/X86/X86MCInstLower.cpp
+++ b/llvm/lib/Target/X86/X86MCInstLower.cpp
@@ -47,6 +47,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
 #include <string>
@@ -112,7 +113,7 @@ struct NoAutoPaddingScope {
 static void emitX86Nops(MCStreamer &OS, unsigned NumBytes,
                         const X86Subtarget *Subtarget);
 
-void X86AsmPrinter::StackMapShadowTracker::count(MCInst &Inst,
+void X86AsmPrinter::StackMapShadowTracker::count(const MCInst &Inst,
                                                  const MCSubtargetInfo &STI,
                                                  MCCodeEmitter *CodeEmitter) {
   if (InShadow) {
@@ -2193,6 +2194,27 @@ static void addConstantComments(const MachineInstr *MI,
   }
 }
 
+bool isImportedFunction(const MachineOperand &MO) {
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_DLLIMPORT);
+}
+
+bool isCallToCFGuardFunction(const MachineInstr *MI) {
+  assert(MI->getOpcode() == X86::TAILJMPm64_REX ||
+         MI->getOpcode() == X86::CALL64m);
+  const MachineOperand &MO = MI->getOperand(3);
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_NO_FLAG) &&
+         isCFGuardFunction(MO.getGlobal());
+}
+
+bool hasJumpTableInfoInBlock(const llvm::MachineInstr *MI) {
+  const MachineBasicBlock &MBB = *MI->getParent();
+  for (auto I = MBB.instr_rbegin(), E = MBB.instr_rend(); I != E; ++I)
+    if (I->isJumpTableDebugInfo())
+      return true;
+
+  return false;
+}
+
 void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   // FIXME: Enable feature predicate checks once all the test pass.
   // X86_MC::verifyInstructionPredicates(MI->getOpcode(),
@@ -2271,20 +2293,64 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::TAILJMPd64:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
-    [[fallthrough]];
-  case X86::TAILJMPr:
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
+  case X86::TAILJMPm64_REX:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_BR_REX);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
   case X86::TAILJMPm:
   case X86::TAILJMPd:
   case X86::TAILJMPd_CC:
-  case X86::TAILJMPr64:
   case X86::TAILJMPm64:
   case X86::TAILJMPd64_CC:
-  case X86::TAILJMPr64_REX:
-  case X86::TAILJMPm64_REX:
     // Lower these as normal, but add some comments.
     OutStreamer->AddComment("TAILCALL");
     break;
 
+  case X86::TAILJMPr:
+  case X86::TAILJMPr64:
+  case X86::TAILJMPr64_REX: {
+    MCInst TmpInst;
+    MCInstLowering.Lower(MI, TmpInst);
+
+    if (EnableImportCallOptimization) {
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    EmitAndCountInstruction(TmpInst);
+    return;
+  }
+
+  case X86::JMP64r:
+  case X86::JMP64m:
+    if (EnableImportCallOptimization && hasJumpTableInfoInBlock(MI)) {
+      uint16_t EncodedReg =
+          this->getSubtarget().getRegisterInfo()->getEncodingValue(
+              MI->getOperand(0).getReg().asMCReg());
+      emitLabelAndRecordForImportCallOptimization(
+          (ImportCallKind)(IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST +
+                           EncodedReg));
+    }
+    break;
+
   case X86::TLS_addr32:
   case X86::TLS_addr64:
   case X86::TLS_addrX32:
@@ -2469,7 +2535,49 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::CALL64pcrel32:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_CALL);
+
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // For Import Call Optimization to work, we need a the call instruction
+      // with a rex prefix, and a 5-byte nop after the call instruction.
+      EmitAndCountInstruction(MCInstBuilder(X86::REX64_PREFIX));
+      emitCallInstruction(TmpInst);
+      emitNop(*OutStreamer, 5, Subtarget);
+      return;
+    }
+
     break;
+  case X86::CALL64r:
+    if (EnableImportCallOptimization) {
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_CALL);
+      emitCallInstruction(TmpInst);
+
+      // For Import Call Optimization to work, we a 3-byte nop after the call
+      // instruction.
+      emitNop(*OutStreamer, 3, Subtarget);
+      return;
+    }
+
+    break;
+  case X86::CALL64m:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_CALL);
+    }
+    break;
+
   case X86::JCC_1:
     // Two instruction prefixes (2EH for branch not-taken and 3EH for branch
     // taken) are used as branch hints. Here we add branch taken prefix for
@@ -2490,20 +2598,47 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   MCInst TmpInst;
   MCInstLowering.Lower(MI, TmpInst);
 
-  // Stackmap shadows cannot include branch targets, so we can count the bytes
-  // in a call towards the shadow, but must ensure that the no thread returns
-  // in to the stackmap shadow.  The only way to achieve this is if the call
-  // is at the end of the shadow.
   if (MI->isCall()) {
-    // Count then size of the call towards the shadow
-    SMShadowTracker.count(TmpInst, getSubtargetInfo(), CodeEmitter.get());
-    // Then flush the shadow so that we fill with nops before the call, not
-    // after it.
-    SMShadowTracker.emitShadowPadding(*OutStreamer, getSubtargetInfo());
-    // Then emit the call
-    OutStreamer->emitInstruction(TmpInst, getSubtargetInfo());
+    emitCallInstruction(TmpInst);
     return;
   }
 
   EmitAndCountInstruction(TmpInst);
 }
+
+void X86AsmPrinter::emitCallInstruction(const llvm::MCInst &MCI) {
+  // Stackmap shadows cannot include branch targets, so we can count the bytes
+  // in a call towards the shado...
[truncated]

@dpaoliello dpaoliello changed the title [x64][win] Add support for x64 import call optimization (equivalent to MSVC /d2guardretpoline) [x64][win] Add compiler support for x64 import call optimization (equivalent to MSVC /d2guardretpoline) Feb 11, 2025
.push_back({CallSiteSymbol, Kind});
}

void X86AsmPrinter::ensureRaxUsedForOperand(MCInst &TmpInst) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a weird way to handle this... why not use a call pseudo-instruction that constrains the input register, so the rest of the compiler is aware of what's happening here?

What happens if a parameter needs to be in RAX? That shouldn't happen for the C calling convention, but it can happen for alternative calling conventions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a little while to figure out how to do this - so please let me know if it's correct or if there's a better way to approach this.

What happens if a parameter needs to be in RAX? That shouldn't happen for the C calling convention, but it can happen for alternative calling conventions.

Import call optimization can't be enabled with indirect calls that need to use RAX for parameters - MSVC doesn't handle this in any elegant way (looks like it is UB), not sure what we want to do with it here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new pseudo-instruction looks fine.

If it's illegal to use certain calling conventions in /d2guardretpoline mode, we need to emit an error message if someone tries to use such a calling convention. I can't think of any other way to handle it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, added a check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this is a known condition we can never support, can we do something nicer than report_fatal_error? grep for DiagnosticInfoUnsupported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (EnableImportCallOptimization && hasJumpTableInfoInBlock(MI)) {
uint16_t EncodedReg =
this->getSubtarget().getRegisterInfo()->getEncodingValue(
MI->getOperand(0).getReg().asMCReg());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JMP64r and JMP64m seem very different... do they really use the same encoding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call: looks like x64 Windows will never use JMP64m for jump tables that's used only for non-PIC targets, so I can remove this.

IMAGE_RETPOLINE_AMD64_INDIR_CALL);
emitCallInstruction(TmpInst);

// For Import Call Optimization to work, we a 3-byte nop after the call
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// For Import Call Optimization to work, we a 3-byte nop after the call
// For Import Call Optimization to work, we need a 3-byte nop after the call

@dpaoliello dpaoliello force-pushed the x64impcall branch 3 times, most recently from 03ba173 to cbb936b Compare February 18, 2025 23:58
Copy link

github-actions bot commented Feb 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants