Skip to content

Commit 63b7a75

Browse files
committed
Add return-type to public functions, mostly tests part 1
No change in the effective code. A batch of ~50 files. Modified files pass ruff check --select=ANN201 Partially implements #4393
1 parent 5c96d02 commit 63b7a75

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+529
-456
lines changed

benchmarks/bench_linalg_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# yapf: enable
3030

3131

32-
def time_kak_decomposition(target):
32+
def time_kak_decomposition(target) -> None:
3333
"""Benchmark kak_decomposition
3434
kak_decomposition is benchmarked because it was historically slow.
3535
See https://github.com/quantumlib/Cirq/issues/3840 for status of other benchmarks.

benchmarks/circuit_construction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,6 @@ class XOnAllQubitsCircuit:
131131
params = [[1, 10, 100, 1000], [1, 10, 100, 1000]]
132132
param_names = ["Number of Qubits(N)", "Depth(D)"]
133133

134-
def time_circuit_construction(self, N: int, D: int):
134+
def time_circuit_construction(self, N: int, D: int) -> cirq.Circuit:
135135
q = cirq.LineQubit.range(N)
136136
return cirq.Circuit(cirq.Moment(cirq.X.on_each(*q)) for _ in range(D))

benchmarks/parameter_resolution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ class RabiCalibration:
2424
params = ([50, 100, 150, 200], [20, 40, 60, 80, 100])
2525
param_names = ["num_qubits", "num_scan_points"]
2626

27-
def setup(self, num_qubits: int, _):
27+
def setup(self, num_qubits: int, _) -> None:
2828
qubits = cirq.GridQubit.rect(1, num_qubits)
2929
self.symbols = {q: sympy.Symbol(f'a_{q}') for q in qubits}
3030
self.circuit = cirq.Circuit(
3131
[cirq.X(q) ** self.symbols[q] for q in qubits], cirq.measure_each(*qubits)
3232
)
3333
self.qubit_amps = {q: random.uniform(0.48, 0.52) for q in qubits}
3434

35-
def time_parameter_resolution(self, _, num_scan_points: int):
35+
def time_parameter_resolution(self, _, num_scan_points: int) -> None:
3636
for diff in np.linspace(-0.3, 0.3, num=num_scan_points):
3737
resolver = {self.symbols[q]: amp + diff for q, amp in self.qubit_amps.items()}
3838
_ = cirq.resolve_parameters(self.circuit, resolver)

benchmarks/randomized_benchmarking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class SingleQubitRandomizedBenchmarking:
3939
param_names = ["depth", "num_qubits", "num_circuits"]
4040
timeout = 600 # Change timeout to 10 minutes instead of default 60 seconds.
4141

42-
def setup(self, *_):
42+
def setup(self, *_) -> None:
4343
self.sq_xz_matrices = np.array(
4444
[
4545
dot([cirq.unitary(c) for c in reversed(group)])
@@ -60,12 +60,12 @@ def _get_op_grid(self, qubits: list[cirq.Qid], depth: int) -> list[list[cirq.Ope
6060
op_grid.append(op_sequence)
6161
return op_grid
6262

63-
def time_rb_op_grid_generation(self, depth: int, num_qubits: int, num_circuits: int):
63+
def time_rb_op_grid_generation(self, depth: int, num_qubits: int, num_circuits: int) -> None:
6464
qubits = cirq.GridQubit.rect(1, num_qubits)
6565
for _ in range(num_circuits):
6666
self._get_op_grid(qubits, depth)
6767

68-
def time_rb_circuit_construction(self, depth: int, num_qubits: int, num_circuits: int):
68+
def time_rb_circuit_construction(self, depth: int, num_qubits: int, num_circuits: int) -> None:
6969
qubits = cirq.GridQubit.rect(1, num_qubits)
7070
for _ in range(num_circuits):
7171
op_grid = self._get_op_grid(qubits, depth)

benchmarks/serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class SerializeLargeExpandedCircuits:
2929
params = ([100, 500, 1000], [100, 1000, 4000])
3030
timeout = 600 # Change timeout to 2 minutes instead of default 60 seconds.
3131

32-
def setup(self, num_qubits: int, num_moments: int):
32+
def setup(self, num_qubits: int, num_moments: int) -> None:
3333
qubits = cirq.LineQubit.range(num_qubits)
3434
one_q_x_moment = cirq.Moment(cirq.X(q) for q in qubits[::2])
3535
one_q_y_moment = cirq.Moment(cirq.Y(q) for q in qubits[1::2])
@@ -43,11 +43,11 @@ def setup(self, num_qubits: int, num_moments: int):
4343
* (num_moments // 5)
4444
)
4545

46-
def time_json_serialization(self, *_):
46+
def time_json_serialization(self, *_) -> None:
4747
_ = cirq.to_json(self.circuit)
4848

49-
def time_json_serialization_gzip(self, *_):
49+
def time_json_serialization_gzip(self, *_) -> None:
5050
_ = cirq.to_json_gzip(self.circuit)
5151

52-
def track_json_serialization_gzip_size(self, *_):
52+
def track_json_serialization_gzip_size(self, *_) -> str:
5353
return _human_size(len(cirq.to_json_gzip(self.circuit)))

benchmarks/transformers/routing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ class RouteCQC:
2020
param_names = ["qubits", "depth", "op_density", "grid_device_size"]
2121
timeout = 300 # Increase timeout to 5 minutes instead of default 60 seconds.
2222

23-
def setup(self, qubits: int, depth: int, op_density: float, grid_device_size: int):
23+
def setup(self, qubits: int, depth: int, op_density: float, grid_device_size: int) -> None:
2424
gate_domain = {cirq.CNOT: 2, cirq.X: 1}
2525
self.circuit = cirq.testing.random_circuit(
2626
qubits, depth, op_density, gate_domain=gate_domain, random_state=12345
2727
)
2828
self.device = cirq.testing.construct_grid_device(grid_device_size, grid_device_size)
2929
self.router = cirq.RouteCQC(self.device.metadata.nx_graph)
3030

31-
def time_circuit_routing(self, *_):
31+
def time_circuit_routing(self, *_) -> None:
3232
self.routed_circuit = self.router(self.circuit)
3333

3434
def track_routed_circuit_depth_ratio(self, *_) -> float:
3535
self.routed_circuit = self.router(self.circuit)
3636
return len(self.routed_circuit) / len(self.circuit)
3737

38-
def teardown(self, *_):
38+
def teardown(self, *_) -> None:
3939
self.device.validate_circuit(self.routed_circuit)

benchmarks/transformers/transformer_primitives.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class MapLargeExpandedCircuit:
2020
params = ([100, 500, 1000], [100, 1000, 4000])
2121
timeout = 600 # Change timeout to 2 minutes instead of default 60 seconds.
2222

23-
def setup(self, num_qubits: int, num_moments: int):
23+
def setup(self, num_qubits: int, num_moments: int) -> None:
2424
qubits = cirq.LineQubit.range(num_qubits)
2525
one_q_x_moment = cirq.Moment(cirq.X(q) for q in qubits[::2])
2626
one_q_y_moment = cirq.Moment(cirq.Y(q) for q in qubits[1::2])
@@ -32,7 +32,7 @@ def setup(self, num_qubits: int, num_moments: int):
3232
[one_q_x_moment, two_q_cx_moment, one_q_y_moment, two_q_cz_moment] * (num_moments // 4)
3333
)
3434

35-
def time_map_moments(self, num_qubits: int, _):
35+
def time_map_moments(self, num_qubits: int, _) -> None:
3636
all_qubits = cirq.LineQubit.range(num_qubits)
3737

3838
def map_func(m: cirq.Moment, _) -> cirq.Moment:
@@ -46,19 +46,19 @@ def map_func(m: cirq.Moment, _) -> cirq.Moment:
4646

4747
_ = cirq.map_moments(circuit=self.circuit, map_func=map_func)
4848

49-
def time_map_operations_apply_tag(self, *_):
49+
def time_map_operations_apply_tag(self, *_) -> None:
5050
def map_func(op: cirq.Operation, _) -> cirq.Operation:
5151
return op.with_tags("old op")
5252

5353
_ = cirq.map_operations(circuit=self.circuit, map_func=map_func)
5454

55-
def time_map_operations_to_optree(self, *_):
55+
def time_map_operations_to_optree(self, *_) -> None:
5656
def map_func(op: cirq.Operation, _) -> cirq.OP_TREE:
5757
return [op, op]
5858

5959
_ = cirq.map_operations(circuit=self.circuit, map_func=map_func)
6060

61-
def time_map_operations_to_optree_and_unroll(self, *_):
61+
def time_map_operations_to_optree_and_unroll(self, *_) -> None:
6262
def map_func(op: cirq.Operation, _) -> cirq.OP_TREE:
6363
return [op, op]
6464

cirq-aqt/cirq_aqt/aqt_device.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __init__(
189189
self.noise_dict = noise_dict
190190
self.simulate_ideal = simulate_ideal
191191

192-
def generate_circuit_from_list(self, json_string: str):
192+
def generate_circuit_from_list(self, json_string: str) -> None:
193193
"""Generates a list of cirq operations from a json string.
194194
195195
The default behavior is to add a measurement to any qubit at the end
@@ -288,11 +288,11 @@ def __init__(
288288
def metadata(self) -> aqt_device_metadata.AQTDeviceMetadata:
289289
return self._metadata
290290

291-
def validate_gate(self, gate: cirq.Gate):
291+
def validate_gate(self, gate: cirq.Gate) -> None:
292292
if gate not in self.metadata.gateset:
293293
raise ValueError(f'Unsupported gate type: {gate!r}')
294294

295-
def validate_operation(self, operation):
295+
def validate_operation(self, operation) -> None:
296296
if not isinstance(operation, cirq.GateOperation):
297297
raise ValueError(f'Unsupported operation: {operation!r}')
298298

@@ -304,7 +304,7 @@ def validate_operation(self, operation):
304304
if q not in self.qubits:
305305
raise ValueError(f'Qubit not on device: {q!r}')
306306

307-
def validate_circuit(self, circuit: cirq.AbstractCircuit):
307+
def validate_circuit(self, circuit: cirq.AbstractCircuit) -> None:
308308
super().validate_circuit(circuit)
309309
_verify_unique_measurement_keys(circuit.all_operations())
310310

cirq-aqt/cirq_aqt/aqt_device_metadata_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def metadata(qubits) -> AQTDeviceMetadata:
3838
)
3939

4040

41-
def test_aqtdevice_metadata(metadata, qubits):
41+
def test_aqtdevice_metadata(metadata, qubits) -> None:
4242
assert metadata.qubit_set == frozenset(qubits)
4343
assert set(qubits) == set(metadata.nx_graph.nodes())
4444
edges = metadata.nx_graph.edges()
@@ -48,7 +48,7 @@ def test_aqtdevice_metadata(metadata, qubits):
4848
assert len(metadata.gate_durations) == 4
4949

5050

51-
def test_aqtdevice_duration_of(metadata, qubits):
51+
def test_aqtdevice_duration_of(metadata, qubits) -> None:
5252
q0, q1 = qubits[:2]
5353
ms = cirq.Duration(millis=1)
5454
assert metadata.duration_of(cirq.Z(q0)) == 10 * ms
@@ -59,5 +59,5 @@ def test_aqtdevice_duration_of(metadata, qubits):
5959
metadata.duration_of(cirq.I(q0))
6060

6161

62-
def test_repr(metadata):
62+
def test_repr(metadata) -> None:
6363
cirq.testing.assert_equivalent_repr(metadata, setup_code='import cirq\nimport cirq_aqt\n')

cirq-aqt/cirq_aqt/aqt_device_test.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,24 @@ def with_qubits(self, *new_qubits) -> NotImplementedOperation:
4343
raise NotImplementedError()
4444

4545
@property
46-
def qubits(self):
46+
def qubits(self) -> tuple[cirq.Qid, ...]:
4747
raise NotImplementedError()
4848

4949

50-
def test_init_qubits(device, qubits):
50+
def test_init_qubits(device, qubits) -> None:
5151
ms = cirq.Duration(millis=1)
5252
assert device.qubits == frozenset(qubits)
5353
with pytest.raises(TypeError, match="NamedQubit"):
5454
aqt_device.AQTDevice(
5555
measurement_duration=100 * ms,
5656
twoq_gates_duration=200 * ms,
5757
oneq_gates_duration=10 * ms,
58-
qubits=[cirq.LineQubit(0), cirq.NamedQubit("a")],
58+
qubits=[cirq.LineQubit(0), cirq.NamedQubit("a")], # type: ignore[list-item]
5959
)
6060

6161

6262
@pytest.mark.parametrize('ms', [cirq.Duration(millis=1), timedelta(milliseconds=1)])
63-
def test_init_durations(ms, qubits):
63+
def test_init_durations(ms, qubits) -> None:
6464
dev = aqt_device.AQTDevice(
6565
qubits=qubits,
6666
measurement_duration=100 * ms,
@@ -72,12 +72,12 @@ def test_init_durations(ms, qubits):
7272
assert dev.metadata.measurement_duration == cirq.Duration(millis=100)
7373

7474

75-
def test_metadata(device, qubits):
75+
def test_metadata(device, qubits) -> None:
7676
assert isinstance(device.metadata, aqt_device_metadata.AQTDeviceMetadata)
7777
assert device.metadata.qubit_set == frozenset(qubits)
7878

7979

80-
def test_repr(device):
80+
def test_repr(device) -> None:
8181
assert repr(device) == (
8282
"cirq_aqt.aqt_device.AQTDevice("
8383
"measurement_duration=cirq.Duration(millis=100), "
@@ -89,13 +89,13 @@ def test_repr(device):
8989
cirq.testing.assert_equivalent_repr(device, setup_code='import cirq\nimport cirq_aqt\n')
9090

9191

92-
def test_validate_measurement_non_adjacent_qubits_ok(device):
92+
def test_validate_measurement_non_adjacent_qubits_ok(device) -> None:
9393
device.validate_operation(
9494
cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1)))
9595
)
9696

9797

98-
def test_validate_operation_existing_qubits(device):
98+
def test_validate_operation_existing_qubits(device) -> None:
9999
device.validate_operation(cirq.GateOperation(cirq.XX, (cirq.LineQubit(0), cirq.LineQubit(1))))
100100
device.validate_operation(cirq.Z(cirq.LineQubit(0)))
101101
device.validate_operation(
@@ -114,7 +114,7 @@ def test_validate_operation_existing_qubits(device):
114114
device.validate_operation(cirq.X(cirq.NamedQubit("q1")))
115115

116116

117-
def test_validate_operation_supported_gate(device):
117+
def test_validate_operation_supported_gate(device) -> None:
118118
class MyGate(cirq.Gate):
119119
def num_qubits(self):
120120
return 1
@@ -128,12 +128,12 @@ def num_qubits(self):
128128
device.validate_operation(NotImplementedOperation())
129129

130130

131-
def test_aqt_device_eq(device):
131+
def test_aqt_device_eq(device) -> None:
132132
eq = cirq.testing.EqualsTester()
133133
eq.make_equality_group(lambda: device)
134134

135135

136-
def test_validate_circuit_repeat_measurement_keys(device):
136+
def test_validate_circuit_repeat_measurement_keys(device) -> None:
137137
circuit = cirq.Circuit()
138138
circuit.append(
139139
[cirq.measure(cirq.LineQubit(0), key='a'), cirq.measure(cirq.LineQubit(1), key='a')]
@@ -143,16 +143,16 @@ def test_validate_circuit_repeat_measurement_keys(device):
143143
device.validate_circuit(circuit)
144144

145145

146-
def test_aqt_device_str(device):
146+
def test_aqt_device_str(device) -> None:
147147
assert str(device) == "q(0)───q(1)───q(2)"
148148

149149

150-
def test_aqt_device_pretty_repr(device):
150+
def test_aqt_device_pretty_repr(device) -> None:
151151
cirq.testing.assert_repr_pretty(device, "q(0)───q(1)───q(2)")
152152
cirq.testing.assert_repr_pretty(device, "AQTDevice(...)", cycle=True)
153153

154154

155-
def test_at(device):
155+
def test_at(device) -> None:
156156
assert device.at(-1) is None
157157
assert device.at(0) == cirq.LineQubit(0)
158158
assert device.at(2) == cirq.LineQubit(2)

0 commit comments

Comments
 (0)