23
23
from cirq import circuits , ops
24
24
from cirq .transformers import transformer_api
25
25
26
- _PAULIS = [ops .I , ops .X , ops .Y , ops .Z ]
27
-
28
-
29
- def _is_target (
30
- op : ops .Operation ,
31
- target : ops .Gate | ops .GateFamily | ops .Gateset | type [ops .Gate | ops .Operation ],
32
- ):
33
- if inspect .isclass (target ):
34
- if issubclass (target , ops .Operation ):
35
- return isinstance (op , target )
36
- if not hasattr (op , 'gate' ):
37
- return False
38
- return isinstance (op .gate , target )
39
- if isinstance (target , ops .Gate ):
40
- if not hasattr (op , 'gate' ) or op .gate is None :
41
- return False
42
- return op .gate == target
43
- return op in target
26
+ _PAULIS : tuple [ops .Gate ] = (ops .I , ops .X , ops .Y , ops .Z ) # type: ignore[has-type]
44
27
45
28
46
29
@transformer_api .transformer
@@ -54,13 +37,13 @@ class PauliInsertionTransformer:
54
37
55
38
def __init__ (
56
39
self ,
57
- target : ops .Gate | ops .GateFamily | ops .Gateset | type [ops .Gate | ops . Operation ],
40
+ target : ops .Gate | ops .GateFamily | ops .Gateset | type [ops .Gate ],
58
41
probabilities : np .ndarray | None = None ,
59
42
):
60
43
"""Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities.
61
44
62
45
Args:
63
- target: The target gate, gatefamily, gateset, or type (e.g. PauliSumExponential ).
46
+ target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte ).
64
47
probabilities: Optional ndarray representing the probabilities of sampling 2Q paulis.
65
48
The order of the paulis is IXYZ. If None, assume uniform distribution.
66
49
Returns:
@@ -72,7 +55,13 @@ def __init__(
72
55
assert probabilities .shape == (4 , 4 )
73
56
assert np .isclose (probabilities .sum (), 1 )
74
57
75
- self .target = target
58
+ if inspect .isclass (target ):
59
+ self .target = ops .GateFamily (target )
60
+ elif isinstance (target , ops .Gate ):
61
+ self .target = ops .Gateset (target )
62
+ else :
63
+ assert isinstance (target , (ops .Gateset , ops .GateFamily ))
64
+ self .target = target
76
65
self ._flat_probs = probabilities .reshape (- 1 )
77
66
78
67
def __call__ (
@@ -106,7 +95,7 @@ def __call__(
106
95
for op in moment :
107
96
if any (tag in tags_to_ignore for tag in op .tags ):
108
97
continue
109
- if not _is_target ( op , self .target ) :
98
+ if op not in self .target :
110
99
continue
111
100
pair = np .unravel_index (rng .choice (16 , p = self ._flat_probs ), (4 , 4 ))
112
101
for pauli_index , q in zip (pair , op .qubits ):
0 commit comments