Skip to content

Commit c105cb9

Browse files
authored
Merge branch 'main' into remove-controlled-cz-case
2 parents 830bda4 + b1ff3ae commit c105cb9

28 files changed

+318
-115
lines changed

cirq-core/cirq/circuits/circuit.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ def to_text_diagram(
11721172
*,
11731173
use_unicode_characters: bool = True,
11741174
transpose: bool = False,
1175-
include_tags: bool = True,
1175+
include_tags: bool | Iterable[type] = True,
11761176
precision: int | None = 3,
11771177
qubit_order: cirq.QubitOrderOrList = ops.QubitOrder.DEFAULT,
11781178
) -> str:
@@ -1182,7 +1182,10 @@ def to_text_diagram(
11821182
use_unicode_characters: Determines if unicode characters are
11831183
allowed (as opposed to ascii-only diagrams).
11841184
transpose: Arranges qubit wires vertically instead of horizontally.
1185-
include_tags: Whether tags on TaggedOperations should be printed
1185+
include_tags: Controls which tags attached to operations are
1186+
included. ``True`` includes all tags, ``False`` includes none,
1187+
or a collection of tag classes may be specified to include only
1188+
those tags.
11861189
precision: Number of digits to display in text diagram
11871190
qubit_order: Determines how qubits are ordered in the diagram.
11881191
@@ -1209,7 +1212,7 @@ def to_text_diagram_drawer(
12091212
use_unicode_characters: bool = True,
12101213
qubit_namer: Callable[[cirq.Qid], str] | None = None,
12111214
transpose: bool = False,
1212-
include_tags: bool = True,
1215+
include_tags: bool | Iterable[type] = True,
12131216
draw_moment_groups: bool = True,
12141217
precision: int | None = 3,
12151218
qubit_order: cirq.QubitOrderOrList = ops.QubitOrder.DEFAULT,
@@ -1224,7 +1227,10 @@ def to_text_diagram_drawer(
12241227
allowed (as opposed to ascii-only diagrams).
12251228
qubit_namer: Names qubits in diagram. Defaults to using _circuit_diagram_info_ or str.
12261229
transpose: Arranges qubit wires vertically instead of horizontally.
1227-
include_tags: Whether to include tags in the operation.
1230+
include_tags: Controls which tags attached to operations are
1231+
included. ``True`` includes all tags, ``False`` includes none,
1232+
or a collection of tag classes may be specified to include only
1233+
those tags.
12281234
draw_moment_groups: Whether to draw moment symbol or not
12291235
precision: Number of digits to use when representing numbers.
12301236
qubit_order: Determines how qubits are ordered in the diagram.
@@ -2534,7 +2540,7 @@ def _draw_moment_annotations(
25342540
get_circuit_diagram_info: Callable[
25352541
[cirq.Operation, cirq.CircuitDiagramInfoArgs], cirq.CircuitDiagramInfo
25362542
],
2537-
include_tags: bool,
2543+
include_tags: bool | Iterable[type],
25382544
first_annotation_row: int,
25392545
transpose: bool,
25402546
):
@@ -2566,7 +2572,7 @@ def _draw_moment_in_diagram(
25662572
get_circuit_diagram_info: (
25672573
Callable[[cirq.Operation, cirq.CircuitDiagramInfoArgs], cirq.CircuitDiagramInfo] | None
25682574
),
2569-
include_tags: bool,
2575+
include_tags: bool | Iterable[type],
25702576
first_annotation_row: int,
25712577
transpose: bool,
25722578
):
@@ -2637,8 +2643,16 @@ def _draw_moment_in_diagram(
26372643
desc = _formatted_phase(global_phase, use_unicode_characters, precision)
26382644
if desc:
26392645
y = max(label_map.values(), default=0) + 1
2640-
if tags and include_tags:
2641-
desc = desc + f"[{', '.join(map(str, tags))}]"
2646+
visible_tags = protocols.CircuitDiagramInfoArgs(
2647+
known_qubits=None,
2648+
known_qubit_count=None,
2649+
use_unicode_characters=True,
2650+
precision=None,
2651+
label_map=None,
2652+
include_tags=include_tags,
2653+
).tags_to_include(tags)
2654+
if visible_tags:
2655+
desc = desc + f"[{', '.join(map(str, visible_tags))}]"
26422656
out_diagram.write(x0, y, desc)
26432657

26442658
if not non_global_ops:

cirq-core/cirq/circuits/moment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def to_text_diagram(
565565
extra_qubits: Iterable[cirq.Qid] = (),
566566
use_unicode_characters: bool = True,
567567
precision: int | None = None,
568-
include_tags: bool = True,
568+
include_tags: bool | Iterable[type] = True,
569569
) -> str:
570570
"""Create a text diagram for the moment.
571571
@@ -583,8 +583,10 @@ def to_text_diagram(
583583
precision: How precise numbers, such as angles, should be. Use None
584584
for infinite precision, or an integer for a certain number of
585585
digits of precision.
586-
include_tags: Whether or not to include operation tags in the
587-
diagram.
586+
include_tags: Controls which tags attached to operations are
587+
included. ``True`` includes all tags, ``False`` includes none,
588+
or a collection of tag classes may be specified to include only
589+
those tags.
588590
589591
Returns:
590592
The text diagram rendered into text.

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from cirq.experiments.readout_confusion_matrix import TensoredConfusionMatrices
2929

3030
if TYPE_CHECKING:
31-
from cirq.experiments import SingleQubitReadoutCalibrationResult
31+
from cirq.experiments.single_qubit_readout_calibration import (
32+
SingleQubitReadoutCalibrationResult,
33+
)
3234
from cirq.study import ResultDict
3335

3436

@@ -217,6 +219,11 @@ def _normalize_input_paulis(
217219
return cast(dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], circuits_to_pauli)
218220

219221

222+
def _extract_readout_qubits(pauli_strings: list[ops.PauliString]) -> list[ops.Qid]:
223+
"""Extracts unique qubits from a list of QWC Pauli strings."""
224+
return sorted(set(q for ps in pauli_strings for q in ps.qubits))
225+
226+
220227
def _pauli_strings_to_basis_change_ops(
221228
pauli_strings: list[ops.PauliString], qid_list: list[ops.Qid]
222229
):
@@ -315,16 +322,38 @@ def _process_pauli_measurement_results(
315322
for pauli_group_index, circuit_result in enumerate(circuit_results):
316323
measurement_results = circuit_result.measurements["m"]
317324
pauli_strs = pauli_string_groups[pauli_group_index]
325+
pauli_readout_qubits = _extract_readout_qubits(pauli_strs)
326+
327+
calibration_result = (
328+
calibration_results[tuple(pauli_readout_qubits)]
329+
if disable_readout_mitigation is False
330+
else None
331+
)
318332

319333
for pauli_str in pauli_strs:
320334
qubits_sorted = sorted(pauli_str.qubits)
321335
qubit_indices = [qubits.index(q) for q in qubits_sorted]
322336

323-
confusion_matrices = (
324-
_build_many_one_qubits_confusion_matrix(calibration_results[tuple(qubits_sorted)])
325-
if disable_readout_mitigation is False
326-
else _build_many_one_qubits_empty_confusion_matrix(len(qubits_sorted))
327-
)
337+
if disable_readout_mitigation:
338+
pauli_str_calibration_result = None
339+
confusion_matrices = _build_many_one_qubits_empty_confusion_matrix(
340+
len(qubits_sorted)
341+
)
342+
else:
343+
if calibration_result is None:
344+
# This case should be logically impossible if mitigation is on,
345+
# so we raise an error.
346+
raise ValueError(
347+
f"Readout mitigation is enabled, but no calibration result was "
348+
f"found for qubits {pauli_readout_qubits}."
349+
)
350+
pauli_str_calibration_result = calibration_result.readout_result_for_qubits(
351+
qubits_sorted
352+
)
353+
confusion_matrices = _build_many_one_qubits_confusion_matrix(
354+
pauli_str_calibration_result
355+
)
356+
328357
tensored_cm = TensoredConfusionMatrices(
329358
confusion_matrices,
330359
[[q] for q in qubits_sorted],
@@ -356,11 +385,7 @@ def _process_pauli_measurement_results(
356385
mitigated_stddev=d_m_with_coefficient,
357386
unmitigated_expectation=unmitigated_value_with_coefficient,
358387
unmitigated_stddev=d_unmit_with_coefficient,
359-
calibration_result=(
360-
calibration_results[tuple(qubits_sorted)]
361-
if disable_readout_mitigation is False
362-
else None
363-
),
388+
calibration_result=pauli_str_calibration_result,
364389
)
365390
)
366391

@@ -428,8 +453,7 @@ def measure_pauli_strings(
428453
unique_qubit_tuples = set()
429454
for pauli_string_groups in normalized_circuits_to_pauli.values():
430455
for pauli_strings in pauli_string_groups:
431-
for pauli_string in pauli_strings:
432-
unique_qubit_tuples.add(tuple(sorted(pauli_string.qubits)))
456+
unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings)))
433457
# qubits_list is a list of qubit tuples
434458
qubits_list = sorted(unique_qubit_tuples)
435459

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323

2424
import cirq
2525
from cirq.contrib.paulistring import measure_pauli_strings
26-
from cirq.experiments import SingleQubitReadoutCalibrationResult
26+
from cirq.contrib.paulistring.pauli_string_measurement_with_readout_mitigation import (
27+
_process_pauli_measurement_results,
28+
)
29+
from cirq.experiments.single_qubit_readout_calibration import SingleQubitReadoutCalibrationResult
2730
from cirq.experiments.single_qubit_readout_calibration_test import NoisySingleQubitReadoutSampler
2831

2932

@@ -867,3 +870,37 @@ def test_group_paulis_type_mismatch() -> None:
867870
measure_pauli_strings(
868871
circuits_to_pauli, cirq.Simulator(), 1000, 1000, 1000, np.random.default_rng()
869872
)
873+
874+
875+
def test_process_pauli_measurement_results_raises_error_on_missing_calibration() -> None:
876+
"""Test that the function raises an error if the calibration result is missing."""
877+
qubits: list[cirq.Qid] = [q for q in cirq.LineQubit.range(5)]
878+
879+
measurement_op = cirq.measure(*qubits, key='m')
880+
test_circuits = list[cirq.Circuit]()
881+
for _ in range(3):
882+
circuit_list = []
883+
884+
circuit = _create_ghz(5, qubits) + measurement_op
885+
circuit_list.append(circuit)
886+
test_circuits.extend(circuit_list)
887+
888+
pauli_strings = [_generate_random_pauli_string(qubits, True) for _ in range(3)]
889+
sampler = cirq.Simulator()
890+
891+
circuit_results = sampler.run_batch(test_circuits, repetitions=1000)
892+
893+
empty_calibration_result_dict = {tuple(qubits): None}
894+
895+
with pytest.raises(
896+
ValueError,
897+
match="Readout mitigation is enabled, but no calibration result was found for qubits",
898+
):
899+
_process_pauli_measurement_results(
900+
qubits,
901+
[pauli_strings],
902+
circuit_results[0], # type: ignore[arg-type]
903+
empty_calibration_result_dict, # type: ignore[arg-type]
904+
1000,
905+
1.0,
906+
)

cirq-core/cirq/contrib/quantum_volume/quantum_volume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def sample_heavy_set(
127127
# Add measure gates to the end of (a copy of) the circuit. Ensure that those
128128
# gates measure those in the given mapping, preserving this order.
129129
qubits = circuit.all_qubits()
130-
key = None
130+
key: Callable[[cirq.Qid], cirq.Qid] | None = None
131131
if mapping:
132132
# Add any qubits that were not explicitly mapped, so they aren't lost in
133133
# the sorting.
@@ -137,7 +137,7 @@ def sample_heavy_set(
137137
# Don't do a single large measurement gate because then the key will be one
138138
# large string. Instead, do a bunch of single-qubit measurement gates so we
139139
# preserve the qubit keys.
140-
sorted_qubits = sorted(qubits, key=key) # type: ignore[arg-type]
140+
sorted_qubits = sorted(qubits, key=key)
141141
circuit_copy = circuit + [cirq.measure(q) for q in sorted_qubits]
142142

143143
# Run the sampler to compare each output against the Heavy Set.

cirq-core/cirq/experiments/single_qubit_readout_calibration.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,17 @@ def plot_integrated_histogram(
179179
ax.set_ylabel('Percentile')
180180
return ax
181181

182+
def readout_result_for_qubits(
183+
self, readout_qubits: list[ops.Qid]
184+
) -> SingleQubitReadoutCalibrationResult:
185+
"""Builds a calibration result for the specific readout qubits."""
186+
return SingleQubitReadoutCalibrationResult(
187+
zero_state_errors={qubit: self.zero_state_errors[qubit] for qubit in readout_qubits},
188+
one_state_errors={qubit: self.one_state_errors[qubit] for qubit in readout_qubits},
189+
timestamp=self.timestamp,
190+
repetitions=self.repetitions,
191+
)
192+
182193
@classmethod
183194
def _from_json_dict_(
184195
cls, zero_state_errors, one_state_errors, repetitions, timestamp, **kwargs

cirq-core/cirq/linalg/predicates.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,12 @@ def is_unitary(matrix: np.ndarray, *, rtol: float = 1e-5, atol: float = 1e-8) ->
115115
Returns:
116116
Whether the matrix is unitary within the given tolerance.
117117
"""
118-
return matrix.shape[0] == matrix.shape[1] and np.allclose(
119-
matrix.dot(np.conj(matrix.T)), np.eye(matrix.shape[0]), rtol=rtol, atol=atol
118+
return (
119+
matrix.ndim == 2
120+
and matrix.shape[0] == matrix.shape[1]
121+
and np.allclose(
122+
matrix.dot(np.conj(matrix.T)), np.eye(matrix.shape[0]), rtol=rtol, atol=atol
123+
)
120124
)
121125

122126

cirq-core/cirq/linalg/predicates_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,13 @@ def test_is_hermitian_tolerance():
103103

104104

105105
def test_is_unitary():
106+
assert not cirq.is_unitary(np.empty((0,)))
106107
assert cirq.is_unitary(np.empty((0, 0)))
107108
assert not cirq.is_unitary(np.empty((1, 0)))
108109
assert not cirq.is_unitary(np.empty((0, 1)))
110+
assert not cirq.is_unitary(np.empty((0, 0, 0)))
109111

112+
assert not cirq.is_unitary(np.array(1))
110113
assert cirq.is_unitary(np.array([[1]]))
111114
assert cirq.is_unitary(np.array([[-1]]))
112115
assert cirq.is_unitary(np.array([[1j]]))

cirq-core/cirq/ops/raw_types.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -913,11 +913,12 @@ def _resolve_parameters_(
913913

914914
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
915915
sub_op_info = protocols.circuit_diagram_info(self.sub_operation, args, NotImplemented)
916-
# Add tag to wire symbol if it exists.
917-
if sub_op_info is not NotImplemented and args.include_tags and sub_op_info.wire_symbols:
918-
sub_op_info.wire_symbols = (
919-
sub_op_info.wire_symbols[0] + f"[{', '.join(map(str, self._tags))}]",
920-
) + sub_op_info.wire_symbols[1:]
916+
if sub_op_info is not NotImplemented and sub_op_info.wire_symbols:
917+
visible_tags = args.tags_to_include(self._tags)
918+
if visible_tags:
919+
sub_op_info.wire_symbols = (
920+
sub_op_info.wire_symbols[0] + f"[{', '.join(map(str, visible_tags))}]",
921+
) + sub_op_info.wire_symbols[1:]
921922
return sub_op_info
922923

923924
@cached_method

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,8 @@ def __str__(self):
554554
diagram_with_non_string_tag = "(1, 1): ───H[<taggy>]───"
555555
assert c.to_text_diagram() == diagram_with_non_string_tag
556556
assert c.to_text_diagram(include_tags=False) == diagram_without_tags
557+
assert c.to_text_diagram(include_tags={str}) == diagram_without_tags
558+
assert c.to_text_diagram(include_tags={TaggyTag}) == diagram_with_non_string_tag
557559

558560

559561
def test_circuit_diagram_tagged_global_phase() -> None:
@@ -651,7 +653,7 @@ def test_tagged_operation_forwards_protocols() -> None:
651653
np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
652654
assert cirq.has_unitary(tagged_h)
653655
assert cirq.decompose(tagged_h) == cirq.decompose(h)
654-
assert [*tagged_h._decompose_()] == cirq.decompose(h)
656+
assert [*tagged_h._decompose_()] == cirq.decompose_once(h)
655657
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
656658
assert cirq.equal_up_to_global_phase(h, tagged_h)
657659
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()

0 commit comments

Comments
 (0)