Skip to content

Commit 9c2064f

Browse files
mypy & coverage
1 parent 2c46d19 commit 9c2064f

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

cirq-core/cirq/transformers/pauli_insertion.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,7 @@
2323
from cirq import circuits, ops
2424
from cirq.transformers import transformer_api
2525

26-
_PAULIS = [ops.I, ops.X, ops.Y, ops.Z]
27-
28-
29-
def _is_target(
30-
op: ops.Operation,
31-
target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation],
32-
):
33-
if inspect.isclass(target):
34-
if issubclass(target, ops.Operation):
35-
return isinstance(op, target)
36-
if not hasattr(op, 'gate'):
37-
return False
38-
return isinstance(op.gate, target)
39-
if isinstance(target, ops.Gate):
40-
if not hasattr(op, 'gate') or op.gate is None:
41-
return False
42-
return op.gate == target
43-
return op in target
26+
_PAULIS: tuple[ops.Gate] = (ops.I, ops.X, ops.Y, ops.Z) # type: ignore[has-type]
4427

4528

4629
@transformer_api.transformer
@@ -54,13 +37,13 @@ class PauliInsertionTransformer:
5437

5538
def __init__(
5639
self,
57-
target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate | ops.Operation],
40+
target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate],
5841
probabilities: np.ndarray | None = None,
5942
):
6043
"""Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities.
6144
6245
Args:
63-
target: The target gate, gatefamily, gateset, or type (e.g. PauliSumExponential).
46+
target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte).
6447
probabilities: Optional ndarray representing the probabilities of sampling 2Q paulis.
6548
The order of the paulis is IXYZ. If None, assume uniform distribution.
6649
Returns:
@@ -72,7 +55,13 @@ def __init__(
7255
assert probabilities.shape == (4, 4)
7356
assert np.isclose(probabilities.sum(), 1)
7457

75-
self.target = target
58+
if inspect.isclass(target):
59+
self.target = ops.GateFamily(target)
60+
elif isinstance(target, ops.Gate):
61+
self.target = ops.Gateset(target)
62+
else:
63+
assert isinstance(target, (ops.Gateset, ops.GateFamily))
64+
self.target = target
7665
self._flat_probs = probabilities.reshape(-1)
7766

7867
def __call__(
@@ -106,7 +95,7 @@ def __call__(
10695
for op in moment:
10796
if any(tag in tags_to_ignore for tag in op.tags):
10897
continue
109-
if not _is_target(op, self.target):
98+
if op not in self.target:
11099
continue
111100
pair = np.unravel_index(rng.choice(16, p=self._flat_probs), (4, 4))
112101
for pauli_index, q in zip(pair, op.qubits):

cirq-core/cirq/transformers/pauli_insertion_test.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,17 @@ def _random_probs(n: int, seed: int | None = None):
2929

3030

3131
@pytest.mark.parametrize('probs', _random_probs(3, 0))
32-
def test_pauli_insertion_with_probabilities(probs):
32+
@pytest.mark.parametrize(
33+
'target',
34+
[cirq.ZZPowGate, cirq.ZZ**0.324, cirq.Gateset(cirq.ZZ**0.324), cirq.GateFamily(cirq.ZZ**0.324)],
35+
)
36+
def test_pauli_insertion_with_probabilities(probs, target):
3337
c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324)
34-
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs)
38+
transformer = cirq.transformers.PauliInsertionTransformer(target, probs)
3539
count = np.zeros((4, 4))
40+
rng = np.random.default_rng(0)
3641
for _ in range(100):
37-
nc = transformer(c)
42+
nc = transformer(c, rng_or_seed=rng)
3843
assert len(nc) == 2
3944
u, v = nc[0]
4045
i = _PAULIS.index(u.gate)
@@ -49,12 +54,36 @@ def test_pauli_insertion_with_probabilities_doesnot_create_moment(probs):
4954
c = cirq.Circuit.from_moments([], [cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324])
5055
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs)
5156
count = np.zeros((4, 4))
57+
rng = np.random.default_rng(0)
5258
for _ in range(100):
53-
nc = transformer(c)
59+
nc = transformer(c, rng_or_seed=rng)
5460
assert len(nc) == 2
5561
u, v = nc[0]
5662
i = _PAULIS.index(u.gate)
5763
j = _PAULIS.index(v.gate)
5864
count[i, j] += 1
5965
count = count / count.sum()
6066
np.testing.assert_allclose(count, probs, atol=0.1)
67+
68+
69+
def test_invalid_context_raises():
70+
c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324)
71+
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)
72+
with pytest.raises(ValueError):
73+
_ = transformer(c, context=cirq.TransformerContext(deep=True))
74+
75+
76+
def test_transformer_ignores_tagged_ops():
77+
op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324
78+
c = cirq.Circuit(op.with_tags('ignore'))
79+
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)
80+
81+
assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c
82+
83+
84+
def test_transformer_ignores_tagged_moments():
85+
op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324
86+
c = cirq.Circuit(cirq.Moment(op).with_tags('ignore'))
87+
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)
88+
89+
assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c

0 commit comments

Comments
 (0)