Skip to content

Commit cedf227

Browse files
Simplify decomposition of controlled eigengates with global phase (#7383)
Fixes #7238 1. Create a flag in `DecomposeContext` to configure extraction of terminal gates' global phases into distinct GlobalPhaseGates. a. This flag defaults to False for backward compatibility. 2. Instrument the terminal decomposition gates (X, Y, Z, CZ) to check this flag and extract global phase if it exists. a. Added some convenience methods in GlobalPhaseGate to help. b. Updated some unrelated existing code to use these convenience methods for clarity. 3. Update `ControlledGate.decompose` to attempt decompose the subgate before attempting the brute-force approach. a. This is what ultimately simplifies the decomposition result. b. Step 2 also allowed removal of the entire CZPowGate instanceof block there, since it was about extracting global phase. 4. Add unit tests for each change. --------- Co-authored-by: Noureldin <noureldinyosri@gmail.com>
1 parent cac6aad commit cedf227

9 files changed

+202
-66
lines changed

cirq-core/cirq/ops/common_gates.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@
3737
from cirq import protocols, value
3838
from cirq._compat import proper_repr
3939
from cirq._doc import document
40-
from cirq.ops import control_values as cv, controlled_gate, eigen_gate, gate_features, raw_types
40+
from cirq.ops import (
41+
control_values as cv,
42+
controlled_gate,
43+
eigen_gate,
44+
gate_features,
45+
global_phase_op,
46+
raw_types,
47+
)
4148
from cirq.ops.measurement_gate import MeasurementGate
4249
from cirq.ops.swap_gates import ISWAP, ISwapPowGate, SWAP, SwapPowGate
4350

@@ -235,6 +242,11 @@ def controlled(
235242
return cirq.CCXPowGate(exponent=self._exponent)
236243
return result
237244

245+
def _decompose_with_context_(
246+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
247+
) -> list[cirq.Operation] | NotImplementedType:
248+
return _extract_phase(self, XPowGate, qubits, context)
249+
238250
def _pauli_expansion_(self) -> value.LinearDict[str]:
239251
if self._dimension != 2:
240252
return NotImplemented # pragma: no cover
@@ -487,6 +499,11 @@ def __repr__(self) -> str:
487499
f'global_shift={self._global_shift!r})'
488500
)
489501

502+
def _decompose_with_context_(
503+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
504+
) -> list[cirq.Operation] | NotImplementedType:
505+
return _extract_phase(self, YPowGate, qubits, context)
506+
490507

491508
class Ry(YPowGate):
492509
r"""A gate with matrix $e^{-i Y t/2}$ that rotates around the Y axis of the Bloch sphere by $t$.
@@ -699,6 +716,11 @@ def controlled(
699716
return cirq.CCZPowGate(exponent=self._exponent)
700717
return result
701718

719+
def _decompose_with_context_(
720+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
721+
) -> list[cirq.Operation] | NotImplementedType:
722+
return _extract_phase(self, ZPowGate, qubits, context)
723+
702724
def _qid_shape_(self) -> tuple[int, ...]:
703725
return (self._dimension,)
704726

@@ -1131,6 +1153,11 @@ def controlled(
11311153
control_qid_shape=result.control_qid_shape + (2,),
11321154
)
11331155

1156+
def _decompose_with_context_(
1157+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
1158+
) -> list[cirq.Operation] | NotImplementedType:
1159+
return _extract_phase(self, CZPowGate, qubits, context)
1160+
11341161
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
11351162
return protocols.CircuitDiagramInfo(
11361163
wire_symbols=('@', '@'), exponent=self._diagram_exponent(args)
@@ -1486,3 +1513,25 @@ def _phased_x_or_pauli_gate(
14861513
case 0.5:
14871514
return YPowGate(exponent=exponent)
14881515
return cirq.ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)
1516+
1517+
1518+
def _extract_phase(
1519+
gate: cirq.EigenGate,
1520+
gate_class: type,
1521+
qubits: tuple[cirq.Qid, ...],
1522+
context: cirq.DecompositionContext,
1523+
) -> list[cirq.Operation] | NotImplementedType:
1524+
"""Extracts the global phase field to its own gate, or absorbs it if it has no effect.
1525+
1526+
This is for use within the decompose handlers, and will return `NotImplemented` if there is no
1527+
global phase, implying it is already in its simplest form. It will return a list, with the
1528+
original op minus any global phase first, and the global phase op second. If the resulting
1529+
global phase is empty (can happen for example in `XPowGate(global_phase=2/3)**3`), then it is
1530+
excluded from the return value."""
1531+
if not context.extract_global_phases or gate.global_shift == 0:
1532+
return NotImplemented
1533+
result = [gate_class(exponent=gate.exponent).on(*qubits)]
1534+
phase_gate = global_phase_op.from_phase_and_exponent(gate.global_shift, gate.exponent)
1535+
if not phase_gate.is_identity():
1536+
result.append(phase_gate())
1537+
return result

cirq-core/cirq/ops/common_gates_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,3 +1322,37 @@ def test_parameterized_pauli_expansion(gate_type, exponent) -> None:
13221322
gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5})
13231323
pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5})
13241324
assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved))
1325+
1326+
1327+
@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate])
1328+
@pytest.mark.parametrize('exponent', [0, 0.5, 2, 3, -0.5, -2, -3, sympy.Symbol('s')])
1329+
def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamVal) -> None:
1330+
context = cirq.DecompositionContext(cirq.SimpleQubitManager(), extract_global_phases=True)
1331+
test_shift = 2 / 3 # Interesting because e.g. X(shift=2/3) ** 3 == X with no phase
1332+
gate = gate_type(exponent=exponent, global_shift=test_shift)
1333+
op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate)))
1334+
decomposed = cirq.decompose(op, context=context)
1335+
1336+
# The first gate should be the original gate, but with shift removed.
1337+
gate0 = decomposed[0].gate
1338+
assert isinstance(gate0, gate_type)
1339+
assert isinstance(gate0, cirq.EigenGate)
1340+
assert gate0.global_shift == 0
1341+
assert gate0.exponent == exponent
1342+
if exponent % 3 == 0:
1343+
# Since test_shift == 2/3, gate**3 nullifies the phase, leaving only the unphased gate.
1344+
assert len(decomposed) == 1
1345+
else:
1346+
# Other exponents emit a global phase gate to compensate.
1347+
assert len(decomposed) == 2
1348+
gate1 = decomposed[1].gate
1349+
assert isinstance(gate1, cirq.GlobalPhaseGate)
1350+
assert gate1.coefficient == 1j ** (2 * exponent * test_shift)
1351+
1352+
# Sanity check that the decomposition is equivalent to the original.
1353+
decomposed_circuit = cirq.Circuit(decomposed)
1354+
if cirq.is_parameterized(exponent):
1355+
resolver = {'s': -1.234} # arbitrary
1356+
op = cirq.resolve_parameters(op, resolver)
1357+
decomposed_circuit = cirq.resolve_parameters(decomposed_circuit, resolver)
1358+
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10)

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
control_values as cv,
2525
controlled_operation as cop,
2626
diagonal_gate as dg,
27-
global_phase_op as gp,
2827
op_tree,
2928
raw_types,
3029
)
@@ -139,19 +138,35 @@ def num_controls(self) -> int:
139138
def _qid_shape_(self) -> tuple[int, ...]:
140139
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)
141140

142-
def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> None | NotImplementedType | cirq.OP_TREE:
143-
return self._decompose_with_context_(qubits)
144-
145141
def _decompose_with_context_(
146-
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | None = None
147-
) -> None | NotImplementedType | cirq.OP_TREE:
142+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
143+
) -> NotImplementedType | cirq.OP_TREE:
148144
control_qubits = list(qubits[: self.num_controls()])
149145
controlled_sub_gate = self.sub_gate.controlled(
150146
self.num_controls(), self.control_values, self.control_qid_shape
151147
)
152148
# Prefer the subgate controlled version if available
153149
if self != controlled_sub_gate:
154150
return controlled_sub_gate.on(*qubits)
151+
152+
# Try decomposing the subgate next.
153+
result = protocols.decompose_once_with_qubits(
154+
self.sub_gate,
155+
qubits[self.num_controls() :],
156+
NotImplemented,
157+
flatten=False,
158+
# Extract global phases from decomposition, as controlled phases decompose easily.
159+
context=context.extracting_global_phases(),
160+
)
161+
if result is not NotImplemented:
162+
return op_tree.transform_op_tree(
163+
result,
164+
lambda op: op.controlled_by(
165+
*qubits[: self.num_controls()], control_values=self.control_values
166+
),
167+
)
168+
169+
# Finally try brute-force on the unitary.
155170
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits):
156171
n_qubits = protocols.num_qubits(self.sub_gate)
157172
# Case 1: Global Phase (1x1 Matrix)
@@ -173,54 +188,9 @@ def _decompose_with_context_(
173188
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
174189
)
175190
return invert_ops + decomposed_ops + invert_ops
176-
if isinstance(self.sub_gate, common_gates.CZPowGate):
177-
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
178-
num_controls = self.num_controls() + 1
179-
control_values = self.control_values & cv.ProductOfSums(((1,),))
180-
control_qid_shape = self.control_qid_shape + (2,)
181-
controlled_z = (
182-
z_sub_gate.controlled(
183-
num_controls=num_controls,
184-
control_values=control_values,
185-
control_qid_shape=control_qid_shape,
186-
)
187-
if protocols.is_parameterized(self)
188-
else ControlledGate(
189-
z_sub_gate,
190-
num_controls=num_controls,
191-
control_values=control_values,
192-
control_qid_shape=control_qid_shape,
193-
)
194-
)
195-
if self != controlled_z:
196-
result = controlled_z.on(*qubits)
197-
if self.sub_gate.global_shift == 0:
198-
return result
199-
# Reconstruct the controlled global shift of the subgate.
200-
total_shift = self.sub_gate.exponent * self.sub_gate.global_shift
201-
phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift))
202-
controlled_phase_op = phase_gate.controlled(
203-
num_controls=self.num_controls(),
204-
control_values=self.control_values,
205-
control_qid_shape=self.control_qid_shape,
206-
).on(*control_qubits)
207-
return [result, controlled_phase_op]
208-
result = protocols.decompose_once_with_qubits(
209-
self.sub_gate,
210-
qubits[self.num_controls() :],
211-
NotImplemented,
212-
flatten=False,
213-
context=context,
214-
)
215-
if result is NotImplemented:
216-
return NotImplemented
217191

218-
return op_tree.transform_op_tree(
219-
result,
220-
lambda op: op.controlled_by(
221-
*qubits[: self.num_controls()], control_values=self.control_values
222-
),
223-
)
192+
# If nothing works, return `NotImplemented`.
193+
return NotImplemented
224194

225195
def on(self, *qubits: cirq.Qid) -> cop.ControlledOperation:
226196
if len(qubits) == 0:

cirq-core/cirq/ops/controlled_gate_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,31 @@ def test_controlled_global_phase_matrix_gate_decomposes(
804804
decomposed = cirq.decompose(cg_matrix(*all_qubits))
805805
assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed)
806806
np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix))
807+
808+
809+
@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate])
810+
@pytest.mark.parametrize('test_shift', np.pi * (np.random.default_rng(324).random(10) * 2 - 1))
811+
def test_controlled_phase_extracted_before_decomposition(gate_type, test_shift) -> None:
812+
test_shift = 0.123 # arbitrary
813+
814+
shifted_gate = gate_type(global_shift=test_shift).controlled()
815+
unshifted_gate = gate_type().controlled()
816+
qs = cirq.LineQubit.range(cirq.num_qubits(shifted_gate))
817+
shifted_op = shifted_gate.on(*qs)
818+
unshifted_op = unshifted_gate.on(*qs)
819+
shifted_decomposition = cirq.decompose(shifted_op)
820+
unshifted_decomposition = cirq.decompose(unshifted_op)
821+
822+
# No brute-force calculation. It's the standard decomposition plus Z for the controlled shift.
823+
assert shifted_decomposition[:-1] == unshifted_decomposition
824+
z_op = shifted_decomposition[-1]
825+
assert z_op.qubits == (qs[0],)
826+
z = z_op.gate
827+
assert isinstance(z, cirq.ZPowGate)
828+
np.testing.assert_approx_equal(z.exponent, test_shift)
829+
assert z.global_shift == 0
830+
831+
# Sanity check that the decomposition is equivalent
832+
np.testing.assert_allclose(
833+
cirq.unitary(cirq.Circuit(shifted_decomposition)), cirq.unitary(shifted_op), atol=1e-10
834+
)

cirq-core/cirq/ops/global_phase_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ def _resolve_parameters_(
9292
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
9393
return GlobalPhaseGate(coefficient=coefficient)
9494

95+
def is_identity(self) -> bool:
96+
"""Checks if gate is equivalent to an identity.
97+
98+
Returns: True if the coefficient is within rounding error of 1.
99+
"""
100+
return not protocols.is_parameterized(self._coefficient) and np.isclose(self.coefficient, 1)
101+
95102
def controlled(
96103
self,
97104
num_controls: int | None = None,
@@ -122,3 +129,23 @@ def global_phase_operation(
122129
) -> cirq.GateOperation:
123130
"""Creates an operation that represents a global phase on the state."""
124131
return GlobalPhaseGate(coefficient, atol)()
132+
133+
134+
def from_phase_and_exponent(
135+
half_turns: cirq.TParamVal, exponent: cirq.TParamVal
136+
) -> cirq.GlobalPhaseGate:
137+
"""Creates a GlobalPhaseGate from the global phase and exponent.
138+
139+
Args:
140+
half_turns: The number of half turns to rotate by.
141+
exponent: The power to raise the phase to.
142+
143+
Returns: A `GlobalPhaseGate` with the corresponding coefficient.
144+
"""
145+
coefficient = 1j ** (2 * half_turns * exponent)
146+
coefficient = (
147+
complex(coefficient)
148+
if isinstance(coefficient, sympy.Expr) and coefficient.is_complex
149+
else coefficient
150+
)
151+
return GlobalPhaseGate(coefficient)

cirq-core/cirq/ops/global_phase_op_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sympy
2020

2121
import cirq
22+
from cirq.ops import global_phase_op
2223

2324

2425
def test_init() -> None:
@@ -304,3 +305,22 @@ def test_global_phase_gate_controlled(coeff, exp) -> None:
304305
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
305306
g, control_values=xor_control_values
306307
)
308+
309+
310+
def test_is_identity() -> None:
311+
g = cirq.GlobalPhaseGate(1)
312+
assert g.is_identity()
313+
g = cirq.GlobalPhaseGate(1j)
314+
assert not g.is_identity()
315+
g = cirq.GlobalPhaseGate(-1)
316+
assert not g.is_identity()
317+
318+
319+
def test_from_phase_and_exponent() -> None:
320+
g = global_phase_op.from_phase_and_exponent(2.5, 0.5)
321+
assert g.coefficient == np.exp(1.25j * np.pi)
322+
a, b = sympy.symbols('a, b')
323+
g = global_phase_op.from_phase_and_exponent(a, b)
324+
assert g.coefficient == 1j ** (2 * a * b)
325+
g = global_phase_op.from_phase_and_exponent(1 / a, a)
326+
assert g.coefficient == -1

cirq-core/cirq/ops/three_qubit_gates.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,11 @@ def _decompose_(self, qubits):
106106
elif not b.is_adjacent(c):
107107
a, b = b, a
108108

109-
p = common_gates.T**self._exponent
109+
exp = self._exponent
110+
p = common_gates.T**exp
110111
sweep_abc = [common_gates.CNOT(a, b), common_gates.CNOT(b, c)]
111-
global_phase = 1j ** (2 * self.global_shift * self._exponent)
112-
global_phase = (
113-
complex(global_phase)
114-
if protocols.is_parameterized(global_phase) and global_phase.is_complex
115-
else global_phase
116-
)
117-
global_phase_operation = (
118-
[global_phase_op.global_phase_operation(global_phase)]
119-
if protocols.is_parameterized(global_phase) or abs(global_phase - 1.0) > 0
120-
else []
121-
)
112+
global_phase_gate = global_phase_op.from_phase_and_exponent(self.global_shift, exp)
113+
global_phase_operation = [] if global_phase_gate.is_identity() else [global_phase_gate()]
122114
return global_phase_operation + [
123115
p(a),
124116
p(b),

cirq-core/cirq/protocols/decompose_protocol.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,16 @@ class DecompositionContext:
8181
Args:
8282
qubit_manager: A `cirq.QubitManager` instance to allocate clean / dirty ancilla qubits as
8383
part of the decompose protocol.
84+
extract_global_phases: If set, will extract the global phases from
85+
`DECOMPOSE_TARGET_GATESET` into independent global phase operations.
8486
"""
8587

8688
qubit_manager: cirq.QubitManager
89+
extract_global_phases: bool = False
90+
91+
def extracting_global_phases(self) -> DecompositionContext:
92+
"""Returns a copy with the `extract_global_phases` field set."""
93+
return dataclasses.replace(self, extract_global_phases=True)
8794

8895

8996
class SupportsDecompose(Protocol):

cirq-core/cirq/protocols/decompose_protocol_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,12 @@ def test_decompose_without_context_succeed() -> None:
445445
cirq.ops.CleanQubit(1, prefix='_decompose_protocol'),
446446
)
447447
]
448+
449+
450+
def test_extracting_global_phases() -> None:
451+
qm = cirq.SimpleQubitManager()
452+
context = cirq.DecompositionContext(qm)
453+
new_context = context.extracting_global_phases()
454+
assert not context.extract_global_phases
455+
assert new_context.extract_global_phases
456+
assert new_context.qubit_manager is qm

0 commit comments

Comments
 (0)