Skip to content

Commit 0b6a9a4

Browse files
address comments
1 parent 9c2064f commit 0b6a9a4

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

cirq-core/cirq/transformers/pauli_insertion.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import inspect
20+
from collections.abc import Mapping
2021

2122
import numpy as np
2223

@@ -38,22 +39,30 @@ class PauliInsertionTransformer:
3839
def __init__(
3940
self,
4041
target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate],
41-
probabilities: np.ndarray | None = None,
42+
probabilities: np.ndarray | Mapping[tuple[ops.Qid, ops.Qid], np.ndarray] | None = None,
4243
):
4344
"""Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities.
4445
4546
Args:
4647
target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte).
47-
probabilities: Optional ndarray representing the probabilities of sampling 2Q paulis.
48-
The order of the paulis is IXYZ. If None, assume uniform distribution.
49-
Returns:
50-
A gauge transformer.
48+
probabilities: Optional ndarray or mapping[qubit-pair, nndarray] representing the
49+
probabilities of sampling 2Q paulis. The order of the paulis is IXYZ.
50+
If at operation `op` a pair (i, j) is sampled then _PAULIS[i] is applied
51+
to op.qubits[0] and _PAULIS[j] is applied to op.qubits[1].
52+
If None, assume uniform distribution.
5153
"""
5254
if probabilities is None:
5355
probabilities = np.ones((4, 4)) / 16
54-
probabilities = np.asarray(probabilities)
55-
assert probabilities.shape == (4, 4)
56-
assert np.isclose(probabilities.sum(), 1)
56+
elif isinstance(probabilities, dict):
57+
probabilities = {k: np.asarray(v) for k, v in probabilities.items()}
58+
for probs in probabilities.values():
59+
assert np.isclose(probs.sum(), 1)
60+
assert probs.shape == (4, 4)
61+
else:
62+
probabilities = np.asarray(probabilities)
63+
assert np.isclose(probabilities.sum(), 1)
64+
assert probabilities.shape == (4, 4)
65+
self.probabilities = probabilities
5766

5867
if inspect.isclass(target):
5968
self.target = ops.GateFamily(target)
@@ -62,7 +71,21 @@ def __init__(
6271
else:
6372
assert isinstance(target, (ops.Gateset, ops.GateFamily))
6473
self.target = target
65-
self._flat_probs = probabilities.reshape(-1)
74+
75+
def _is_target(self, op: ops.Operation) -> bool:
76+
if isinstance(self.probabilities, dict) and op.qubits not in self.probabilities:
77+
return False
78+
return op in self.target
79+
80+
def _sample(
81+
self, qubits: tuple[ops.Qid, ops.Qid], rng: np.random.Generator
82+
) -> tuple[ops.Gate, ops.Gate]:
83+
if isinstance(self.probabilities, dict):
84+
flat_probs = self.probabilities[qubits].reshape(-1)
85+
else:
86+
flat_probs = self.probabilities.reshape(-1)
87+
i, j = np.unravel_index(rng.choice(16, p=flat_probs), (4, 4))
88+
return _PAULIS[i], _PAULIS[j]
6689

6790
def __call__(
6891
self,
@@ -95,14 +118,14 @@ def __call__(
95118
for op in moment:
96119
if any(tag in tags_to_ignore for tag in op.tags):
97120
continue
98-
if op not in self.target:
121+
if not self._is_target(op):
99122
continue
100-
pair = np.unravel_index(rng.choice(16, p=self._flat_probs), (4, 4))
101-
for pauli_index, q in zip(pair, op.qubits):
123+
pair = self._sample(op.qubits, rng)
124+
for pauli, q in zip(pair, op.qubits):
102125
if new_circuit and (q not in new_circuit[-1].qubits):
103-
new_circuit[-1] += _PAULIS[pauli_index](q)
126+
new_circuit[-1] += pauli(q)
104127
else:
105-
new_moment.append(_PAULIS[pauli_index](q))
128+
new_moment.append(pauli(q))
106129
if new_moment:
107130
new_circuit.append(circuits.Moment(new_moment))
108131
new_circuit.append(moment)

cirq-core/cirq/transformers/pauli_insertion_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,18 @@ def test_transformer_ignores_tagged_moments():
8787
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)
8888

8989
assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c
90+
91+
92+
def test_transformer_ignores_with_probs_map():
93+
qs = tuple(cirq.LineQubit.range(3))
94+
op = cirq.ZZ(*qs[:2]) ** 0.324
95+
c = cirq.Circuit(cirq.Moment(op))
96+
transformer = cirq.transformers.PauliInsertionTransformer(
97+
cirq.ZZPowGate, {qs[1:]: np.ones((4, 4)) / 16}
98+
)
99+
100+
assert transformer(c) == c # qubits are not in target
101+
102+
c = cirq.Circuit(cirq.Moment(op.with_qubits(*qs[1:])))
103+
nc = transformer(c)
104+
assert len(nc) == 2

0 commit comments

Comments
 (0)