Skip to content

Commit 3d997b8

Browse files
committed
fix: update type annotations and tests for Sequence[int] and QubitPermutationGate
1 parent 7560153 commit 3d997b8

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

cirq-core/cirq/sim/classical_simulator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _is_identity(action) -> bool:
3838
class ClassicalBasisState(qis.QuantumStateRepresentation):
3939
"""Represents a classical basis state for efficient state evolution."""
4040

41-
def __init__(self, initial_state: list[int] | np.ndarray):
41+
def __init__(self, initial_state: Sequence[int] | np.ndarray):
4242
"""Initializes the ClassicalBasisState object.
4343
4444
Args:
@@ -85,21 +85,19 @@ def __init__(
8585
8686
Args:
8787
qubits: The qubits to simulate.
88-
initial_state: The initial state for the simulation. Accepts int or a sequence of int.
88+
initial_state: The initial state for the simulation. Accepts int or Sequence[int].
8989
classical_data: The classical data container for the simulation.
9090
9191
Raises:
9292
ValueError: If qubits not provided and initial_state is int.
93-
If initial_state is not an int, list[int], tuple[int], or np.ndarray.
93+
If initial_state is not an int or Sequence[int].
9494
If initial_state is a np.ndarray and its shape is not 1-dimensional.
95-
If gate is not one of X, SWAP, QubitPermutationGate, a controlled version
96-
of X or SWAP, or a measurement.
9795
9896
An initial_state value of type integer is parsed in big endian order.
9997
"""
10098
if isinstance(initial_state, int):
10199
if qubits is None:
102-
raise ValueError('qubits must be provided if initial_state is not list[int]')
100+
raise ValueError('qubits must be provided if initial_state is not Sequence[int]')
103101
state = ClassicalBasisState(
104102
big_endian_int_to_bits(initial_state, bit_count=len(qubits))
105103
)
@@ -109,10 +107,10 @@ def __init__(
109107
f'initial_state must be 1-dimensional, got shape {initial_state.shape}'
110108
)
111109
state = ClassicalBasisState(list(initial_state))
112-
elif isinstance(initial_state, (list, tuple)):
110+
elif isinstance(initial_state, Sequence) and not isinstance(initial_state, (str, bytes)):
113111
state = ClassicalBasisState(list(initial_state))
114112
else:
115-
raise ValueError('initial_state must be an int, list[int], tuple[int], or np.ndarray')
113+
raise ValueError('initial_state must be an int or Sequence[int]')
116114
super().__init__(state=state, qubits=qubits, classical_data=classical_data)
117115

118116
def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose: bool = True):
@@ -125,6 +123,10 @@ def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose:
125123
126124
Returns:
127125
True if the operation was applied successfully.
126+
127+
Raises:
128+
ValueError: If gate is not one of X, SWAP, QubitPermutationGate, a controlled version
129+
of X or SWAP, or a measurement.
128130
"""
129131
gate = action.gate if isinstance(action, ops.Operation) else action
130132
mapped_qubits = [self.qubit_map[i] for i in qubits]

cirq-core/cirq/sim/classical_simulator_test.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,21 @@ def test_Swap():
6060
np.testing.assert_equal(results, expected_results)
6161

6262

63-
def test_qubit_permutation_gate():
64-
q0, q1, q2 = cirq.LineQubit.range(3)
65-
perm_gate = cirq.QubitPermutationGate([2, 0, 1])
66-
circuit = cirq.Circuit(perm_gate(q0, q1, q2), cirq.measure(q0, q1, q2, key='key'))
63+
@pytest.mark.parametrize(
64+
"n,perm,state",
65+
[
66+
(n, np.random.permutation(n).tolist(), np.random.choice(2, size=n))
67+
for n in np.random.randint(3, 8, size=10)
68+
],
69+
)
70+
def test_qubit_permutation_gate(n, perm, state):
71+
qubits = cirq.LineQubit.range(n)
72+
perm_gate = cirq.QubitPermutationGate(perm)
73+
circuit = cirq.Circuit(perm_gate(*qubits), cirq.measure(*qubits, key='key'))
6774
sim = cirq.ClassicalStateSimulator()
68-
result = sim.simulate(circuit, initial_state=[1, 0, 1])
69-
np.testing.assert_equal(result.measurements['key'], [1, 1, 0])
75+
result = sim.simulate(circuit, initial_state=state)
76+
expected = [state[perm[i]] for i in range(n)]
77+
np.testing.assert_equal(result.measurements['key'], expected)
7078

7179

7280
def test_CCNOT():

0 commit comments

Comments
 (0)