Skip to content

Commit 16b74b6

Browse files
committed
Add first draft of ZXTransformer to contrib.
This is a custom transformer which uses ZX-calculus through the PyZX library to perform circuit optimisation.
1 parent bc4cd6d commit 16b74b6

File tree

5 files changed

+308
-0
lines changed

5 files changed

+308
-0
lines changed

cirq-core/cirq/contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from cirq.contrib.qcircuit import circuit_to_latex_using_qcircuit
2525
from cirq.contrib import json
2626
from cirq.contrib.circuitdag import CircuitDag, Unique
27+
from cirq.contrib.zxtransformer import zx_transformer

cirq-core/cirq/contrib/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ pylatex~=1.4
66
# quimb
77
quimb~=1.7
88
opt_einsum
9+
10+
# required for zxtransformer
11+
pyzx==0.8.0
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A custom transformer for Cirq which uses ZX-Calculus for circuit optimization, implemented using
16+
PyZX."""
17+
18+
from cirq.contrib.zxtransformer.zxtransformer import zx_transformer
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A custom transformer for Cirq which uses ZX-Calculus for circuit optimization, implemented
16+
using PyZX."""
17+
18+
import functools
19+
from typing import List, Callable, Optional, Union
20+
from fractions import Fraction
21+
22+
import cirq
23+
from cirq import circuits
24+
25+
import pyzx as zx
26+
from pyzx.circuit import gates as zx_gates
27+
28+
29+
@functools.cache
30+
def _cirq_to_pyzx():
31+
return {
32+
cirq.H: zx_gates.HAD,
33+
cirq.CZ: zx_gates.CZ,
34+
cirq.CNOT: zx_gates.CNOT,
35+
cirq.SWAP: zx_gates.SWAP,
36+
cirq.CCZ: zx_gates.CCZ,
37+
}
38+
39+
40+
def cirq_gate_to_zx_gate(
41+
cirq_gate: Optional[cirq.Gate], qubits: List[int]
42+
) -> Optional[zx_gates.Gate]:
43+
"""Convert a Cirq gate to a PyZX gate."""
44+
45+
if isinstance(cirq_gate, (cirq.Rx, cirq.XPowGate)):
46+
return zx_gates.XPhase(*qubits, phase=Fraction(cirq_gate.exponent))
47+
if isinstance(cirq_gate, (cirq.Ry, cirq.YPowGate)):
48+
return zx_gates.YPhase(*qubits, phase=Fraction(cirq_gate.exponent))
49+
if isinstance(cirq_gate, (cirq.Rz, cirq.ZPowGate)):
50+
return zx_gates.ZPhase(*qubits, phase=Fraction(cirq_gate.exponent))
51+
52+
# TODO: Deal with exponents other than nice ones.
53+
if (gate := _cirq_to_pyzx().get(cirq_gate, None)) is not None:
54+
return gate(*qubits)
55+
56+
return None
57+
58+
59+
cirq_gate_table = {
60+
'rx': cirq.XPowGate,
61+
'ry': cirq.YPowGate,
62+
'rz': cirq.ZPowGate,
63+
'h': cirq.HPowGate,
64+
'cx': cirq.CXPowGate,
65+
'cz': cirq.CZPowGate,
66+
'swap': cirq.SwapPowGate,
67+
'ccz': cirq.CCZPowGate,
68+
}
69+
70+
71+
def _cirq_to_circuits_and_ops(
72+
circuit: circuits.AbstractCircuit, qubits: List[cirq.Qid]
73+
) -> List[Union[zx.Circuit, cirq.Operation]]:
74+
"""Convert an AbstractCircuit to a list of PyZX Circuits and cirq.Operations. As much of the
75+
AbstractCircuit is converted to PyZX as possible, but some gates are not supported by PyZX and
76+
are left as cirq.Operations.
77+
78+
:param circuit: The AbstractCircuit to convert.
79+
:return: A list of PyZX Circuits and cirq.Operations corresponding to the AbstractCircuit.
80+
"""
81+
circuits_and_ops: List[Union[zx.Circuit, cirq.Operation]] = []
82+
qubit_to_index = {qubit: index for index, qubit in enumerate(qubits)}
83+
current_circuit: Optional[zx.Circuit] = None
84+
for moment in circuit:
85+
for op in moment:
86+
gate_qubits = [qubit_to_index[qarg] for qarg in op.qubits]
87+
gate = cirq_gate_to_zx_gate(op.gate, gate_qubits)
88+
if not gate:
89+
# Encountered an operation not supported by PyZX, so just store it.
90+
# Flush the current PyZX Circuit first if there is one.
91+
if current_circuit is not None:
92+
circuits_and_ops.append(current_circuit)
93+
current_circuit = None
94+
circuits_and_ops.append(op)
95+
continue
96+
97+
if current_circuit is None:
98+
current_circuit = zx.Circuit(len(qubits))
99+
current_circuit.add_gate(gate)
100+
101+
# Flush any remaining PyZX Circuit.
102+
if current_circuit is not None:
103+
circuits_and_ops.append(current_circuit)
104+
105+
return circuits_and_ops
106+
107+
108+
def _recover_circuit(
109+
circuits_and_ops: List[Union[zx.Circuit, cirq.Operation]], qubits: List[cirq.Qid]
110+
) -> circuits.Circuit:
111+
"""Recovers a cirq.Circuit from a list of PyZX Circuits and cirq.Operations.
112+
113+
:param circuits_and_ops: The list of (optimized) PyZX Circuits and cirq.Operations from which to
114+
recover the cirq.Circuit.
115+
:return: An optimized version of the original input circuit to ZXTransformer.
116+
:raises ValueError: If an unsupported gate has been encountered.
117+
"""
118+
cirq_circuit = circuits.Circuit()
119+
for circuit_or_op in circuits_and_ops:
120+
if isinstance(circuit_or_op, cirq.Operation):
121+
cirq_circuit.append(circuit_or_op)
122+
continue
123+
for gate in circuit_or_op.gates:
124+
gate_name = (
125+
gate.qasm_name
126+
if not (hasattr(gate, 'adjoint') and gate.adjoint)
127+
else gate.qasm_name_adjoint
128+
)
129+
gate_type = cirq_gate_table[gate_name]
130+
if gate_type is None:
131+
raise ValueError(f"Unsupported gate: {gate_name}.")
132+
qargs: List[cirq.Qid] = []
133+
for attr in ['ctrl1', 'ctrl2', 'control', 'target']:
134+
if hasattr(gate, attr):
135+
qargs.append(qubits[getattr(gate, attr)])
136+
params: List[float] = []
137+
if hasattr(gate, 'phase'):
138+
params = [float(gate.phase)]
139+
elif hasattr(gate, 'phases'):
140+
params = [float(phase) for phase in gate.phases]
141+
elif gate_name in ('h', 'cz', 'cx', 'swap', 'ccz'):
142+
params = [1.0]
143+
cirq_circuit.append(gate_type(exponent=params[0])(*qargs))
144+
return cirq_circuit
145+
146+
147+
def _optimize(c: zx.Circuit) -> zx.Circuit:
148+
g = c.to_graph()
149+
zx.simplify.full_reduce(g)
150+
return zx.extract.extract_circuit(g)
151+
152+
153+
@cirq.transformer
154+
def zx_transformer(
155+
circuit: circuits.AbstractCircuit,
156+
context: Optional[cirq.TransformerContext] = None,
157+
optimizer: Callable[[zx.Circuit], zx.Circuit] = _optimize,
158+
) -> circuits.Circuit:
159+
"""Perform circuit optimization using pyzx.
160+
161+
Args:
162+
circuit: 'cirq.Circuit' input circuit to transform.
163+
context: `cirq.TransformerContext` storing common configurable
164+
options for transformers.
165+
optimizer: The optimization routine to execute. Defaults to `pyzx.simplify.full_reduce` if
166+
not specified.
167+
168+
Returns:
169+
The modified circuit after optimization.
170+
"""
171+
qubits: List[cirq.Qid] = [*circuit.all_qubits()]
172+
173+
circuits_and_ops = _cirq_to_circuits_and_ops(circuit, qubits)
174+
if not circuits_and_ops:
175+
copied_circuit = circuit.unfreeze(copy=True)
176+
return copied_circuit
177+
178+
circuits_and_ops = [
179+
optimizer(circuit) if isinstance(circuit, zx.Circuit) else circuit
180+
for circuit in circuits_and_ops
181+
]
182+
183+
return _recover_circuit(circuits_and_ops, qubits)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2024 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Cirq ZX transformer."""
16+
17+
from typing import Optional, Callable
18+
19+
import cirq
20+
import pyzx as zx
21+
22+
from cirq.contrib.zxtransformer.zxtransformer import zx_transformer, _cirq_to_circuits_and_ops
23+
24+
25+
def _run_zxtransformer(
26+
qc: cirq.Circuit, optimizer: Optional[Callable[[zx.Circuit], zx.Circuit]] = None
27+
) -> None:
28+
zx_qc = zx_transformer(qc) if optimizer is None else zx_transformer(qc, optimizer=optimizer)
29+
qubit_map = {qid: qid for qid in qc.all_qubits()}
30+
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(qc, zx_qc, qubit_map)
31+
32+
33+
def test_basic_circuit() -> None:
34+
"""Test a basic circuit.
35+
36+
Taken from https://github.com/Quantomatic/pyzx/blob/master/circuits/Fast/mod5_4_before
37+
"""
38+
q = cirq.LineQubit.range(5)
39+
circuit = cirq.Circuit(
40+
cirq.X(q[4]),
41+
cirq.H(q[4]),
42+
cirq.CCZ(q[0], q[3], q[4]),
43+
cirq.CCZ(q[2], q[3], q[4]),
44+
cirq.H(q[4]),
45+
cirq.CX(q[3], q[4]),
46+
cirq.H(q[4]),
47+
cirq.CCZ(q[1], q[2], q[4]),
48+
cirq.H(q[4]),
49+
cirq.CX(q[2], q[4]),
50+
cirq.H(q[4]),
51+
cirq.CCZ(q[0], q[1], q[4]),
52+
cirq.H(q[4]),
53+
cirq.CX(q[1], q[4]),
54+
cirq.CX(q[0], q[4]),
55+
)
56+
57+
_run_zxtransformer(circuit)
58+
59+
60+
def test_fractional_gates() -> None:
61+
"""Test a circuit with gates which have a fractional phase."""
62+
q = cirq.NamedQubit("q")
63+
circuit = cirq.Circuit(cirq.ry(0.5)(q),
64+
cirq.rz(0.5)(q))
65+
_run_zxtransformer(circuit)
66+
67+
68+
def test_custom_optimize() -> None:
69+
"""Test custom optimize method."""
70+
q = cirq.LineQubit.range(4)
71+
circuit = cirq.Circuit(
72+
cirq.H(q[0]),
73+
cirq.H(q[1]),
74+
cirq.H(q[2]),
75+
cirq.H(q[3]),
76+
cirq.CX(q[0], q[1]),
77+
cirq.CX(q[1], q[2]),
78+
cirq.CX(q[2], q[3]),
79+
cirq.CX(q[3], q[0]),
80+
)
81+
82+
def optimize(circ: zx.Circuit) -> zx.Circuit:
83+
# Any function that takes a zx.Circuit and returns a zx.Circuit will do.
84+
return circ.to_basic_gates()
85+
86+
_run_zxtransformer(circuit, optimize)
87+
88+
89+
def test_measurement() -> None:
90+
"""Test a circuit with a measurement."""
91+
q = cirq.NamedQubit("q")
92+
circuit = cirq.Circuit(cirq.H(q), cirq.measure(q, key='c'), cirq.H(q))
93+
circuits_and_ops = _cirq_to_circuits_and_ops(circuit, [*circuit.all_qubits()])
94+
assert len(circuits_and_ops) == 3
95+
assert circuits_and_ops[1] == cirq.measure(q, key='c')
96+
97+
98+
def test_conditional_gate() -> None:
99+
"""Test a circuit with a conditional gate."""
100+
q = cirq.NamedQubit("q")
101+
circuit = cirq.Circuit(cirq.X(q), cirq.H(q).with_classical_controls('c'), cirq.X(q))
102+
circuits_and_ops = _cirq_to_circuits_and_ops(circuit, [*circuit.all_qubits()])
103+
assert len(circuits_and_ops) == 3

0 commit comments

Comments
 (0)