Skip to content

Implements optimizations A.1 and A.2 for Quantum Shannon Decomposition #7390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,31 @@

from __future__ import annotations

from typing import Callable, Iterable, TYPE_CHECKING
from typing import Callable, cast, Iterable, TYPE_CHECKING

import numpy as np
from attr import define
from scipy.linalg import cossin

from cirq import ops
from cirq.circuits.frozen_circuit import FrozenCircuit
from cirq.linalg import decompositions, predicates
from cirq.protocols import unitary_protocol
from cirq.transformers.analytical_decompositions.three_qubit_decomposition import (
three_qubit_matrix_to_operations,
)
from cirq.transformers.analytical_decompositions.two_qubit_to_cz import (
two_qubit_matrix_to_cz_operations,
two_qubit_matrix_to_diagonal_and_cz_operations,
)

if TYPE_CHECKING:
import cirq


@define
class _TwoQubitGate:
location: int
matrix: np.ndarray


def quantum_shannon_decomposition(
qubits: list[cirq.Qid], u: np.ndarray, atol: float = 1e-8
) -> Iterable[cirq.Operation]:
Expand All @@ -67,14 +72,12 @@ def quantum_shannon_decomposition(
1. _single_qubit_decomposition
OR
(Recursive Case)
1. _msb_demuxer
2. _multiplexed_cossin
3. _msb_demuxer
1. _recursive_decomposition

Yields:
A single 2-qubit or 1-qubit operations from OP TREE
composed from the set
{ CNOT, rz, ry, ZPowGate }
{ CNOT, CZ, rz, ry, ZPowGate }

Raises:
ValueError: If the u matrix is non-unitary
Expand All @@ -98,30 +101,92 @@ def quantum_shannon_decomposition(
yield from _single_qubit_decomposition(qubits[0], u)
return

if n == 4:
operations = tuple(
two_qubit_matrix_to_cz_operations(
qubits[0], qubits[1], u, allow_partial_czs=True, clean_operations=True, atol=atol
)
# Collect all operations from the recursive decomposition
shannon_decomp: list[cirq.Operation | list[cirq.Operation]] = [
*_recursive_decomposition(qubits, u)
]
# Separate all 2-qubit generic gates while keeping track of location
two_qubit_gates = [
_TwoQubitGate(location=loc, matrix=unitary_protocol.unitary(o))
for loc, o in enumerate(cast(list[ops.Operation], shannon_decomp))
if isinstance(o.gate, ops.MatrixGate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why MatrixGate in particular? shouldn't is op.gate is not None be sufficient

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need a way to differentiate between end gates produced by the Shannon decomposition (such as CZ, CNOT), and 2-qubit gates that need to be further decomposed in the A.2 optimization step. A.2 requires splicing a diagonal matrix from the 2-gate decomposition and multiplying it with the previous 2-qubit gate. So I extract these gates in a separate list, perform A.2, and then replace the gates in the original list with the decomposition result.

]
# Apply case A.2 from Shende et al.
q0 = qubits[-2]
q1 = qubits[-1]
for idx in range(len(two_qubit_gates) - 1, 0, -1):
diagonal, operations = two_qubit_matrix_to_diagonal_and_cz_operations(
q0,
q1,
two_qubit_gates[idx].matrix,
allow_partial_czs=True,
clean_operations=True,
atol=atol,
)
yield from operations
i, j = np.unravel_index(np.argmax(np.abs(u)), u.shape)
new_unitary = unitary_protocol.unitary(FrozenCircuit.from_moments(*operations))
global_phase = np.angle(u[i, j]) - np.angle(new_unitary[i, j])
if np.abs(global_phase) > 1e-9:
yield ops.global_phase_operation(np.exp(1j * global_phase))
return
global_phase = _global_phase_difference(
two_qubit_gates[idx].matrix, [ops.MatrixGate(diagonal)(q0, q1), *operations]
)
if not np.isclose(global_phase, 0, atol=atol):
operations.append(ops.global_phase_operation(np.exp(1j * global_phase)))
# Replace the generic gate with ops from OP TREE
shannon_decomp[two_qubit_gates[idx].location] = operations
# Join the diagonal with the unitary to be decomposed in the next step
two_qubit_gates[idx - 1].matrix = diagonal @ two_qubit_gates[idx - 1].matrix
if len(two_qubit_gates) > 0:
operations = two_qubit_matrix_to_cz_operations(
q0,
q1,
two_qubit_gates[0].matrix,
allow_partial_czs=True,
clean_operations=True,
atol=atol,
)
global_phase = _global_phase_difference(two_qubit_gates[0].matrix, operations)
if not np.isclose(global_phase, 0, atol=atol):
operations.append(ops.global_phase_operation(np.exp(1j * global_phase)))
shannon_decomp[two_qubit_gates[0].location] = operations
# Yield the final operations in order
yield from cast(Iterable[ops.Operation], ops.flatten_op_tree(shannon_decomp))


def _recursive_decomposition(qubits: list[cirq.Qid], u: np.ndarray) -> Iterable[cirq.Operation]:
"""Recursive step in the quantum shannon decomposition.

Decomposes n-qubit unitary into generic 2-qubit gates, CNOT, CZ and 1-qubit gates.
All generic 2-qubit gates are applied to the two least significant qubits and
are not decomposed further here.

Args:
qubits: List of qubits in order of significance
u: Numpy array for unitary matrix representing gate to be decomposed

Calls:
1. _msb_demuxer
2. _multiplexed_cossin
3. _msb_demuxer

Yields:
Generic 2-qubit gates or operations from {ry,rz,CNOT,CZ}.

if n == 8:
operations = tuple(
three_qubit_matrix_to_operations(qubits[0], qubits[1], qubits[2], u, atol=atol)
Raises:
ValueError: If the u matrix is not of shape (2^n,2^n)
ValueError: If the u matrix is not of size at least 4
"""
n = u.shape[0]
if n & (n - 1):
raise ValueError(
f"Expected input matrix u to be a (2^n x 2^n) shaped numpy array, \
but instead got shape {u.shape}"
)
yield from operations
i, j = np.unravel_index(np.argmax(np.abs(u)), u.shape)
new_unitary = unitary_protocol.unitary(FrozenCircuit.from_moments(*operations))
global_phase = np.angle(u[i, j]) - np.angle(new_unitary[i, j])
if np.abs(global_phase) > 1e-9:
yield ops.global_phase_operation(np.exp(1j * global_phase))

if n <= 2:
raise ValueError(
f"Expected input matrix u for recursive step to have size at least 4, \
but it has size {n}"
)

if n == 4:
yield ops.MatrixGate(u).on(*qubits)
return

# Perform a cosine-sine (linalg) decomposition on u
Expand All @@ -137,10 +202,30 @@ def quantum_shannon_decomposition(
# Yield ops from multiplexed Ry part
yield from _multiplexed_cossin(qubits, theta, ops.ry)

# Optimization A.1 in Shende et al. - the last CZ gate in the multiplexed Ry part
# is merged into the generic multiplexor (u1, u2)
# This gate is CZ(qubits[1], qubits[0]) = CZ(qubits[0], qubits[1])
# as CZ is symmetric.
# For the u1⊕u2 multiplexor operator:
# as u1 is the operator in case qubits[0] = |0>,
# and u2 is the operator in case qubits[0] = |1>
# we can represent the merge by phasing u2 with Z ⊗ I
cz_diag = np.concatenate((np.ones(n >> 2), np.full(n >> 2, -1)))
u2 = u2 @ np.diag(cz_diag)

# Yield ops from decomposition of multiplexed u1/u2 part
yield from _msb_demuxer(qubits, u1, u2)


def _global_phase_difference(u: np.ndarray, ops: list[cirq.Operation]) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need this? I thought that the analytical decompositions now preserve global phase #6523, is this not the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to remove it, but unit tests are failing. Looking at #6523 , it seems to me it was mistakenly closed by #7118 . #7118 changes the MatrixGate decomposition, but #6523 requires updating two_qubit_to_cz, a different thing.

"""Returns the difference in global phase between unitary u and
a list of operations computing u.
"""
i, j = np.unravel_index(np.argmax(np.abs(u)), u.shape)
new_unitary = unitary_protocol.unitary(FrozenCircuit.from_moments(*ops))
return np.angle(u[i, j]) - np.angle(new_unitary[i, j])


def _single_qubit_decomposition(qubit: cirq.Qid, u: np.ndarray) -> Iterable[cirq.Operation]:
"""Decomposes single-qubit gate, and returns list of operations, keeping phase invariant.

Expand Down Expand Up @@ -202,11 +287,14 @@ def _msb_demuxer(
u2: Lower-right quadrant of total unitary to be decomposed (see diagram)

Calls:
1. quantum_shannon_decomposition
1. _recursive_decomposition
2. _multiplexed_cossin
3. quantum_shannon_decomposition
3. _recursive_decomposition

Yields: Single operation from OP TREE of 2-qubit and 1-qubit operations
Yields:
Generic 2-qubit gates on the two least significant qubits,
CNOT gates with the target not on the two least significant qubits,
ry or rz
"""
# Perform a diagonalization to find values
u1 = u1.astype(np.complex128)
Expand All @@ -231,15 +319,15 @@ def _msb_demuxer(
# Last term is given by ( I ⊗ W ), demultiplexed
# Remove most-significant (demuxed) control-qubit
# Yield operations for QSD on W
yield from quantum_shannon_decomposition(demux_qubits[1:], W, atol=1e-6)
yield from _recursive_decomposition(demux_qubits[1:], W)

# Use complex phase of d_i to give theta_i (so d_i* gives -theta_i)
# Observe that middle part looks like Σ_i( Rz(theta_i)⊗|i><i| )
# Yield ops from multiplexed Rz part
yield from _multiplexed_cossin(demux_qubits, -np.angle(d), ops.rz)

# Yield operations for QSD on V
yield from quantum_shannon_decomposition(demux_qubits[1:], V, atol=1e-6)
yield from _recursive_decomposition(demux_qubits[1:], V)


def _nth_gray(n: int) -> int:
Expand All @@ -263,7 +351,7 @@ def _multiplexed_cossin(
Calls:
No major calls

Yields: Single operation from OP TREE from set 1- and 2-qubit gates: {ry,rz,CNOT}
Yields: Single operation from OP TREE from set 1- and 2-qubit gates: {ry,rz,CNOT,CZ}
"""
# Most significant qubit is main qubit with rotation function applied
main_qubit = cossin_qubits[0]
Expand Down Expand Up @@ -304,4 +392,11 @@ def _multiplexed_cossin(
yield rot_func(rotation).on(main_qubit)

# Add a CNOT from the select qubit to the main qubit
yield ops.CNOT(control_qubits[select_qubit], main_qubit)
# Optimization A.1 in Shende et al. - use CZ instead of CNOT for ry rotations
if rot_func == ops.ry:
# Don't emit the last gate, as it will be merged into the generic multiplexor
# in the cosine-sine decomposition
if j < len(angles) - 1:
yield ops.CZ(control_qubits[select_qubit], main_qubit)
else:
yield ops.CNOT(control_qubits[select_qubit], main_qubit)
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import cirq
from cirq.ops import common_gates
from cirq.testing import random_two_qubit_circuit_with_czs
from cirq.transformers.analytical_decompositions.quantum_shannon_decomposition import (
_msb_demuxer,
_multiplexed_cossin,
_nth_gray,
_recursive_decomposition,
_single_qubit_decomposition,
quantum_shannon_decomposition,
)
Expand All @@ -49,6 +51,14 @@ def test_qsd_n_qubit_errors():
cirq.Circuit(quantum_shannon_decomposition(qubits, np.ones((8, 8))))


def test_recursive_decomposition_n_qubit_errors():
qubits = [cirq.NamedQubit(f'q{i}') for i in range(3)]
with pytest.raises(ValueError, match="shaped numpy array"):
cirq.Circuit(_recursive_decomposition(qubits, np.eye(9)))
with pytest.raises(ValueError, match="size at least 4"):
cirq.Circuit(_recursive_decomposition(qubits, np.eye(2)))


def test_random_single_qubit_decomposition():
U = unitary_group.rvs(2)
qubit = cirq.NamedQubit('q0')
Expand Down Expand Up @@ -80,10 +90,18 @@ def test_multiplexed_cossin():
multiplexed_ry = np.array(multiplexed_ry)
qubits = [cirq.NamedQubit(f'q{i}') for i in range(2)]
circuit = cirq.Circuit(_multiplexed_cossin(qubits, [angle_1, angle_2]))
# Add back the CZ gate removed by the A.1 optimization
circuit += cirq.CZ(qubits[1], qubits[0])
# Test return is equal to inital unitary
assert cirq.approx_eq(multiplexed_ry, circuit.unitary(), atol=1e-9)
# Test all operations in gate set
gates = (common_gates.Rz, common_gates.Ry, common_gates.ZPowGate, common_gates.CXPowGate)
gates = (
common_gates.Rz,
common_gates.Ry,
common_gates.ZPowGate,
common_gates.CXPowGate,
common_gates.CZPowGate,
)
assert all(isinstance(op.gate, gates) for op in circuit.all_operations())


Expand Down Expand Up @@ -203,3 +221,17 @@ def test_qft5():
)
new_unitary = cirq.unitary(shannon_circuit)
np.testing.assert_allclose(new_unitary, desired_unitary, atol=1e-6)


def test_random_circuit_decomposition():
qubits = cirq.LineQubit.range(3)
test_circuit = (
random_two_qubit_circuit_with_czs(3, qubits[0], qubits[1])
+ random_two_qubit_circuit_with_czs(3, qubits[1], qubits[2])
+ random_two_qubit_circuit_with_czs(3, qubits[0], qubits[2])
)
circuit = cirq.Circuit(quantum_shannon_decomposition(qubits, test_circuit.unitary()))
# Test return is equal to initial unitary
assert cirq.approx_eq(test_circuit.unitary(), circuit.unitary(), atol=1e-9)
# Test all operations have at most 2 qubits.
assert all(cirq.num_qubits(op) <= 2 for op in circuit.all_operations())
Loading
Loading