Skip to content

Commit

Permalink
* Disallow passing qvector(s) arguments to custom operations.
Browse files Browse the repository at this point in the history
Signed-off-by: Pradnya Khalate <[email protected]>
  • Loading branch information
khalatepradnya committed Dec 5, 2024
1 parent 9d7e20a commit 5ff555f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,11 @@ def bodyBuilder(iterVal):
targets = [self.popValue() for _ in range(numTargets)]
targets.reverse()

self.checkControlAndTargetTypes([], targets)
for i, t in enumerate(targets):
if not quake.RefType.isinstance(t.type):
self.emitFatalError(
f'invalid target operand {i}, broadcasting is not supported on custom operations.'
)

globalName = f'{nvqppPrefix}{node.func.id}_generator_{numTargets}.rodata'

Expand Down Expand Up @@ -2678,6 +2682,12 @@ def bodyBuilder(iterVal):
targets = [self.popValue() for _ in range(numTargets)]
targets.reverse()

for i, t in enumerate(targets):
if not quake.RefType.isinstance(t.type):
self.emitFatalError(
f'invalid target operand {i}, broadcasting is not supported on custom operations.'
)

globalName = f'{nvqppPrefix}{node.func.value.id}_generator_{numTargets}.rodata'

currentST = SymbolTable(self.module.operation)
Expand Down
39 changes: 39 additions & 0 deletions python/tests/custom/test_custom_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,45 @@ def bell():
error)


def test_bug_2452():
cudaq.register_operation("custom_i", np.array([1, 0, 0, 1]))

@cudaq.kernel
def kernel1():
qubits = cudaq.qvector(2)
custom_i(qubits)

with pytest.raises(RuntimeError) as error:
kernel1.compile()
assert 'broadcasting is not supported on custom operations' in repr(error)

cudaq.register_operation("custom_x", np.array([0, 1, 1, 0]))

@cudaq.kernel
def kernel2():
qubit = cudaq.qubit()
ancilla = cudaq.qvector(2)
x(ancilla)
custom_x.ctrl(ancilla, qubit) # `controls` can be `qvector`

counts = cudaq.sample(kernel2)
assert len(counts) == 1 and '111' in counts

cudaq.register_operation(
"custom_cz", np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
-1]))

@cudaq.kernel
def kernel3():
qubits = cudaq.qvector(2)
custom_cz(qubits)

with pytest.raises(RuntimeError) as error:
cudaq.sample(kernel3)
assert 'invalid number of arguments (1) passed to custom_cz (requires 2 arguments)' in repr(
error)


# leave for gdb debugging
if __name__ == "__main__":
loc = os.path.abspath(__file__)
Expand Down

0 comments on commit 5ff555f

Please sign in to comment.