Skip to content

Commit 84224d8

Browse files
authored
Merge branch 'main' into protocol
2 parents c30ae4e + c9839bf commit 84224d8

21 files changed

+782
-117
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@
380380
merge_operations_to_circuit_op as merge_operations_to_circuit_op,
381381
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
382382
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
383+
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
383384
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
384385
optimize_for_target_gateset as optimize_for_target_gateset,
385386
parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations,

cirq-core/cirq/ops/common_gates.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@
3737
from cirq import protocols, value
3838
from cirq._compat import proper_repr
3939
from cirq._doc import document
40-
from cirq.ops import control_values as cv, controlled_gate, eigen_gate, gate_features, raw_types
40+
from cirq.ops import (
41+
control_values as cv,
42+
controlled_gate,
43+
eigen_gate,
44+
gate_features,
45+
global_phase_op,
46+
raw_types,
47+
)
4148
from cirq.ops.measurement_gate import MeasurementGate
4249
from cirq.ops.swap_gates import ISWAP, ISwapPowGate, SWAP, SwapPowGate
4350

@@ -235,6 +242,11 @@ def controlled(
235242
return cirq.CCXPowGate(exponent=self._exponent)
236243
return result
237244

245+
def _decompose_with_context_(
246+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
247+
) -> list[cirq.Operation] | NotImplementedType:
248+
return _extract_phase(self, XPowGate, qubits, context)
249+
238250
def _pauli_expansion_(self) -> value.LinearDict[str]:
239251
if self._dimension != 2:
240252
return NotImplemented # pragma: no cover
@@ -487,6 +499,11 @@ def __repr__(self) -> str:
487499
f'global_shift={self._global_shift!r})'
488500
)
489501

502+
def _decompose_with_context_(
503+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
504+
) -> list[cirq.Operation] | NotImplementedType:
505+
return _extract_phase(self, YPowGate, qubits, context)
506+
490507

491508
class Ry(YPowGate):
492509
r"""A gate with matrix $e^{-i Y t/2}$ that rotates around the Y axis of the Bloch sphere by $t$.
@@ -699,6 +716,11 @@ def controlled(
699716
return cirq.CCZPowGate(exponent=self._exponent)
700717
return result
701718

719+
def _decompose_with_context_(
720+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
721+
) -> list[cirq.Operation] | NotImplementedType:
722+
return _extract_phase(self, ZPowGate, qubits, context)
723+
702724
def _qid_shape_(self) -> tuple[int, ...]:
703725
return (self._dimension,)
704726

@@ -1131,6 +1153,11 @@ def controlled(
11311153
control_qid_shape=result.control_qid_shape + (2,),
11321154
)
11331155

1156+
def _decompose_with_context_(
1157+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
1158+
) -> list[cirq.Operation] | NotImplementedType:
1159+
return _extract_phase(self, CZPowGate, qubits, context)
1160+
11341161
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
11351162
return protocols.CircuitDiagramInfo(
11361163
wire_symbols=('@', '@'), exponent=self._diagram_exponent(args)
@@ -1486,3 +1513,25 @@ def _phased_x_or_pauli_gate(
14861513
case 0.5:
14871514
return YPowGate(exponent=exponent)
14881515
return cirq.ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)
1516+
1517+
1518+
def _extract_phase(
1519+
gate: cirq.EigenGate,
1520+
gate_class: type,
1521+
qubits: tuple[cirq.Qid, ...],
1522+
context: cirq.DecompositionContext,
1523+
) -> list[cirq.Operation] | NotImplementedType:
1524+
"""Extracts the global phase field to its own gate, or absorbs it if it has no effect.
1525+
1526+
This is for use within the decompose handlers, and will return `NotImplemented` if there is no
1527+
global phase, implying it is already in its simplest form. It will return a list, with the
1528+
original op minus any global phase first, and the global phase op second. If the resulting
1529+
global phase is empty (can happen for example in `XPowGate(global_phase=2/3)**3`), then it is
1530+
excluded from the return value."""
1531+
if not context.extract_global_phases or gate.global_shift == 0:
1532+
return NotImplemented
1533+
result = [gate_class(exponent=gate.exponent).on(*qubits)]
1534+
phase_gate = global_phase_op.from_phase_and_exponent(gate.global_shift, gate.exponent)
1535+
if not phase_gate.is_identity():
1536+
result.append(phase_gate())
1537+
return result

cirq-core/cirq/ops/common_gates_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,3 +1322,37 @@ def test_parameterized_pauli_expansion(gate_type, exponent) -> None:
13221322
gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5})
13231323
pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5})
13241324
assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved))
1325+
1326+
1327+
@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate])
1328+
@pytest.mark.parametrize('exponent', [0, 0.5, 2, 3, -0.5, -2, -3, sympy.Symbol('s')])
1329+
def test_decompose_with_extracted_phases(gate_type: type, exponent: cirq.TParamVal) -> None:
1330+
context = cirq.DecompositionContext(cirq.SimpleQubitManager(), extract_global_phases=True)
1331+
test_shift = 2 / 3 # Interesting because e.g. X(shift=2/3) ** 3 == X with no phase
1332+
gate = gate_type(exponent=exponent, global_shift=test_shift)
1333+
op = gate.on(*cirq.LineQubit.range(cirq.num_qubits(gate)))
1334+
decomposed = cirq.decompose(op, context=context)
1335+
1336+
# The first gate should be the original gate, but with shift removed.
1337+
gate0 = decomposed[0].gate
1338+
assert isinstance(gate0, gate_type)
1339+
assert isinstance(gate0, cirq.EigenGate)
1340+
assert gate0.global_shift == 0
1341+
assert gate0.exponent == exponent
1342+
if exponent % 3 == 0:
1343+
# Since test_shift == 2/3, gate**3 nullifies the phase, leaving only the unphased gate.
1344+
assert len(decomposed) == 1
1345+
else:
1346+
# Other exponents emit a global phase gate to compensate.
1347+
assert len(decomposed) == 2
1348+
gate1 = decomposed[1].gate
1349+
assert isinstance(gate1, cirq.GlobalPhaseGate)
1350+
assert gate1.coefficient == 1j ** (2 * exponent * test_shift)
1351+
1352+
# Sanity check that the decomposition is equivalent to the original.
1353+
decomposed_circuit = cirq.Circuit(decomposed)
1354+
if cirq.is_parameterized(exponent):
1355+
resolver = {'s': -1.234} # arbitrary
1356+
op = cirq.resolve_parameters(op, resolver)
1357+
decomposed_circuit = cirq.resolve_parameters(decomposed_circuit, resolver)
1358+
np.testing.assert_allclose(cirq.unitary(op), cirq.unitary(decomposed_circuit), atol=1e-10)

cirq-core/cirq/ops/controlled_gate.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
control_values as cv,
2525
controlled_operation as cop,
2626
diagonal_gate as dg,
27-
global_phase_op as gp,
2827
op_tree,
2928
raw_types,
3029
)
@@ -139,19 +138,35 @@ def num_controls(self) -> int:
139138
def _qid_shape_(self) -> tuple[int, ...]:
140139
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)
141140

142-
def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> None | NotImplementedType | cirq.OP_TREE:
143-
return self._decompose_with_context_(qubits)
144-
145141
def _decompose_with_context_(
146-
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext | None = None
147-
) -> None | NotImplementedType | cirq.OP_TREE:
142+
self, qubits: tuple[cirq.Qid, ...], context: cirq.DecompositionContext
143+
) -> NotImplementedType | cirq.OP_TREE:
148144
control_qubits = list(qubits[: self.num_controls()])
149145
controlled_sub_gate = self.sub_gate.controlled(
150146
self.num_controls(), self.control_values, self.control_qid_shape
151147
)
152148
# Prefer the subgate controlled version if available
153149
if self != controlled_sub_gate:
154150
return controlled_sub_gate.on(*qubits)
151+
152+
# Try decomposing the subgate next.
153+
result = protocols.decompose_once_with_qubits(
154+
self.sub_gate,
155+
qubits[self.num_controls() :],
156+
NotImplemented,
157+
flatten=False,
158+
# Extract global phases from decomposition, as controlled phases decompose easily.
159+
context=context.extracting_global_phases(),
160+
)
161+
if result is not NotImplemented:
162+
return op_tree.transform_op_tree(
163+
result,
164+
lambda op: op.controlled_by(
165+
*qubits[: self.num_controls()], control_values=self.control_values
166+
),
167+
)
168+
169+
# Finally try brute-force on the unitary.
155170
if protocols.has_unitary(self.sub_gate) and all(q.dimension == 2 for q in qubits):
156171
n_qubits = protocols.num_qubits(self.sub_gate)
157172
# Case 1: Global Phase (1x1 Matrix)
@@ -173,54 +188,9 @@ def _decompose_with_context_(
173188
protocols.unitary(self.sub_gate), control_qubits, qubits[-1]
174189
)
175190
return invert_ops + decomposed_ops + invert_ops
176-
if isinstance(self.sub_gate, common_gates.CZPowGate):
177-
z_sub_gate = common_gates.ZPowGate(exponent=self.sub_gate.exponent)
178-
num_controls = self.num_controls() + 1
179-
control_values = self.control_values & cv.ProductOfSums(((1,),))
180-
control_qid_shape = self.control_qid_shape + (2,)
181-
controlled_z = (
182-
z_sub_gate.controlled(
183-
num_controls=num_controls,
184-
control_values=control_values,
185-
control_qid_shape=control_qid_shape,
186-
)
187-
if protocols.is_parameterized(self)
188-
else ControlledGate(
189-
z_sub_gate,
190-
num_controls=num_controls,
191-
control_values=control_values,
192-
control_qid_shape=control_qid_shape,
193-
)
194-
)
195-
if self != controlled_z:
196-
result = controlled_z.on(*qubits)
197-
if self.sub_gate.global_shift == 0:
198-
return result
199-
# Reconstruct the controlled global shift of the subgate.
200-
total_shift = self.sub_gate.exponent * self.sub_gate.global_shift
201-
phase_gate = gp.GlobalPhaseGate(1j ** (2 * total_shift))
202-
controlled_phase_op = phase_gate.controlled(
203-
num_controls=self.num_controls(),
204-
control_values=self.control_values,
205-
control_qid_shape=self.control_qid_shape,
206-
).on(*control_qubits)
207-
return [result, controlled_phase_op]
208-
result = protocols.decompose_once_with_qubits(
209-
self.sub_gate,
210-
qubits[self.num_controls() :],
211-
NotImplemented,
212-
flatten=False,
213-
context=context,
214-
)
215-
if result is NotImplemented:
216-
return NotImplemented
217191

218-
return op_tree.transform_op_tree(
219-
result,
220-
lambda op: op.controlled_by(
221-
*qubits[: self.num_controls()], control_values=self.control_values
222-
),
223-
)
192+
# If nothing works, return `NotImplemented`.
193+
return NotImplemented
224194

225195
def on(self, *qubits: cirq.Qid) -> cop.ControlledOperation:
226196
if len(qubits) == 0:

cirq-core/cirq/ops/controlled_gate_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,31 @@ def test_controlled_global_phase_matrix_gate_decomposes(
804804
decomposed = cirq.decompose(cg_matrix(*all_qubits))
805805
assert not any(isinstance(op.gate, cirq.MatrixGate) for op in decomposed)
806806
np.testing.assert_allclose(cirq.unitary(cirq.Circuit(decomposed)), cirq.unitary(cg_matrix))
807+
808+
809+
@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate, cirq.CZPowGate])
810+
@pytest.mark.parametrize('test_shift', np.pi * (np.random.default_rng(324).random(10) * 2 - 1))
811+
def test_controlled_phase_extracted_before_decomposition(gate_type, test_shift) -> None:
812+
test_shift = 0.123 # arbitrary
813+
814+
shifted_gate = gate_type(global_shift=test_shift).controlled()
815+
unshifted_gate = gate_type().controlled()
816+
qs = cirq.LineQubit.range(cirq.num_qubits(shifted_gate))
817+
shifted_op = shifted_gate.on(*qs)
818+
unshifted_op = unshifted_gate.on(*qs)
819+
shifted_decomposition = cirq.decompose(shifted_op)
820+
unshifted_decomposition = cirq.decompose(unshifted_op)
821+
822+
# No brute-force calculation. It's the standard decomposition plus Z for the controlled shift.
823+
assert shifted_decomposition[:-1] == unshifted_decomposition
824+
z_op = shifted_decomposition[-1]
825+
assert z_op.qubits == (qs[0],)
826+
z = z_op.gate
827+
assert isinstance(z, cirq.ZPowGate)
828+
np.testing.assert_approx_equal(z.exponent, test_shift)
829+
assert z.global_shift == 0
830+
831+
# Sanity check that the decomposition is equivalent
832+
np.testing.assert_allclose(
833+
cirq.unitary(cirq.Circuit(shifted_decomposition)), cirq.unitary(shifted_op), atol=1e-10
834+
)

cirq-core/cirq/ops/global_phase_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ def _resolve_parameters_(
9292
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
9393
return GlobalPhaseGate(coefficient=coefficient)
9494

95+
def is_identity(self) -> bool:
96+
"""Checks if gate is equivalent to an identity.
97+
98+
Returns: True if the coefficient is within rounding error of 1.
99+
"""
100+
return not protocols.is_parameterized(self._coefficient) and np.isclose(self.coefficient, 1)
101+
95102
def controlled(
96103
self,
97104
num_controls: int | None = None,
@@ -122,3 +129,23 @@ def global_phase_operation(
122129
) -> cirq.GateOperation:
123130
"""Creates an operation that represents a global phase on the state."""
124131
return GlobalPhaseGate(coefficient, atol)()
132+
133+
134+
def from_phase_and_exponent(
135+
half_turns: cirq.TParamVal, exponent: cirq.TParamVal
136+
) -> cirq.GlobalPhaseGate:
137+
"""Creates a GlobalPhaseGate from the global phase and exponent.
138+
139+
Args:
140+
half_turns: The number of half turns to rotate by.
141+
exponent: The power to raise the phase to.
142+
143+
Returns: A `GlobalPhaseGate` with the corresponding coefficient.
144+
"""
145+
coefficient = 1j ** (2 * half_turns * exponent)
146+
coefficient = (
147+
complex(coefficient)
148+
if isinstance(coefficient, sympy.Expr) and coefficient.is_complex
149+
else coefficient
150+
)
151+
return GlobalPhaseGate(coefficient)

cirq-core/cirq/ops/global_phase_op_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sympy
2020

2121
import cirq
22+
from cirq.ops import global_phase_op
2223

2324

2425
def test_init() -> None:
@@ -304,3 +305,22 @@ def test_global_phase_gate_controlled(coeff, exp) -> None:
304305
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
305306
g, control_values=xor_control_values
306307
)
308+
309+
310+
def test_is_identity() -> None:
311+
g = cirq.GlobalPhaseGate(1)
312+
assert g.is_identity()
313+
g = cirq.GlobalPhaseGate(1j)
314+
assert not g.is_identity()
315+
g = cirq.GlobalPhaseGate(-1)
316+
assert not g.is_identity()
317+
318+
319+
def test_from_phase_and_exponent() -> None:
320+
g = global_phase_op.from_phase_and_exponent(2.5, 0.5)
321+
assert g.coefficient == np.exp(1.25j * np.pi)
322+
a, b = sympy.symbols('a, b')
323+
g = global_phase_op.from_phase_and_exponent(a, b)
324+
assert g.coefficient == 1j ** (2 * a * b)
325+
g = global_phase_op.from_phase_and_exponent(1 / a, a)
326+
assert g.coefficient == -1

cirq-core/cirq/ops/three_qubit_gates.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,11 @@ def _decompose_(self, qubits):
106106
elif not b.is_adjacent(c):
107107
a, b = b, a
108108

109-
p = common_gates.T**self._exponent
109+
exp = self._exponent
110+
p = common_gates.T**exp
110111
sweep_abc = [common_gates.CNOT(a, b), common_gates.CNOT(b, c)]
111-
global_phase = 1j ** (2 * self.global_shift * self._exponent)
112-
global_phase = (
113-
complex(global_phase)
114-
if protocols.is_parameterized(global_phase) and global_phase.is_complex
115-
else global_phase
116-
)
117-
global_phase_operation = (
118-
[global_phase_op.global_phase_operation(global_phase)]
119-
if protocols.is_parameterized(global_phase) or abs(global_phase - 1.0) > 0
120-
else []
121-
)
112+
global_phase_gate = global_phase_op.from_phase_and_exponent(self.global_shift, exp)
113+
global_phase_operation = [] if global_phase_gate.is_identity() else [global_phase_gate()]
122114
return global_phase_operation + [
123115
p(a),
124116
p(b),

cirq-core/cirq/protocols/decompose_protocol.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,16 @@ class DecompositionContext:
8181
Args:
8282
qubit_manager: A `cirq.QubitManager` instance to allocate clean / dirty ancilla qubits as
8383
part of the decompose protocol.
84+
extract_global_phases: If set, will extract the global phases from
85+
`DECOMPOSE_TARGET_GATESET` into independent global phase operations.
8486
"""
8587

8688
qubit_manager: cirq.QubitManager
89+
extract_global_phases: bool = False
90+
91+
def extracting_global_phases(self) -> DecompositionContext:
92+
"""Returns a copy with the `extract_global_phases` field set."""
93+
return dataclasses.replace(self, extract_global_phases=True)
8794

8895

8996
class SupportsDecompose(Protocol):

cirq-core/cirq/protocols/decompose_protocol_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,12 @@ def test_decompose_without_context_succeed() -> None:
445445
cirq.ops.CleanQubit(1, prefix='_decompose_protocol'),
446446
)
447447
]
448+
449+
450+
def test_extracting_global_phases() -> None:
451+
qm = cirq.SimpleQubitManager()
452+
context = cirq.DecompositionContext(qm)
453+
new_context = context.extracting_global_phases()
454+
assert not context.extract_global_phases
455+
assert new_context.extract_global_phases
456+
assert new_context.qubit_manager is qm

0 commit comments

Comments
 (0)