diff --git a/llpc/lower/llpcSpirvLowerGlobal.cpp b/llpc/lower/llpcSpirvLowerGlobal.cpp index 6939fb06d0..e20271e652 100644 --- a/llpc/lower/llpcSpirvLowerGlobal.cpp +++ b/llpc/lower/llpcSpirvLowerGlobal.cpp @@ -590,27 +590,44 @@ void SpirvLowerGlobal::mapGlobalVariableToProxy(GlobalVariable *globalVar) { const auto &dataLayout = m_module->getDataLayout(); Type *globalVarTy = globalVar->getValueType(); - m_builder->SetInsertPointPastAllocas(m_entryPoint); - Value *proxy = nullptr; - + assert(m_entryPoint); + removeConstantExpr(m_context, globalVar); // Handle special globals, regular allocas will be removed by SROA pass. - if (globalVar->getName().startswith(RtName::HitAttribute)) + if (globalVar->getName().startswith(RtName::HitAttribute)) { proxy = m_entryPoint->getArg(1); - else if (globalVar->getName().startswith(RtName::IncomingRayPayLoad)) + globalVar->replaceAllUsesWith(proxy); + } else if (globalVar->getName().startswith(RtName::IncomingRayPayLoad)) { proxy = m_entryPoint->getArg(0); - else if (globalVar->getName().startswith(RtName::IncomingCallableData)) + globalVar->replaceAllUsesWith(proxy); + } else if (globalVar->getName().startswith(RtName::IncomingCallableData)) { proxy = m_entryPoint->getArg(0); - else - proxy = m_builder->CreateAlloca(globalVarTy, dataLayout.getAllocaAddrSpace(), nullptr, - Twine(LlpcName::GlobalProxyPrefix) + globalVar->getName()); + globalVar->replaceAllUsesWith(proxy); + } else { + // Collect used functions + SmallSet funcs; + for (User *user : globalVar->users()) { + auto inst = cast(user); + funcs.insert(inst->getFunction()); + } + for (Function *func : funcs) { + m_builder->SetInsertPointPastAllocas(func); + proxy = m_builder->CreateAlloca(globalVarTy, dataLayout.getAllocaAddrSpace(), nullptr, + Twine(LlpcName::GlobalProxyPrefix) + globalVar->getName()); - if (globalVar->hasInitializer()) { - auto initializer = globalVar->getInitializer(); - m_builder->CreateStore(initializer, proxy); + if (globalVar->hasInitializer()) { + auto initializer = globalVar->getInitializer(); + m_builder->CreateStore(initializer, proxy); + } + globalVar->mutateType(proxy->getType()); + globalVar->replaceUsesWithIf(proxy, [func](Use &U) { + Instruction *userInst = cast(U.getUser()); + return userInst->getFunction() == func; + }); + } } - m_globalVarProxyMap[globalVar] = proxy; + m_globalVarProxy.insert(globalVar); } // ===================================================================================================================== @@ -687,20 +704,17 @@ void SpirvLowerGlobal::mapOutputToProxy(GlobalVariable *output) { // ===================================================================================================================== // Does lowering operations for SPIR-V global variables, replaces global variables with proxy variables. void SpirvLowerGlobal::lowerGlobalVar() { - if (m_globalVarProxyMap.empty()) { + if (m_globalVarProxy.empty()) { // Skip lowering if there is no global variable return; } - // Replace global variable with proxy variable - for (auto globalVarMap : m_globalVarProxyMap) { - auto globalVar = cast(globalVarMap.first); - auto proxy = globalVarMap.second; - globalVar->mutateType(proxy->getType()); // To clear address space for pointer to make replacement valid - globalVar->replaceAllUsesWith(proxy); + // remove global variables + for (auto globalVar : m_globalVarProxy) { globalVar->dropAllReferences(); globalVar->eraseFromParent(); } + m_globalVarProxy.clear(); } // ===================================================================================================================== diff --git a/llpc/lower/llpcSpirvLowerGlobal.h b/llpc/lower/llpcSpirvLowerGlobal.h index d434c953fc..720ba54684 100644 --- a/llpc/lower/llpcSpirvLowerGlobal.h +++ b/llpc/lower/llpcSpirvLowerGlobal.h @@ -115,8 +115,8 @@ class SpirvLowerGlobal : public SpirvLower, public llvm::PassInfoMixin m_globalVarProxyMap; // Proxy map for lowering global variables - std::unordered_map m_inputProxyMap; // Proxy map for lowering inputs + std::unordered_set m_globalVarProxy; // The unordered_set for lowering global variables + std::unordered_map m_inputProxyMap; // Proxy map for lowering inputs // NOTE: Here we use list to store pairs of output proxy mappings. This is because we want output patching to be // "ordered" (resulting LLVM IR for the patching always be consistent).