Skip to content

Fix circuit reversal in stratified_circuit in cirq-core/cirq/transformers/stratify.py #7531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,7 @@ class Circuit(AbstractCircuit):
* batch_remove
* batch_insert_into
* insert_at_frontier
* reverse

Circuits can also be iterated over,

Expand Down Expand Up @@ -2523,6 +2524,16 @@ def clear_operations_touching(
self._moments[k] = self._moments[k].without_operations_touching(qubits)
self._mutated()

def reverse(self) -> None:
"""Reverses the moments in the circuit, and the operations in the moments."""
# Work on a copy in case validation fails halfway through.
copy = self.copy()
backwards = []
for moment in copy[::-1]:
backwards.append(Moment(reversed(moment.operations)))
self._moments = backwards
self._mutated()

@property
def moments(self) -> Sequence[cirq.Moment]:
return self._moments
Expand Down Expand Up @@ -2558,6 +2569,8 @@ def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit:
# Keep moments aligned
c_noisy += Circuit(op_tree)
return c_noisy




def _pick_inserted_ops_moment_indices(
Expand Down
167 changes: 167 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,173 @@ def test_clear_operations_touching() -> None:
]
)

def test_reverse_empty_circuit():
circuit = cirq.Circuit()
circuit.reverse()
assert len(circuit) == 0
assert circuit == cirq.Circuit()

def test_reverse_single_moment_single_operation():
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit(cirq.X(q))
original_str = str(circuit)

circuit.reverse()

assert str(circuit) == original_str
assert len(circuit) == 1

def test_reverse_single_moment_multiple_operations():
"""Test reversing a circuit with one moment and multiple operations."""
q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)
original_ops = [cirq.X(q0), cirq.Y(q1), cirq.Z(q2)]
circuit = cirq.Circuit(cirq.Moment(original_ops))

circuit.reverse()

# Moment order unchanged (only one moment), but operations reversed
assert len(circuit) == 1
reversed_ops = list(circuit[0])
assert reversed_ops == list(reversed(original_ops))

def test_reverse_multiple_moments_single_operations():
"""Test reversing a circuit with multiple moments, each with single operations."""
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit([
cirq.Moment([cirq.X(q)]),
cirq.Moment([cirq.Y(q)]),
cirq.Moment([cirq.Z(q)])
])

original_moments = [str(moment) for moment in circuit]
circuit.reverse()

# Moments should be reversed
assert len(circuit) == 3
reversed_moments = [str(moment) for moment in circuit]
assert reversed_moments == list(reversed(original_moments))

def test_reverse_multiple_moments_multiple_operations():
"""Test reversing a circuit with multiple moments and multiple operations."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
circuit = cirq.Circuit([
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.Z(q0), cirq.H(q1)]),
cirq.Moment([cirq.S(q0), cirq.T(q1)])
])

# Store original structure
original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Check that moments are reversed and operations within each moment are reversed
assert len(circuit) == 3

# First moment should be the reversed last moment
expected_first = list(reversed(original_structure[2]))
actual_first = list(circuit[0])
assert actual_first == expected_first

# Second moment should be the reversed middle moment
expected_second = list(reversed(original_structure[1]))
actual_second = list(circuit[1])
assert actual_second == expected_second

# Third moment should be the reversed first moment
expected_third = list(reversed(original_structure[0]))
actual_third = list(circuit[2])
assert actual_third == expected_third

def test_reverse_twice_returns_original():
"""Test that reversing twice returns the original circuit."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
original_circuit = cirq.Circuit([
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.Z(q0)]),
cirq.Moment([cirq.H(q0), cirq.S(q1)])
])

# Make a copy to compare against
expected = original_circuit.copy()

# Reverse twice
original_circuit.reverse()
original_circuit.reverse()

# Should be back to original
assert original_circuit == expected

def test_reverse_with_measurements():
"""Test reversing a circuit with measurement operations."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
circuit = cirq.Circuit([
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.measure(q0, key='a'), cirq.measure(q1, key='b')])
])

original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Check structure is properly reversed
assert len(circuit) == 2

# First moment should be reversed measurements
expected_first = list(reversed(original_structure[1]))
actual_first = list(circuit[0])
assert len(actual_first) == 2
assert all(isinstance(op.gate, cirq.MeasurementGate) for op in actual_first)

# Second moment should be reversed X, Y gates
expected_second = list(reversed(original_structure[0]))
actual_second = list(circuit[1])
assert len(actual_second) == 2

def test_reverse_with_two_qubit_gates():
"""Test reversing a circuit with two-qubit gates."""
q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)
circuit = cirq.Circuit([
cirq.Moment([cirq.CNOT(q0, q1), cirq.X(q2)]),
cirq.Moment([cirq.CZ(q1, q2)]),
cirq.Moment([cirq.SWAP(q0, q2), cirq.Y(q1)])
])

original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Verify the structure is correctly reversed
assert len(circuit) == 3

# Check that two-qubit gates are preserved correctly
for i, moment in enumerate(circuit):
expected_ops = list(reversed(original_structure[2-i]))
actual_ops = list(moment.operations)
assert actual_ops == expected_ops

def test_reverse_modifies_original_circuit():
"""Test that reverse() modifies the original circuit in-place."""
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit([
cirq.Moment([cirq.X(q)]),
cirq.Moment([cirq.Y(q)])
])

original_id = id(circuit)
circuit.reverse()

# Should be the same object
assert id(circuit) == original_id

# But content should be different
assert str(circuit[0]) != "X(q(0, 0))" # First moment is now Y

@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_all_qubits(circuit_cls) -> None:
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/transformers/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def stratified_circuit(
# Try the algorithm with each permutation of the classifiers.
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
shortest_stratified_circuit = circuits.Circuit()
reversed_circuit = circuit[::-1]
reversed_circuit = circuit.copy().reverse()
for ordered_classifiers in itertools.permutations(classifiers):
solution = _stratify_circuit(
circuit,
Expand All @@ -87,7 +87,8 @@ def stratified_circuit(
reversed_circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)[::-1]
)
solution.reverse()
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)
Expand Down
Loading