-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
State synthesis for quantum devices #2291
base: main
Are you sure you want to change the base?
Changes from all commits
ac01dd1
21a87c1
3fc56de
7969a75
755d0d1
dc5e77e
382bc99
d3a05d4
ac151f2
51ef054
0cdf3e9
5307aa4
a7f5387
eb8db13
9f0937f
2f3a623
b381350
dc87ca4
e4c7735
53a34c9
30777f3
fe6d409
48704e3
137f621
ad7c6bc
83683f7
78c0a44
6682c39
102f819
6b2c015
5ea1d97
074c60f
310f6ca
f0176ae
d17fa6d
3425182
6fdccba
fc5e154
1dfa805
b67fc88
9563371
f32b066
008e8c1
84a4369
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1467,19 +1467,29 @@ def QuakeOp_GetStateOp : QuakeOp<"get_state", [Pure] > { | |
let summary = "Get state from kernel with the provided name."; | ||
let description = [{ | ||
This operation is created by argument synthesis of state pointer arguments | ||
for quantum devices. It takes a kernel name as ASCIIZ string literal value | ||
and returns the kernel's quantum state. The operation is replaced by a call | ||
to the kernel with the provided name in ReplaceStateByKernel pass. | ||
for quantum devices. | ||
|
||
It takes two kernel names as ASCIIZ string literals: | ||
- "num_qubits" for determining the size of the allocation to initialize | ||
- "init" for initializing the state the same way as the original kernel | ||
passed to `cudaq::get_state`) as ASCIIZ string literal | ||
|
||
And returns the quantum state of the original kernel passed to | ||
`cudaq::get_state`. The operation is replaced by calls to the kernels with | ||
the provided names in `ReplaceStateByKernel` pass. | ||
|
||
```mlir | ||
%0 = quake.get_state "callee" : !cc.ptr<!cc.state> | ||
%0 = quake.get_state "num_qubits" "init" : !cc.ptr<!cc.state> | ||
``` | ||
}]; | ||
|
||
let arguments = (ins StrAttr:$calleeName); | ||
let arguments = (ins | ||
StrAttr:$numQubitsFuncName, | ||
StrAttr:$initFuncName | ||
); | ||
let results = (outs cc_PointerType:$result); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably want to force this to be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do, thanks |
||
let assemblyFormat = [{ | ||
$calleeName `:` qualified(type(results)) attr-dict | ||
$numQubitsFuncName $initFuncName `:` qualified(type(results)) attr-dict | ||
}]; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,13 @@ std::unique_ptr<mlir::Pass> | |
createArgumentSynthesisPass(mlir::ArrayRef<mlir::StringRef> funcNames, | ||
mlir::ArrayRef<mlir::StringRef> substitutions); | ||
|
||
/// Helper function to build an argument synthesis pass. The names of the | ||
/// functions and the substitutions text can be built as an unzipped pair of | ||
/// lists. | ||
std::unique_ptr<mlir::Pass> | ||
createArgumentSynthesisPass(const std::vector<std::string> &funcNames, | ||
const std::vector<std::string> &substitutions); | ||
Comment on lines
+65
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
// declarative passes | ||
#define GEN_PASS_DECL | ||
#define GEN_PASS_REGISTRATION | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * | ||
* All rights reserved. * | ||
* * | ||
* This source code and the accompanying materials are made available under * | ||
* the terms of the Apache License 2.0 which accompanies this distribution. * | ||
******************************************************************************/ | ||
|
||
#include "PassDetails.h" | ||
#include "cudaq/Optimizer/Builder/Intrinsics.h" | ||
#include "cudaq/Optimizer/Dialect/CC/CCOps.h" | ||
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" | ||
#include "cudaq/Optimizer/Transforms/Passes.h" | ||
#include "mlir/Dialect/Complex/IR/Complex.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include <span> | ||
|
||
namespace cudaq::opt { | ||
#define GEN_PASS_DEF_REPLACESTATEWITHKERNEL | ||
#include "cudaq/Optimizer/Transforms/Passes.h.inc" | ||
} // namespace cudaq::opt | ||
|
||
#define DEBUG_TYPE "replace-state-with-kernel" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
// clang-format off | ||
/// Replace `quake.get_number_of_qubits` by a call to a a function | ||
/// that computes the number of qubits for a state. | ||
/// | ||
/// ``` | ||
/// %0 = quake.get_state "callee.num_qubits_0" "callee.init_0" : !cc.ptr<!cc.state> | ||
/// %1 = quake.get_number_of_qubits %0 : (!cc.ptr<!cc.state>) -> i64 | ||
/// ─────────────────────────────────────────── | ||
/// ... | ||
/// %1 = call @callee.num_qubits_0() : () -> i64 | ||
/// ``` | ||
// clang-format on | ||
class ReplaceGetNumQubitsPattern | ||
: public OpRewritePattern<quake::GetNumberOfQubitsOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(quake::GetNumberOfQubitsOp numQubits, | ||
PatternRewriter &rewriter) const override { | ||
|
||
auto stateOp = numQubits.getOperand(); | ||
if (auto getState = stateOp.getDefiningOp<quake::GetStateOp>()) { | ||
auto numQubitsName = getState.getNumQubitsFuncName(); | ||
|
||
rewriter.setInsertionPoint(numQubits); | ||
rewriter.replaceOpWithNewOp<func::CallOp>( | ||
numQubits, numQubits.getType(), numQubitsName, mlir::ValueRange{}); | ||
return success(); | ||
} | ||
return numQubits->emitError( | ||
"ReplaceStateWithKernel: failed to replace `quake.get_num_qubits`"); | ||
} | ||
}; | ||
|
||
// clang-format off | ||
/// Replace `quake.init_state` by a call to a (modified) kernel that produced | ||
/// the state. | ||
/// | ||
/// ``` | ||
/// %0 = quake.get_state "callee.num_qubits_0" "callee.init_0" : !cc.ptr<!cc.state> | ||
/// %3 = quake.init_state %2, %0 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?> | ||
/// ─────────────────────────────────────────── | ||
/// ... | ||
/// %3 = call @callee.init_0(%2): (!quake.veq<?>) -> !quake.veq<?> | ||
/// ``` | ||
// clang-format on | ||
class ReplaceInitStatePattern | ||
: public OpRewritePattern<quake::InitializeStateOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(quake::InitializeStateOp initState, | ||
PatternRewriter &rewriter) const override { | ||
auto allocaOp = initState.getOperand(0); | ||
auto stateOp = initState.getOperand(1); | ||
|
||
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(stateOp.getType())) { | ||
if (isa<cudaq::cc::StateType>(ptrTy.getElementType())) { | ||
if (auto getState = stateOp.getDefiningOp<quake::GetStateOp>()) { | ||
auto initName = getState.getInitFuncName(); | ||
|
||
rewriter.setInsertionPoint(initState); | ||
rewriter.replaceOpWithNewOp<func::CallOp>( | ||
initState, initState.getType(), initName, | ||
mlir::ValueRange{allocaOp}); | ||
|
||
return success(); | ||
} | ||
|
||
return initState->emitError( | ||
"ReplaceStateWithKernel: failed to replace `quake.init_state`"); | ||
} | ||
} | ||
return failure(); | ||
} | ||
}; | ||
|
||
class ReplaceStateWithKernelPass | ||
: public cudaq::opt::impl::ReplaceStateWithKernelBase< | ||
ReplaceStateWithKernelPass> { | ||
public: | ||
using ReplaceStateWithKernelBase::ReplaceStateWithKernelBase; | ||
|
||
void runOnOperation() override { | ||
auto *ctx = &getContext(); | ||
auto func = getOperation(); | ||
RewritePatternSet patterns(ctx); | ||
patterns.insert<ReplaceGetNumQubitsPattern, ReplaceInitStatePattern>(ctx); | ||
|
||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Before replace state with kernel: " << func << '\n'); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(func.getOperation(), | ||
std::move(patterns)))) | ||
signalPassFailure(); | ||
|
||
LLVM_DEBUG(llvm::dbgs() | ||
<< "After replace state with kernel: " << func << '\n'); | ||
} | ||
}; | ||
} // namespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are truly artifacts that shall be present in the IR, let's make them Symbol attrs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the signatures of those functions are changed by the synthesis, would this instruction get updated with symbols with new signatures during application of the substitution? I can try that and see if the synthesis works.