Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State synthesis for quantum devices #2291

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ac01dd1
DCO Remediation Commit for Ben Howe <[email protected]>
bmhowe23 Oct 11, 2024
21a87c1
State pointer synthesis for quantum hardware
annagrin Sep 17, 2024
3fc56de
Merge with main
annagrin Oct 17, 2024
7969a75
Merge with main
annagrin Oct 17, 2024
755d0d1
Fix test failure on anyon platform
annagrin Oct 17, 2024
dc5e77e
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 17, 2024
382bc99
Make StateInitialization a funcOp pass
annagrin Oct 17, 2024
d3a05d4
Fix issues and tests for the rest of quantum architectures
annagrin Oct 18, 2024
ac151f2
Merge with main
annagrin Oct 18, 2024
51ef054
Fix failing quantinuum state prep tests
annagrin Oct 18, 2024
0cdf3e9
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 18, 2024
5307aa4
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 21, 2024
a7f5387
Address CR comments
annagrin Oct 21, 2024
eb8db13
Merge with main
annagrin Oct 21, 2024
9f0937f
Format
annagrin Oct 21, 2024
2f3a623
Fix failing test
annagrin Oct 22, 2024
b381350
Format
annagrin Oct 22, 2024
dc87ca4
Format
annagrin Oct 22, 2024
e4c7735
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 22, 2024
53a34c9
Replaced getState intrinsic by cc.get_state op
annagrin Oct 22, 2024
30777f3
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Oct 22, 2024
fe6d409
Remove print
annagrin Oct 22, 2024
48704e3
Remove getCudaqState references
annagrin Oct 22, 2024
137f621
Minor updates
annagrin Oct 22, 2024
ad7c6bc
Fix failing quake test
annagrin Oct 23, 2024
83683f7
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Nov 4, 2024
78c0a44
Add a few state-related cc ops
annagrin Nov 5, 2024
6682c39
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into st…
annagrin Nov 5, 2024
102f819
Fix test_argument_conversion
annagrin Nov 5, 2024
6b2c015
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into st…
annagrin Nov 5, 2024
5ea1d97
Add printing in failing tests
annagrin Nov 5, 2024
074c60f
Add printing in failing tests
annagrin Nov 5, 2024
310f6ca
Fix failing tests
annagrin Nov 12, 2024
f0176ae
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into st…
annagrin Nov 12, 2024
d17fa6d
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Nov 12, 2024
3425182
Merge with state-ops
annagrin Nov 12, 2024
6fdccba
Add description for new algorithm for state syntesis
annagrin Nov 12, 2024
fc5e154
Merge with main
annagrin Jan 9, 2025
1dfa805
Fix tests
annagrin Jan 9, 2025
b67fc88
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Jan 9, 2025
9563371
Make intermediate IR legal by separating allocs
annagrin Jan 21, 2025
f32b066
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Jan 21, 2025
008e8c1
DCO Remediation Commit for Anna Gringauze <[email protected]>
annagrin Jan 21, 2025
84a4369
Merge branch 'main' of https://github.com/NVIDIA/cuda-quantum into qu…
annagrin Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1467,19 +1467,29 @@ def QuakeOp_GetStateOp : QuakeOp<"get_state", [Pure] > {
let summary = "Get state from kernel with the provided name.";
let description = [{
This operation is created by argument synthesis of state pointer arguments
for quantum devices. It takes a kernel name as ASCIIZ string literal value
and returns the kernel's quantum state. The operation is replaced by a call
to the kernel with the provided name in ReplaceStateByKernel pass.
for quantum devices.

It takes two kernel names as ASCIIZ string literals:
- "num_qubits" for determining the size of the allocation to initialize
- "init" for initializing the state the same way as the original kernel
passed to `cudaq::get_state`) as ASCIIZ string literal

And returns the quantum state of the original kernel passed to
`cudaq::get_state`. The operation is replaced by calls to the kernels with
the provided names in `ReplaceStateByKernel` pass.

```mlir
%0 = quake.get_state "callee" : !cc.ptr<!cc.state>
%0 = quake.get_state "num_qubits" "init" : !cc.ptr<!cc.state>
```
}];

let arguments = (ins StrAttr:$calleeName);
let arguments = (ins
StrAttr:$numQubitsFuncName,
StrAttr:$initFuncName
Comment on lines +1487 to +1488
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
StrAttr:$numQubitsFuncName,
StrAttr:$initFuncName
FlatSymbolRefAttr:$numQubitsFuncName,
FlatSymbolRefAttr:$initFuncName

If these are truly artifacts that shall be present in the IR, let's make them Symbol attrs.

Copy link
Collaborator Author

@annagrin annagrin Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the signatures of those functions are changed by the synthesis, would this instruction get updated with symbols with new signatures during application of the substitution? I can try that and see if the synthesis works.

);
let results = (outs cc_PointerType:$result);
Copy link
Collaborator

@schweitzpgi schweitzpgi Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to force this to be a ptr<state>, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do, thanks

let assemblyFormat = [{
$calleeName `:` qualified(type(results)) attr-dict
$numQubitsFuncName $initFuncName `:` qualified(type(results)) attr-dict
}];
}

Expand Down
7 changes: 7 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ std::unique_ptr<mlir::Pass>
createArgumentSynthesisPass(mlir::ArrayRef<mlir::StringRef> funcNames,
mlir::ArrayRef<mlir::StringRef> substitutions);

/// Helper function to build an argument synthesis pass. The names of the
/// functions and the substitutions text can be built as an unzipped pair of
/// lists.
std::unique_ptr<mlir::Pass>
createArgumentSynthesisPass(const std::vector<std::string> &funcNames,
const std::vector<std::string> &substitutions);
Comment on lines +65 to +66
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? ArrayRef should subsume rigid std::vector here. I don't think we need this overload?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::vector<string> does not get auto-converted to ArrayRef<StringRef>... I can try ArrayRef<std::string> instead


// declarative passes
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
Expand Down
48 changes: 43 additions & 5 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,8 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%c8_i64 = arith.constant 8 : i64
%0 = cc.address_of @foo.rodata_synth_0 : !cc.ptr<!cc.array<complex<f32> x 8>>
%4 = cc.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
%5 = cc.get_number_of_qubits %4 : (!cc.ptr<!cc.state>) -> i64
%4 = quake.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
%5 = quake.get_number_of_qubits %4 : (!cc.ptr<!cc.state>) -> i64
%6 = quake.alloca !quake.veq<?>[%5 : i64]
%7 = quake.init_state %6, %4 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>

Expand All @@ -805,7 +805,7 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
Before DeleteStates (delete-states):
``` func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
...
%4 = call @__nvqpp_cudaq_state_createFromData_fp32(%3, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
%4 = quake.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
call @__nvqpp__mlirgen__sub_kernel(%4) : (!cc.ptr<!cc.state>) -> ()
return
}
Expand All @@ -815,9 +815,47 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
```
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
...
%4 = call @__nvqpp_cudaq_state_createFromData_fp32(%3, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
%4 = quake.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
call @__nvqpp__mlirgen__sub_kernel(%4) : (!cc.ptr<!cc.state>) -> ()
call @__nvqpp_cudaq_state_delete(%4) : (!cc.ptr<!cc.state>) -> ()
quake.delete_state %3 : !cc.ptr<!cc.state>
return
}
```
}];
}

def ReplaceStateWithKernel : Pass<"replace-state-with-kernel", "mlir::func::FuncOp"> {
let summary =
"Replace `quake.init_state` instructions with call to the kernel generating the state";
let description = [{
Argument synthesis for state pointers for quantum devices substitutes state
argument by a new state created from `__nvqpp_cudaq_state_get` intrinsic, which
in turn accepts the name for the synthesized kernel that generated the state.

This optimization completes the replacement of `quake.init_state` instruction by:

- Replace `quake.init_state` by a call that `get_state` call refers to.
- Remove all unneeded instructions.

For example:

Before ReplaceStateWithKernel (replace-state-with-kernel):
```
func.func @foo() {
%0 = quake.get_state "callee.num_qubits_0" "callee.init_0": !cc.ptr<!cc.state>
%1 = quake.get_number_of_qubits %0 : (!cc.ptr<!cc.state>) -> i64
%2 = quake.alloca !quake.veq<?>[%1 : i64]
%3 = quake.init_state %2, %0 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
return
}
```

After ReplaceStateWithKernel (replace-state-with-kernel):
```
func.func @foo() {
%1 = call @callee.num_qubits_0() : () -> i64
%2 = quake.alloca !quake.veq<?>[%1 : i64]
%3 = call @callee.init_0(%2) : (!quake.veq<?>) -> !quake.veq<?>
return
}
```
Expand Down
9 changes: 9 additions & 0 deletions lib/Optimizer/Transforms/ArgumentSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,12 @@ cudaq::opt::createArgumentSynthesisPass(ArrayRef<StringRef> funcNames,
return std::make_unique<ArgumentSynthesisPass>(
ArgumentSynthesisOptions{pairs});
}

std::unique_ptr<mlir::Pass> cudaq::opt::createArgumentSynthesisPass(
const std::vector<std::string> &funcNames,
const std::vector<std::string> &substitutions) {
return cudaq::opt::createArgumentSynthesisPass(
mlir::SmallVector<mlir::StringRef>{funcNames.begin(), funcNames.end()},
mlir::SmallVector<mlir::StringRef>{substitutions.begin(),
substitutions.end()});
}
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ add_cudaq_library(OptTransforms
QuakeSynthesizer.cpp
RefToVeqAlloc.cpp
RegToMem.cpp
ReplaceStateWithKernel.cpp
StatePreparation.cpp
UnitarySynthesis.cpp
WiresToWiresets.cpp
Expand Down
132 changes: 132 additions & 0 deletions lib/Optimizer/Transforms/ReplaceStateWithKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include <span>

namespace cudaq::opt {
#define GEN_PASS_DEF_REPLACESTATEWITHKERNEL
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "replace-state-with-kernel"

using namespace mlir;

namespace {
// clang-format off
/// Replace `quake.get_number_of_qubits` by a call to a a function
/// that computes the number of qubits for a state.
///
/// ```
/// %0 = quake.get_state "callee.num_qubits_0" "callee.init_0" : !cc.ptr<!cc.state>
/// %1 = quake.get_number_of_qubits %0 : (!cc.ptr<!cc.state>) -> i64
/// ───────────────────────────────────────────
/// ...
/// %1 = call @callee.num_qubits_0() : () -> i64
/// ```
// clang-format on
class ReplaceGetNumQubitsPattern
: public OpRewritePattern<quake::GetNumberOfQubitsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::GetNumberOfQubitsOp numQubits,
PatternRewriter &rewriter) const override {

auto stateOp = numQubits.getOperand();
if (auto getState = stateOp.getDefiningOp<quake::GetStateOp>()) {
auto numQubitsName = getState.getNumQubitsFuncName();

rewriter.setInsertionPoint(numQubits);
rewriter.replaceOpWithNewOp<func::CallOp>(
numQubits, numQubits.getType(), numQubitsName, mlir::ValueRange{});
return success();
}
return numQubits->emitError(
"ReplaceStateWithKernel: failed to replace `quake.get_num_qubits`");
}
};

// clang-format off
/// Replace `quake.init_state` by a call to a (modified) kernel that produced
/// the state.
///
/// ```
/// %0 = quake.get_state "callee.num_qubits_0" "callee.init_0" : !cc.ptr<!cc.state>
/// %3 = quake.init_state %2, %0 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
/// ───────────────────────────────────────────
/// ...
/// %3 = call @callee.init_0(%2): (!quake.veq<?>) -> !quake.veq<?>
/// ```
// clang-format on
class ReplaceInitStatePattern
: public OpRewritePattern<quake::InitializeStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::InitializeStateOp initState,
PatternRewriter &rewriter) const override {
auto allocaOp = initState.getOperand(0);
auto stateOp = initState.getOperand(1);

if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(stateOp.getType())) {
if (isa<cudaq::cc::StateType>(ptrTy.getElementType())) {
if (auto getState = stateOp.getDefiningOp<quake::GetStateOp>()) {
auto initName = getState.getInitFuncName();

rewriter.setInsertionPoint(initState);
rewriter.replaceOpWithNewOp<func::CallOp>(
initState, initState.getType(), initName,
mlir::ValueRange{allocaOp});

return success();
}

return initState->emitError(
"ReplaceStateWithKernel: failed to replace `quake.init_state`");
}
}
return failure();
}
};

class ReplaceStateWithKernelPass
: public cudaq::opt::impl::ReplaceStateWithKernelBase<
ReplaceStateWithKernelPass> {
public:
using ReplaceStateWithKernelBase::ReplaceStateWithKernelBase;

void runOnOperation() override {
auto *ctx = &getContext();
auto func = getOperation();
RewritePatternSet patterns(ctx);
patterns.insert<ReplaceGetNumQubitsPattern, ReplaceInitStatePattern>(ctx);

LLVM_DEBUG(llvm::dbgs()
<< "Before replace state with kernel: " << func << '\n');

if (failed(applyPatternsAndFoldGreedily(func.getOperation(),
std::move(patterns))))
signalPassFailure();

LLVM_DEBUG(llvm::dbgs()
<< "After replace state with kernel: " << func << '\n');
}
};
} // namespace
5 changes: 3 additions & 2 deletions python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ class PyRemoteSimulationState : public RemoteSimulationState {
}
}

std::pair<std::string, std::vector<void *>> getKernelInfo() const override {
return {kernelName, argsData->getArgs()};
std::optional<std::pair<std::string, std::vector<void *>>>
getKernelInfo() const override {
return std::make_pair(kernelName, argsData->getArgs());
}

std::complex<double> overlap(const cudaq::SimulationState &other) override {
Expand Down
2 changes: 1 addition & 1 deletion python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ MlirModule synthesizeKernel(const std::string &name, MlirModule module,
auto isLocalSimulator = platform.is_simulator() && !platform.is_emulated();
auto isSimulator = isLocalSimulator || isRemoteSimulator;

cudaq::opt::ArgumentConverter argCon(name, unwrap(module), isSimulator);
cudaq::opt::ArgumentConverter argCon(name, unwrap(module));
argCon.gen(runtimeArgs.getArgs());
std::string kernName = cudaq::runtime::cudaqGenPrefixName + name;
SmallVector<StringRef> kernels = {kernName};
Expand Down
Loading
Loading