Skip to content

Add a new transformer that performs random pauli insertion #7558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@
from cirq.transformers.insertion_sort import (
insertion_sort_transformer as insertion_sort_transformer,
)


from cirq.transformers.pauli_insertion import PauliInsertionTransformer as PauliInsertionTransformer
133 changes: 133 additions & 0 deletions cirq-core/cirq/transformers/pauli_insertion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A pauli insertion transformer."""

from __future__ import annotations

import inspect
from collections.abc import Mapping

import numpy as np

from cirq import circuits, ops
from cirq.transformers import transformer_api

_PAULIS: tuple[ops.Gate, ops.Gate, ops.Gate, ops.Gate] = (ops.I, ops.X, ops.Y, ops.Z) # type: ignore[has-type]


@transformer_api.transformer
class PauliInsertionTransformer:
r"""Creates a pauli insertion transformer.

A pauli insertion operation samples paulis from $\{I, X, Y, Z\}^2$ with the given
probabilities and adds it before the target 2Q gate/operation. This procedure is commonly
used in zero noise extrapolation (ZNE), see appendix D of https://arxiv.org/abs/2503.20870.
"""

def __init__(
self,
target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate],
probabilities: np.ndarray | Mapping[tuple[ops.Qid, ops.Qid], np.ndarray] | None = None,
):
"""Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities.

Args:
target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte).
probabilities: Optional ndarray or mapping[qubit-pair, nndarray] representing the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should say more clearly that probabilities contains 4x4 arrays and the [i,j] element is the probability of applying _PAULIS[i] to qubit 0 and _PAULIS[j] to qubit 1, where the two qubits now (if you make the other change I suggest) are in the order specified in the key of the dictionary.

probabilities of sampling 2Q paulis. The order of the paulis is IXYZ.
If at operation `op` a pair (i, j) is sampled then _PAULIS[i] is applied
to op.qubits[0] and _PAULIS[j] is applied to op.qubits[1].
If None, assume uniform distribution.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A uniform distribution would completely depolarize the state, which is probably not what we want. I think probabilities should be required.

Comment on lines +42 to +52
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should require that probabilities be a mapping from qubit pairs to numpy arrays and use the order in the qubit pair (not in the op, which is less easily accessible to the user) to determine which qubit is which in probabilities.

"""
if probabilities is None:
probabilities = np.ones((4, 4)) / 16
elif isinstance(probabilities, dict):
probabilities = {k: np.asarray(v) for k, v in probabilities.items()}
for probs in probabilities.values():
assert np.isclose(probs.sum(), 1)
assert probs.shape == (4, 4)
Comment on lines +59 to +60
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change these to ValueErrors and document them in a Raises section of the docstring? Also check that none of them are negative.

else:
probabilities = np.asarray(probabilities)
assert np.isclose(probabilities.sum(), 1)
assert probabilities.shape == (4, 4)
self.probabilities = probabilities

if inspect.isclass(target):
self.target: ops.GateFamily | ops.Gateset = ops.GateFamily(target)
elif isinstance(target, ops.Gate):
self.target = ops.Gateset(target)
else:
assert isinstance(target, (ops.Gateset, ops.GateFamily))
self.target = target

def _is_target(self, op: ops.Operation) -> bool:
if isinstance(self.probabilities, dict) and op.qubits not in self.probabilities:
return False
return op in self.target

def _sample(
self, qubits: tuple[ops.Qid, ...], rng: np.random.Generator
) -> tuple[ops.Gate, ops.Gate]:
if isinstance(self.probabilities, dict):
assert len(qubits) == 2
flat_probs = self.probabilities[qubits].reshape(-1)
else:
flat_probs = self.probabilities.reshape(-1)
i, j = np.unravel_index(rng.choice(16, p=flat_probs), (4, 4))
return _PAULIS[i], _PAULIS[j]

def __call__(
self,
circuit: circuits.AbstractCircuit,
*,
rng_or_seed: np.random.Generator | int | None = None,
context: transformer_api.TransformerContext | None = None,
):
context = (
context
if isinstance(context, transformer_api.TransformerContext)
else transformer_api.TransformerContext()
)
rng = (
rng_or_seed
if isinstance(rng_or_seed, np.random.Generator)
else np.random.default_rng(rng_or_seed)
)

if context.deep:
raise ValueError(f"this transformer doesn't support deep {context=}")

tags_to_ignore = frozenset(context.tags_to_ignore)
new_circuit: list[circuits.Moment] = []
for moment in circuit:
if any(tag in tags_to_ignore for tag in moment.tags):
new_circuit.append(moment)
continue
new_moment = []
for op in moment:
if any(tag in tags_to_ignore for tag in op.tags):
continue
if not self._is_target(op):
continue
pair = self._sample(op.qubits, rng)
for pauli, q in zip(pair, op.qubits):
if new_circuit and (q not in new_circuit[-1].qubits):
new_circuit[-1] += pauli(q)
else:
new_moment.append(pauli(q))
if new_moment:
new_circuit.append(circuits.Moment(new_moment))
new_circuit.append(moment)
return circuits.Circuit.from_moments(*new_circuit)
104 changes: 104 additions & 0 deletions cirq-core/cirq/transformers/pauli_insertion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

import cirq

_PAULIS = [cirq.I, cirq.X, cirq.Y, cirq.Z]


def _random_probs(n: int, seed: int | None = None):
rng = np.random.default_rng(seed)
for _ in range(n):
probs = rng.random((4, 4))
probs /= probs.sum()
yield probs


@pytest.mark.parametrize('probs', _random_probs(3, 0))
@pytest.mark.parametrize(
'target',
[cirq.ZZPowGate, cirq.ZZ**0.324, cirq.Gateset(cirq.ZZ**0.324), cirq.GateFamily(cirq.ZZ**0.324)],
)
def test_pauli_insertion_with_probabilities(probs, target):
c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324)
transformer = cirq.transformers.PauliInsertionTransformer(target, probs)
count = np.zeros((4, 4))
rng = np.random.default_rng(0)
for _ in range(100):
nc = transformer(c, rng_or_seed=rng)
assert len(nc) == 2
u, v = nc[0]
i = _PAULIS.index(u.gate)
j = _PAULIS.index(v.gate)
count[i, j] += 1
count = count / count.sum()
np.testing.assert_allclose(count, probs, atol=0.1)


@pytest.mark.parametrize('probs', _random_probs(3, 0))
def test_pauli_insertion_with_probabilities_doesnot_create_moment(probs):
c = cirq.Circuit.from_moments([], [cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324])
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs)
count = np.zeros((4, 4))
rng = np.random.default_rng(0)
for _ in range(100):
nc = transformer(c, rng_or_seed=rng)
assert len(nc) == 2
u, v = nc[0]
i = _PAULIS.index(u.gate)
j = _PAULIS.index(v.gate)
count[i, j] += 1
count = count / count.sum()
np.testing.assert_allclose(count, probs, atol=0.1)


def test_invalid_context_raises():
c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324)
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)
with pytest.raises(ValueError):
_ = transformer(c, context=cirq.TransformerContext(deep=True))


def test_transformer_ignores_tagged_ops():
op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324
c = cirq.Circuit(op.with_tags('ignore'))
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)

assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c


def test_transformer_ignores_tagged_moments():
op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324
c = cirq.Circuit(cirq.Moment(op).with_tags('ignore'))
transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate)

assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c


def test_transformer_ignores_with_probs_map():
qs = tuple(cirq.LineQubit.range(3))
op = cirq.ZZ(*qs[:2]) ** 0.324
c = cirq.Circuit(cirq.Moment(op))
transformer = cirq.transformers.PauliInsertionTransformer(
cirq.ZZPowGate, {qs[1:]: np.ones((4, 4)) / 16}
)

assert transformer(c) == c # qubits are not in target

c = cirq.Circuit(cirq.Moment(op.with_qubits(*qs[1:])))
nc = transformer(c)
assert len(nc) == 2