Skip to content

Commit 6b9e0a1

Browse files
committed
fix: update type annotations and tests for Sequence[int] and QubitPermutationGate
1 parent 7560153 commit 6b9e0a1

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

cirq-core/cirq/sim/classical_simulator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def __init__(
8585
8686
Args:
8787
qubits: The qubits to simulate.
88-
initial_state: The initial state for the simulation. Accepts int or a sequence of int.
88+
initial_state: The initial state for the simulation. Accepts int or Sequence[int].
8989
classical_data: The classical data container for the simulation.
9090
9191
Raises:
9292
ValueError: If qubits not provided and initial_state is int.
93-
If initial_state is not an int, list[int], tuple[int], or np.ndarray.
93+
If initial_state is not an int or Sequence[int].
9494
If initial_state is a np.ndarray and its shape is not 1-dimensional.
9595
If gate is not one of X, SWAP, QubitPermutationGate, a controlled version
9696
of X or SWAP, or a measurement.
@@ -99,7 +99,7 @@ def __init__(
9999
"""
100100
if isinstance(initial_state, int):
101101
if qubits is None:
102-
raise ValueError('qubits must be provided if initial_state is not list[int]')
102+
raise ValueError('qubits must be provided if initial_state is not Sequence[int]')
103103
state = ClassicalBasisState(
104104
big_endian_int_to_bits(initial_state, bit_count=len(qubits))
105105
)
@@ -109,10 +109,10 @@ def __init__(
109109
f'initial_state must be 1-dimensional, got shape {initial_state.shape}'
110110
)
111111
state = ClassicalBasisState(list(initial_state))
112-
elif isinstance(initial_state, (list, tuple)):
112+
elif isinstance(initial_state, Sequence) and not isinstance(initial_state, (str, bytes)):
113113
state = ClassicalBasisState(list(initial_state))
114114
else:
115-
raise ValueError('initial_state must be an int, list[int], tuple[int], or np.ndarray')
115+
raise ValueError('initial_state must be an int or Sequence[int]')
116116
super().__init__(state=state, qubits=qubits, classical_data=classical_data)
117117

118118
def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose: bool = True):

cirq-core/cirq/sim/classical_simulator_test.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,21 @@ def test_Swap():
6060
np.testing.assert_equal(results, expected_results)
6161

6262

63-
def test_qubit_permutation_gate():
64-
q0, q1, q2 = cirq.LineQubit.range(3)
65-
perm_gate = cirq.QubitPermutationGate([2, 0, 1])
66-
circuit = cirq.Circuit(perm_gate(q0, q1, q2), cirq.measure(q0, q1, q2, key='key'))
63+
@pytest.mark.parametrize(
64+
"n,perm,state",
65+
[
66+
(n, np.random.permutation(n).tolist(), np.random.choice(2, size=n))
67+
for n in np.random.randint(3, 8, size=10)
68+
],
69+
)
70+
def test_qubit_permutation_gate(n, perm, state):
71+
qubits = cirq.LineQubit.range(n)
72+
perm_gate = cirq.QubitPermutationGate(perm)
73+
circuit = cirq.Circuit(perm_gate(*qubits), cirq.measure(*qubits, key='key'))
6774
sim = cirq.ClassicalStateSimulator()
68-
result = sim.simulate(circuit, initial_state=[1, 0, 1])
69-
np.testing.assert_equal(result.measurements['key'], [1, 1, 0])
75+
result = sim.simulate(circuit, initial_state=state)
76+
expected = [state[perm[i]] for i in range(n)]
77+
np.testing.assert_equal(result.measurements['key'], expected)
7078

7179

7280
def test_CCNOT():

0 commit comments

Comments
 (0)