diff --git a/cirq-google/cirq_google/experimental/analog_experiments/__init__.py b/cirq-google/cirq_google/experimental/analog_experiments/__init__.py new file mode 100644 index 00000000000..13bda79ea9b --- /dev/null +++ b/cirq-google/cirq_google/experimental/analog_experiments/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Folder for Running Analog experiments.""" + +from cirq_google.experimental.analog_experiments.analog_trajectory_util import ( + FrequencyMap as FrequencyMap, + AnalogTrajectory as AnalogTrajectory, +) diff --git a/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py new file mode 100644 index 00000000000..4635aa4370e --- /dev/null +++ b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py @@ -0,0 +1,201 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AbstractSet, TYPE_CHECKING + +import attrs +import matplotlib.pyplot as plt +import numpy as np +import tunits as tu + +import cirq +from cirq_google.study import symbol_util as su + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + +@attrs.mutable +class FrequencyMap: + """Object containing information about the step to a new analog Hamiltonian. + + Attributes: + duration: duration of step + qubit_freqs: dict describing qubit frequencies at end of step (None if idle) + couplings: dict describing coupling rates at end of step + """ + + duration: su.ValueOrSymbol + qubit_freqs: dict[str, su.ValueOrSymbol | None] + couplings: dict[tuple[str, str], su.ValueOrSymbol] + + def _is_parameterized_(self) -> bool: + return ( + cirq.is_parameterized(self.duration) + or su.is_parameterized_dict(self.qubit_freqs) + or su.is_parameterized_dict(self.couplings) + ) + + def _parameter_names_(self) -> AbstractSet[str]: + return ( + cirq.parameter_names(self.duration) + | su.dict_param_name(self.qubit_freqs) + | su.dict_param_name(self.couplings) + ) + + def _resolve_parameters_( + self, resolver: cirq.ParamResolverOrSimilarType, recursive: bool + ) -> FrequencyMap: + resolver_ = cirq.ParamResolver(resolver) + return FrequencyMap( + duration=su.direct_symbol_replacement(self.duration, resolver_), + qubit_freqs={ + k: su.direct_symbol_replacement(v, resolver_) for k, v in self.qubit_freqs.items() + }, + couplings={ + k: su.direct_symbol_replacement(v, resolver_) for k, v in self.couplings.items() + }, + ) + + +class AnalogTrajectory: + """Class for handling qubit frequency and coupling trajectories that + define analog experiments. The class is defined using a sparse_trajectory, + which contains time durations of each Hamiltonian ramp element and the + corresponding qubit frequencies and couplings (unassigned qubits and/or + couplers are left unchanged). + """ + + def __init__( + self, + *, + full_trajectory: list[FrequencyMap], + qubits: list[str], + pairs: list[tuple[str, str]], + ): + self.full_trajectory = full_trajectory + self.qubits = qubits + self.pairs = pairs + + @classmethod + def from_sparse_trajectory( + cls, + sparse_trajectory: list[ + tuple[ + tu.Value, + dict[str, su.ValueOrSymbol | None], + dict[tuple[str, str], su.ValueOrSymbol], + ], + ], + qubits: list[str] | None = None, + pairs: list[tuple[str, str]] | None = None, + ): + """Construct AnalogTrajectory from sparse trajectory. + + Args: + sparse_trajectory: A list of tuples, where each tuple defines a `FrequencyMap` + and contains three elements: (duration, qubit_freqs, coupling_strengths). + `duration` is a tunits value, `qubit_freqs` is a dictionary mapping qubit strings + to detuning frequencies, and `coupling_strengths` is a dictionary mapping qubit + pairs to their coupling strength. This format is considered "sparse" because each + tuple does not need to fully specify all qubits and coupling pairs; any missing + detuning frequency or coupling strength will be set to the same value as the + previous value in the list. + qubits: The qubits in interest. If not provided, automatically parsed from trajectory. + pairs: The pairs in interest. If not provided, automatically parsed from trajectory. + """ + if qubits is None or pairs is None: + qubits_in_traj: list[str] = [] + pairs_in_traj: list[tuple[str, str]] = [] + for _, q, p in sparse_trajectory: + qubits_in_traj.extend(q.keys()) + pairs_in_traj.extend(p.keys()) + qubits = list(set(qubits_in_traj)) + pairs = list(set(pairs_in_traj)) + + full_trajectory: list[FrequencyMap] = [] + init_qubit_freq_dict: dict[str, tu.Value | None] = {q: None for q in qubits} + init_g_dict: dict[tuple[str, str], tu.Value] = {p: 0 * tu.MHz for p in pairs} + full_trajectory.append(FrequencyMap(0 * tu.ns, init_qubit_freq_dict, init_g_dict)) + + for dt, qubit_freq_dict, g_dict in sparse_trajectory: + # If no freq provided, set equal to previous + new_qubit_freq_dict = { + q: qubit_freq_dict.get(q, full_trajectory[-1].qubit_freqs.get(q)) for q in qubits + } + # If no g provided, set equal to previous + new_g_dict: dict[tuple[str, str], tu.Value] = { + p: g_dict.get(p, full_trajectory[-1].couplings.get(p)) for p in pairs # type: ignore[misc] + } + + full_trajectory.append(FrequencyMap(dt, new_qubit_freq_dict, new_g_dict)) + return cls(full_trajectory=full_trajectory, qubits=qubits, pairs=pairs) + + def get_full_trajectory_with_resolved_idles( + self, idle_freq_map: dict[str, tu.Value] + ) -> list[FrequencyMap]: + """Insert idle frequencies instead of None in trajectory.""" + + resolved_trajectory: list[FrequencyMap] = [] + for freq_map in self.full_trajectory: + resolved_qubit_freqs = { + q: idle_freq_map[q] if f is None else f for q, f in freq_map.qubit_freqs.items() + } + resolved_trajectory.append(attrs.evolve(freq_map, qubit_freqs=resolved_qubit_freqs)) + return resolved_trajectory + + def plot( + self, + idle_freq_map: dict[str, tu.Value] | None = None, + default_idle_freq: tu.Value = 6.5 * tu.GHz, + resolver: cirq.ParamResolverOrSimilarType | None = None, + axes: tuple[Axes, Axes] | None = None, + ) -> tuple[Axes, Axes]: + if idle_freq_map is None: + idle_freq_map = {q: default_idle_freq for q in self.qubits} + full_trajectory_resolved = cirq.resolve_parameters( + self.get_full_trajectory_with_resolved_idles(idle_freq_map), resolver + ) + unresolved_param_names = set().union( + *[cirq.parameter_names(freq_map) for freq_map in full_trajectory_resolved] + ) + if unresolved_param_names: + raise ValueError(f"There are some parameters {unresolved_param_names} not resolved.") + + times = np.cumsum([step.duration[tu.ns] for step in full_trajectory_resolved]) + + if axes is None: + _, axes = plt.subplots(1, 2, figsize=(10, 4)) + + for qubit_agent in self.qubits: + axes[0].plot( + times, + [step.qubit_freqs[qubit_agent][tu.GHz] for step in full_trajectory_resolved], # type: ignore[index] + label=qubit_agent, + ) + for pair_agent in self.pairs: + axes[1].plot( + times, + [step.couplings[pair_agent][tu.MHz] for step in full_trajectory_resolved], + label=pair_agent, + ) + + for ax, ylabel in zip(axes, ["Qubit freq. (GHz)", "Coupling (MHz)"]): + ax.set_xlabel("Time (ns)") + ax.set_ylabel(ylabel) + ax.legend() + plt.tight_layout() + return axes diff --git a/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py new file mode 100644 index 00000000000..e57dfd58411 --- /dev/null +++ b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py @@ -0,0 +1,131 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy +import tunits as tu + +import cirq +from cirq_google.experimental.analog_experiments import analog_trajectory_util as atu + + +@pytest.fixture +def freq_map() -> atu.FrequencyMap: + return atu.FrequencyMap( + 10 * tu.ns, + {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": sympy.Symbol("f_q0_2")}, + {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): sympy.Symbol("g_q0_1_q0_2")}, + ) + + +def test_freq_map_param_names(freq_map: atu.FrequencyMap) -> None: + assert cirq.is_parameterized(freq_map) + assert cirq.parameter_names(freq_map) == {"f_q0_2", "g_q0_1_q0_2"} + + +def test_freq_map_resolve(freq_map: atu.FrequencyMap) -> None: + resolved_freq_map = cirq.resolve_parameters( + freq_map, {"f_q0_2": 6 * tu.GHz, "g_q0_1_q0_2": 7 * tu.MHz} + ) + assert resolved_freq_map == atu.FrequencyMap( + 10 * tu.ns, + {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 6 * tu.GHz}, + {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 7 * tu.MHz}, + ) + + +FreqMapType = tuple[tu.Value, dict[str, tu.Value | None], dict[tuple[str, str], tu.Value]] + + +@pytest.fixture +def sparse_trajectory() -> list[FreqMapType]: + traj1: FreqMapType = (20 * tu.ns, {"q0_1": 5 * tu.GHz}, {}) + traj2: FreqMapType = (30 * tu.ns, {"q0_2": 8 * tu.GHz}, {}) + traj3: FreqMapType = ( + 40 * tu.ns, + {"q0_0": 8 * tu.GHz, "q0_1": None, "q0_2": None}, + {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 8 * tu.MHz}, + ) + return [traj1, traj2, traj3] + + +def test_full_traj(sparse_trajectory: list[FreqMapType]) -> None: + analog_traj = atu.AnalogTrajectory.from_sparse_trajectory(sparse_trajectory) + assert len(analog_traj.full_trajectory) == 4 + assert analog_traj.full_trajectory[0] == atu.FrequencyMap( + 0 * tu.ns, + {"q0_0": None, "q0_1": None, "q0_2": None}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + ) + assert analog_traj.full_trajectory[1] == atu.FrequencyMap( + 20 * tu.ns, + {"q0_0": None, "q0_1": 5 * tu.GHz, "q0_2": None}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + ) + assert analog_traj.full_trajectory[2] == atu.FrequencyMap( + 30 * tu.ns, + {"q0_0": None, "q0_1": 5 * tu.GHz, "q0_2": 8 * tu.GHz}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + ) + assert analog_traj.full_trajectory[3] == atu.FrequencyMap( + 40 * tu.ns, + {"q0_0": 8 * tu.GHz, "q0_1": None, "q0_2": None}, + {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 8 * tu.MHz}, + ) + + +def test_get_full_trajectory_with_resolved_idles(sparse_trajectory: list[FreqMapType]) -> None: + + analog_traj = atu.AnalogTrajectory.from_sparse_trajectory(sparse_trajectory) + resolved_full_traj = analog_traj.get_full_trajectory_with_resolved_idles( + {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 7 * tu.GHz} + ) + + assert len(resolved_full_traj) == 4 + assert resolved_full_traj[0] == atu.FrequencyMap( + 0 * tu.ns, + {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 7 * tu.GHz}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + ) + assert resolved_full_traj[1] == atu.FrequencyMap( + 20 * tu.ns, + {"q0_0": 5 * tu.GHz, "q0_1": 5 * tu.GHz, "q0_2": 7 * tu.GHz}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + ) + assert resolved_full_traj[2] == atu.FrequencyMap( + 30 * tu.ns, + {"q0_0": 5 * tu.GHz, "q0_1": 5 * tu.GHz, "q0_2": 8 * tu.GHz}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + ) + assert resolved_full_traj[3] == atu.FrequencyMap( + 40 * tu.ns, + {"q0_0": 8 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 7 * tu.GHz}, + {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 8 * tu.MHz}, + ) + + +def test_plot_with_unresolved_parameters(): + traj1: FreqMapType = (20 * tu.ns, {"q0_1": sympy.Symbol("qf")}, {}) + traj2: FreqMapType = (sympy.Symbol("t"), {"q0_2": 8 * tu.GHz}, {}) + analog_traj = atu.AnalogTrajectory.from_sparse_trajectory([traj1, traj2]) + + with pytest.raises(ValueError): + analog_traj.plot() + + +def test_analog_traj_plot(): + traj1: FreqMapType = (5 * tu.ns, {"q0_1": sympy.Symbol("qf")}, {("q0_0", "q0_1"): 2 * tu.MHz}) + traj2: FreqMapType = (sympy.Symbol("t"), {"q0_2": 8 * tu.GHz}, {}) + analog_traj = atu.AnalogTrajectory.from_sparse_trajectory([traj1, traj2]) + analog_traj.plot(resolver={"t": 10 * tu.ns, "qf": 5 * tu.GHz}) diff --git a/cirq-google/cirq_google/ops/analog_detune_gates.py b/cirq-google/cirq_google/ops/analog_detune_gates.py index dbd1f095812..86f631705d5 100644 --- a/cirq-google/cirq_google/ops/analog_detune_gates.py +++ b/cirq-google/cirq_google/ops/analog_detune_gates.py @@ -15,24 +15,14 @@ """Define detuning gates for Analog Experiment usage.""" from __future__ import annotations -from typing import AbstractSet, Any, TYPE_CHECKING, TypeAlias - -import sympy -import tunits as tu +from typing import AbstractSet, Any, TYPE_CHECKING import cirq +from cirq_google.study import symbol_util as su if TYPE_CHECKING: import numpy as np -# The gate is intended for the google internal use, hence the typing style -# follows more on the t-unit + symbol instead of float + symbol style. -ValueOrSymbol: TypeAlias = tu.Value | sympy.Basic -FloatOrSymbol: TypeAlias = float | sympy.Basic - -# A sentile for not finding the key in resolver. -NOT_FOUND = "__NOT_FOUND__" - @cirq.value_equality(approximate=True) class AnalogDetuneQubit(cirq.ops.Gate): @@ -60,12 +50,12 @@ class AnalogDetuneQubit(cirq.ops.Gate): def __init__( self, - length: ValueOrSymbol, - w: ValueOrSymbol, - target_freq: ValueOrSymbol | None = None, - prev_freq: ValueOrSymbol | None = None, - neighbor_coupler_g_dict: dict[str, ValueOrSymbol] | None = None, - prev_neighbor_coupler_g_dict: dict[str, ValueOrSymbol] | None = None, + length: su.ValueOrSymbol, + w: su.ValueOrSymbol, + target_freq: su.ValueOrSymbol | None = None, + prev_freq: su.ValueOrSymbol | None = None, + neighbor_coupler_g_dict: dict[str, su.ValueOrSymbol] | None = None, + prev_neighbor_coupler_g_dict: dict[str, su.ValueOrSymbol] | None = None, linear_rise: bool = True, ): """Inits AnalogDetuneQubit. @@ -97,58 +87,37 @@ def num_qubits(self) -> int: return 1 def _is_parameterized_(self) -> bool: - def _is_parameterized_dict(dict_with_value: dict[str, ValueOrSymbol] | None) -> bool: - if dict_with_value is None: - return False # pragma: no cover - return any(cirq.is_parameterized(v) for v in dict_with_value.values()) - return ( cirq.is_parameterized(self.length) or cirq.is_parameterized(self.w) or cirq.is_parameterized(self.target_freq) or cirq.is_parameterized(self.prev_freq) - or _is_parameterized_dict(self.neighbor_coupler_g_dict) - or _is_parameterized_dict(self.prev_neighbor_coupler_g_dict) + or su.is_parameterized_dict(self.neighbor_coupler_g_dict) + or su.is_parameterized_dict(self.prev_neighbor_coupler_g_dict) ) def _parameter_names_(self) -> AbstractSet[str]: - def dict_param_name(dict_with_value: dict[str, ValueOrSymbol] | None) -> AbstractSet[str]: - if dict_with_value is None: - return set() - return {v.name for v in dict_with_value.values() if cirq.is_parameterized(v)} - return ( cirq.parameter_names(self.length) | cirq.parameter_names(self.w) | cirq.parameter_names(self.target_freq) | cirq.parameter_names(self.prev_freq) - | dict_param_name(self.neighbor_coupler_g_dict) - | dict_param_name(self.prev_neighbor_coupler_g_dict) + | su.dict_param_name(self.neighbor_coupler_g_dict) + | su.dict_param_name(self.prev_neighbor_coupler_g_dict) ) def _resolve_parameters_( self, resolver: cirq.ParamResolverOrSimilarType, recursive: bool ) -> AnalogDetuneQubit: - # A shortcut for value resolution to avoid tu.unit compare with float issue. - def _direct_symbol_replacement(x, resolver: cirq.ParamResolver): - if isinstance(x, sympy.Symbol): - value = resolver.param_dict.get(x.name, NOT_FOUND) - if value == NOT_FOUND: - value = resolver.param_dict.get(x, NOT_FOUND) - if value != NOT_FOUND: - return value - return x # pragma: no cover - return x - resolver_ = cirq.ParamResolver(resolver) return AnalogDetuneQubit( - length=_direct_symbol_replacement(self.length, resolver_), - w=_direct_symbol_replacement(self.w, resolver_), - target_freq=_direct_symbol_replacement(self.target_freq, resolver_), - prev_freq=_direct_symbol_replacement(self.prev_freq, resolver_), + length=su.direct_symbol_replacement(self.length, resolver_), + w=su.direct_symbol_replacement(self.w, resolver_), + target_freq=su.direct_symbol_replacement(self.target_freq, resolver_), + prev_freq=su.direct_symbol_replacement(self.prev_freq, resolver_), neighbor_coupler_g_dict=( { - k: _direct_symbol_replacement(v, resolver_) + k: su.direct_symbol_replacement(v, resolver_) for k, v in self.neighbor_coupler_g_dict.items() } if self.neighbor_coupler_g_dict @@ -156,7 +125,7 @@ def _direct_symbol_replacement(x, resolver: cirq.ParamResolver): ), prev_neighbor_coupler_g_dict=( { - k: _direct_symbol_replacement(v, resolver_) + k: su.direct_symbol_replacement(v, resolver_) for k, v in self.prev_neighbor_coupler_g_dict.items() } if self.prev_neighbor_coupler_g_dict diff --git a/cirq-google/cirq_google/study/symbol_util.py b/cirq-google/cirq_google/study/symbol_util.py new file mode 100644 index 00000000000..b2360543201 --- /dev/null +++ b/cirq-google/cirq_google/study/symbol_util.py @@ -0,0 +1,53 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import AbstractSet, Any, TypeAlias + +import sympy +import tunits as tu + +import cirq + +# The gate is intended for the google internal use, hence the typing style +# follows more on the t-unit + symbol instead of float + symbol style. +ValueOrSymbol: TypeAlias = tu.Value | sympy.Basic + +# A sentile for not finding the key in resolver. +NOT_FOUND = "__NOT_FOUND__" + + +def direct_symbol_replacement(x, resolver: cirq.ParamResolver): + """A shortcut for value resolution to avoid tu.unit compare with float issue.""" + if isinstance(x, sympy.Symbol): + value = resolver.param_dict.get(x.name, NOT_FOUND) + if value == NOT_FOUND: + value = resolver.param_dict.get(x, NOT_FOUND) + if value != NOT_FOUND: + return value + return x # pragma: no cover + return x + + +def dict_param_name(dict_with_value: dict[Any, ValueOrSymbol] | None) -> AbstractSet[str]: + """Find the names of all parameterized value in a dictionary.""" + if dict_with_value is None: + return set() + return {v.name for v in dict_with_value.values() if cirq.is_parameterized(v)} + + +def is_parameterized_dict(dict_with_value: dict[Any, ValueOrSymbol] | None) -> bool: + """Check if any values in the dictionary is parameterized.""" + if dict_with_value is None: + return False # pragma: no cover + return any(cirq.is_parameterized(v) for v in dict_with_value.values()) diff --git a/cirq-google/cirq_google/study/symbol_util_test.py b/cirq-google/cirq_google/study/symbol_util_test.py new file mode 100644 index 00000000000..a6ef9d50b4d --- /dev/null +++ b/cirq-google/cirq_google/study/symbol_util_test.py @@ -0,0 +1,48 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy +import tunits as tu + +import cirq +from cirq_google.study import symbol_util as su + + +def test_dict_param_name(): + d = {"a": 54, "b": sympy.Symbol("t"), "c": sympy.Symbol("t"), "d": "sd"} + + assert su.dict_param_name(None) == set() + assert su.dict_param_name(d) == {"t"} + + +@pytest.mark.parametrize( + "d,expected", + [ + (None, False), + ({}, False), + ({"a": 50}, False), + ({"a": 54, "b": sympy.Symbol("t"), "c": sympy.Symbol("t"), "d": "sd"}, True), + ], +) +def test_is_parameterized_dict(d, expected): + assert su.is_parameterized_dict(d) == expected + + +def test_direct_symbol_replacement(): + value_list = [sympy.Symbol("t"), sympy.Symbol("v"), sympy.Symbol("z"), 123, "fd"] + resolver = cirq.ParamResolver({"t": 5 * tu.ns, sympy.Symbol("v"): 8 * tu.GHz}) + value_resolved = [su.direct_symbol_replacement(v, resolver) for v in value_list] + + assert value_resolved == [5 * tu.ns, 8 * tu.GHz, sympy.Symbol("z"), 123, "fd"]