Skip to content

Commit c9839bf

Browse files
Implements optimizations A.1 and A.2 for Quantum Shannon Decomposition (#7390)
Fixes #6777 I checked that with the current PR, the circuit depth after transpiler matches that obtained by Qiskit. The runtime is slow due `merge_single_qubit_gates_to_phased_x_and_z`. I wrote a different implementation for the special case used by the Shannon decomposition. With this, the runtime is halved. I believe `merge_single_qubit_gates_to_phased_x_and_z` can be optimized too. --------- Co-authored-by: Noureldin <noureldinyosri@google.com>
1 parent cedf227 commit c9839bf

File tree

4 files changed

+246
-42
lines changed

4 files changed

+246
-42
lines changed

cirq-core/cirq/transformers/analytical_decompositions/quantum_shannon_decomposition.py

Lines changed: 131 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,31 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Callable, Iterable, TYPE_CHECKING
24+
from typing import Callable, cast, Iterable, TYPE_CHECKING
2525

2626
import numpy as np
27+
from attr import define
2728
from scipy.linalg import cossin
2829

2930
from cirq import ops
3031
from cirq.circuits.frozen_circuit import FrozenCircuit
3132
from cirq.linalg import decompositions, predicates
3233
from cirq.protocols import unitary_protocol
33-
from cirq.transformers.analytical_decompositions.three_qubit_decomposition import (
34-
three_qubit_matrix_to_operations,
35-
)
3634
from cirq.transformers.analytical_decompositions.two_qubit_to_cz import (
3735
two_qubit_matrix_to_cz_operations,
36+
two_qubit_matrix_to_diagonal_and_cz_operations,
3837
)
3938

4039
if TYPE_CHECKING:
4140
import cirq
4241

4342

43+
@define
44+
class _TwoQubitGate:
45+
location: int
46+
matrix: np.ndarray
47+
48+
4449
def quantum_shannon_decomposition(
4550
qubits: list[cirq.Qid], u: np.ndarray, atol: float = 1e-8
4651
) -> Iterable[cirq.Operation]:
@@ -67,14 +72,12 @@ def quantum_shannon_decomposition(
6772
1. _single_qubit_decomposition
6873
OR
6974
(Recursive Case)
70-
1. _msb_demuxer
71-
2. _multiplexed_cossin
72-
3. _msb_demuxer
75+
1. _recursive_decomposition
7376
7477
Yields:
7578
A single 2-qubit or 1-qubit operations from OP TREE
7679
composed from the set
77-
{ CNOT, rz, ry, ZPowGate }
80+
{ CNOT, CZ, rz, ry, ZPowGate }
7881
7982
Raises:
8083
ValueError: If the u matrix is non-unitary
@@ -98,30 +101,92 @@ def quantum_shannon_decomposition(
98101
yield from _single_qubit_decomposition(qubits[0], u)
99102
return
100103

101-
if n == 4:
102-
operations = tuple(
103-
two_qubit_matrix_to_cz_operations(
104-
qubits[0], qubits[1], u, allow_partial_czs=True, clean_operations=True, atol=atol
105-
)
104+
# Collect all operations from the recursive decomposition
105+
shannon_decomp: list[cirq.Operation | list[cirq.Operation]] = [
106+
*_recursive_decomposition(qubits, u)
107+
]
108+
# Separate all 2-qubit generic gates while keeping track of location
109+
two_qubit_gates = [
110+
_TwoQubitGate(location=loc, matrix=unitary_protocol.unitary(o))
111+
for loc, o in enumerate(cast(list[ops.Operation], shannon_decomp))
112+
if isinstance(o.gate, ops.MatrixGate)
113+
]
114+
# Apply case A.2 from Shende et al.
115+
q0 = qubits[-2]
116+
q1 = qubits[-1]
117+
for idx in range(len(two_qubit_gates) - 1, 0, -1):
118+
diagonal, operations = two_qubit_matrix_to_diagonal_and_cz_operations(
119+
q0,
120+
q1,
121+
two_qubit_gates[idx].matrix,
122+
allow_partial_czs=True,
123+
clean_operations=True,
124+
atol=atol,
106125
)
107-
yield from operations
108-
i, j = np.unravel_index(np.argmax(np.abs(u)), u.shape)
109-
new_unitary = unitary_protocol.unitary(FrozenCircuit.from_moments(*operations))
110-
global_phase = np.angle(u[i, j]) - np.angle(new_unitary[i, j])
111-
if np.abs(global_phase) > 1e-9:
112-
yield ops.global_phase_operation(np.exp(1j * global_phase))
113-
return
126+
global_phase = _global_phase_difference(
127+
two_qubit_gates[idx].matrix, [ops.MatrixGate(diagonal)(q0, q1), *operations]
128+
)
129+
if not np.isclose(global_phase, 0, atol=atol):
130+
operations.append(ops.global_phase_operation(np.exp(1j * global_phase)))
131+
# Replace the generic gate with ops from OP TREE
132+
shannon_decomp[two_qubit_gates[idx].location] = operations
133+
# Join the diagonal with the unitary to be decomposed in the next step
134+
two_qubit_gates[idx - 1].matrix = diagonal @ two_qubit_gates[idx - 1].matrix
135+
if len(two_qubit_gates) > 0:
136+
operations = two_qubit_matrix_to_cz_operations(
137+
q0,
138+
q1,
139+
two_qubit_gates[0].matrix,
140+
allow_partial_czs=True,
141+
clean_operations=True,
142+
atol=atol,
143+
)
144+
global_phase = _global_phase_difference(two_qubit_gates[0].matrix, operations)
145+
if not np.isclose(global_phase, 0, atol=atol):
146+
operations.append(ops.global_phase_operation(np.exp(1j * global_phase)))
147+
shannon_decomp[two_qubit_gates[0].location] = operations
148+
# Yield the final operations in order
149+
yield from cast(Iterable[ops.Operation], ops.flatten_op_tree(shannon_decomp))
150+
151+
152+
def _recursive_decomposition(qubits: list[cirq.Qid], u: np.ndarray) -> Iterable[cirq.Operation]:
153+
"""Recursive step in the quantum shannon decomposition.
154+
155+
Decomposes n-qubit unitary into generic 2-qubit gates, CNOT, CZ and 1-qubit gates.
156+
All generic 2-qubit gates are applied to the two least significant qubits and
157+
are not decomposed further here.
158+
159+
Args:
160+
qubits: List of qubits in order of significance
161+
u: Numpy array for unitary matrix representing gate to be decomposed
162+
163+
Calls:
164+
1. _msb_demuxer
165+
2. _multiplexed_cossin
166+
3. _msb_demuxer
167+
168+
Yields:
169+
Generic 2-qubit gates or operations from {ry,rz,CNOT,CZ}.
114170
115-
if n == 8:
116-
operations = tuple(
117-
three_qubit_matrix_to_operations(qubits[0], qubits[1], qubits[2], u, atol=atol)
171+
Raises:
172+
ValueError: If the u matrix is not of shape (2^n,2^n)
173+
ValueError: If the u matrix is not of size at least 4
174+
"""
175+
n = u.shape[0]
176+
if n & (n - 1):
177+
raise ValueError(
178+
f"Expected input matrix u to be a (2^n x 2^n) shaped numpy array, \
179+
but instead got shape {u.shape}"
118180
)
119-
yield from operations
120-
i, j = np.unravel_index(np.argmax(np.abs(u)), u.shape)
121-
new_unitary = unitary_protocol.unitary(FrozenCircuit.from_moments(*operations))
122-
global_phase = np.angle(u[i, j]) - np.angle(new_unitary[i, j])
123-
if np.abs(global_phase) > 1e-9:
124-
yield ops.global_phase_operation(np.exp(1j * global_phase))
181+
182+
if n <= 2:
183+
raise ValueError(
184+
f"Expected input matrix u for recursive step to have size at least 4, \
185+
but it has size {n}"
186+
)
187+
188+
if n == 4:
189+
yield ops.MatrixGate(u).on(*qubits)
125190
return
126191

127192
# Perform a cosine-sine (linalg) decomposition on u
@@ -137,10 +202,30 @@ def quantum_shannon_decomposition(
137202
# Yield ops from multiplexed Ry part
138203
yield from _multiplexed_cossin(qubits, theta, ops.ry)
139204

205+
# Optimization A.1 in Shende et al. - the last CZ gate in the multiplexed Ry part
206+
# is merged into the generic multiplexor (u1, u2)
207+
# This gate is CZ(qubits[1], qubits[0]) = CZ(qubits[0], qubits[1])
208+
# as CZ is symmetric.
209+
# For the u1⊕u2 multiplexor operator:
210+
# as u1 is the operator in case qubits[0] = |0>,
211+
# and u2 is the operator in case qubits[0] = |1>
212+
# we can represent the merge by phasing u2 with Z ⊗ I
213+
cz_diag = np.concatenate((np.ones(n >> 2), np.full(n >> 2, -1)))
214+
u2 = u2 @ np.diag(cz_diag)
215+
140216
# Yield ops from decomposition of multiplexed u1/u2 part
141217
yield from _msb_demuxer(qubits, u1, u2)
142218

143219

220+
def _global_phase_difference(u: np.ndarray, ops: list[cirq.Operation]) -> float:
221+
"""Returns the difference in global phase between unitary u and
222+
a list of operations computing u.
223+
"""
224+
i, j = np.unravel_index(np.argmax(np.abs(u)), u.shape)
225+
new_unitary = unitary_protocol.unitary(FrozenCircuit.from_moments(*ops))
226+
return np.angle(u[i, j]) - np.angle(new_unitary[i, j])
227+
228+
144229
def _single_qubit_decomposition(qubit: cirq.Qid, u: np.ndarray) -> Iterable[cirq.Operation]:
145230
"""Decomposes single-qubit gate, and returns list of operations, keeping phase invariant.
146231
@@ -202,11 +287,14 @@ def _msb_demuxer(
202287
u2: Lower-right quadrant of total unitary to be decomposed (see diagram)
203288
204289
Calls:
205-
1. quantum_shannon_decomposition
290+
1. _recursive_decomposition
206291
2. _multiplexed_cossin
207-
3. quantum_shannon_decomposition
292+
3. _recursive_decomposition
208293
209-
Yields: Single operation from OP TREE of 2-qubit and 1-qubit operations
294+
Yields:
295+
Generic 2-qubit gates on the two least significant qubits,
296+
CNOT gates with the target not on the two least significant qubits,
297+
ry or rz
210298
"""
211299
# Perform a diagonalization to find values
212300
u1 = u1.astype(np.complex128)
@@ -231,15 +319,15 @@ def _msb_demuxer(
231319
# Last term is given by ( I ⊗ W ), demultiplexed
232320
# Remove most-significant (demuxed) control-qubit
233321
# Yield operations for QSD on W
234-
yield from quantum_shannon_decomposition(demux_qubits[1:], W, atol=1e-6)
322+
yield from _recursive_decomposition(demux_qubits[1:], W)
235323

236324
# Use complex phase of d_i to give theta_i (so d_i* gives -theta_i)
237325
# Observe that middle part looks like Σ_i( Rz(theta_i)⊗|i><i| )
238326
# Yield ops from multiplexed Rz part
239327
yield from _multiplexed_cossin(demux_qubits, -np.angle(d), ops.rz)
240328

241329
# Yield operations for QSD on V
242-
yield from quantum_shannon_decomposition(demux_qubits[1:], V, atol=1e-6)
330+
yield from _recursive_decomposition(demux_qubits[1:], V)
243331

244332

245333
def _nth_gray(n: int) -> int:
@@ -263,7 +351,7 @@ def _multiplexed_cossin(
263351
Calls:
264352
No major calls
265353
266-
Yields: Single operation from OP TREE from set 1- and 2-qubit gates: {ry,rz,CNOT}
354+
Yields: Single operation from OP TREE from set 1- and 2-qubit gates: {ry,rz,CNOT,CZ}
267355
"""
268356
# Most significant qubit is main qubit with rotation function applied
269357
main_qubit = cossin_qubits[0]
@@ -304,4 +392,11 @@ def _multiplexed_cossin(
304392
yield rot_func(rotation).on(main_qubit)
305393

306394
# Add a CNOT from the select qubit to the main qubit
307-
yield ops.CNOT(control_qubits[select_qubit], main_qubit)
395+
# Optimization A.1 in Shende et al. - use CZ instead of CNOT for ry rotations
396+
if rot_func == ops.ry:
397+
# Don't emit the last gate, as it will be merged into the generic multiplexor
398+
# in the cosine-sine decomposition
399+
if j < len(angles) - 1:
400+
yield ops.CZ(control_qubits[select_qubit], main_qubit)
401+
else:
402+
yield ops.CNOT(control_qubits[select_qubit], main_qubit)

cirq-core/cirq/transformers/analytical_decompositions/quantum_shannon_decomposition_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
import cirq
2222
from cirq.ops import common_gates
23+
from cirq.testing import random_two_qubit_circuit_with_czs
2324
from cirq.transformers.analytical_decompositions.quantum_shannon_decomposition import (
2425
_msb_demuxer,
2526
_multiplexed_cossin,
2627
_nth_gray,
28+
_recursive_decomposition,
2729
_single_qubit_decomposition,
2830
quantum_shannon_decomposition,
2931
)
@@ -49,6 +51,14 @@ def test_qsd_n_qubit_errors():
4951
cirq.Circuit(quantum_shannon_decomposition(qubits, np.ones((8, 8))))
5052

5153

54+
def test_recursive_decomposition_n_qubit_errors():
55+
qubits = [cirq.NamedQubit(f'q{i}') for i in range(3)]
56+
with pytest.raises(ValueError, match="shaped numpy array"):
57+
cirq.Circuit(_recursive_decomposition(qubits, np.eye(9)))
58+
with pytest.raises(ValueError, match="size at least 4"):
59+
cirq.Circuit(_recursive_decomposition(qubits, np.eye(2)))
60+
61+
5262
def test_random_single_qubit_decomposition():
5363
U = unitary_group.rvs(2)
5464
qubit = cirq.NamedQubit('q0')
@@ -80,10 +90,18 @@ def test_multiplexed_cossin():
8090
multiplexed_ry = np.array(multiplexed_ry)
8191
qubits = [cirq.NamedQubit(f'q{i}') for i in range(2)]
8292
circuit = cirq.Circuit(_multiplexed_cossin(qubits, [angle_1, angle_2]))
93+
# Add back the CZ gate removed by the A.1 optimization
94+
circuit += cirq.CZ(qubits[1], qubits[0])
8395
# Test return is equal to inital unitary
8496
assert cirq.approx_eq(multiplexed_ry, circuit.unitary(), atol=1e-9)
8597
# Test all operations in gate set
86-
gates = (common_gates.Rz, common_gates.Ry, common_gates.ZPowGate, common_gates.CXPowGate)
98+
gates = (
99+
common_gates.Rz,
100+
common_gates.Ry,
101+
common_gates.ZPowGate,
102+
common_gates.CXPowGate,
103+
common_gates.CZPowGate,
104+
)
87105
assert all(isinstance(op.gate, gates) for op in circuit.all_operations())
88106

89107

@@ -203,3 +221,17 @@ def test_qft5():
203221
)
204222
new_unitary = cirq.unitary(shannon_circuit)
205223
np.testing.assert_allclose(new_unitary, desired_unitary, atol=1e-6)
224+
225+
226+
def test_random_circuit_decomposition():
227+
qubits = cirq.LineQubit.range(3)
228+
test_circuit = (
229+
random_two_qubit_circuit_with_czs(3, qubits[0], qubits[1])
230+
+ random_two_qubit_circuit_with_czs(3, qubits[1], qubits[2])
231+
+ random_two_qubit_circuit_with_czs(3, qubits[0], qubits[2])
232+
)
233+
circuit = cirq.Circuit(quantum_shannon_decomposition(qubits, test_circuit.unitary()))
234+
# Test return is equal to initial unitary
235+
assert cirq.approx_eq(test_circuit.unitary(), circuit.unitary(), atol=1e-9)
236+
# Test all operations have at most 2 qubits.
237+
assert all(cirq.num_qubits(op) <= 2 for op in circuit.all_operations())

0 commit comments

Comments
 (0)