Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/test_plugins/test_invdes2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from __future__ import annotations

import autograd.numpy as np
import pytest

import tidy3d as td
from tidy3d.plugins.invdes2 import (
DeviceSpec,
FluxMetric,
InverseDesign,
OptimizerSpec,
TopologyDesignRegion,
)

from ..test_components.autograd.test_autograd import use_emulated_run # noqa: F401
from ..utils import run_emulated

sim_base = td.Simulation(
size=(10.0, 10.0, 10.0),
grid_spec=td.GridSpec.auto(wavelength=1.0, min_steps_per_wvl=10),
run_time=1.0,
structures=(),
monitors=[
td.FluxMonitor(center=(0.0, 0.0, 0.0), size=(1.0, 1.0, 1.0), freqs=[2e14], name="flux")
],
sources=(),
boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()),
medium=td.Medium(permittivity=1.0),
)

sim_data_base = run_emulated(sim_base, task_name="sim_base")

# TODO: Add more metrics here
metrics = [FluxMetric(monitor_name="flux", weight=0.5)]

# TODO: Add more design regions here
design_regions = [
TopologyDesignRegion(
size=(1.0, 1.0, 1.0), center=(0.0, 0.0, 0.0), eps_bounds=(1.0, 4.0), pixel_size=0.02
)
]

device_spec1 = DeviceSpec(
simulation=sim_base, design_regions=design_regions, metrics=metrics, name="d1"
)

device_spec2 = DeviceSpec(
simulation=sim_base, design_regions=design_regions, metrics=metrics, name="d2"
)

device_specs = [device_spec1, device_spec2]

optimizer_spec = OptimizerSpec(learning_rate=0.1, num_steps=1)

invdes = InverseDesign(optimizer_spec=optimizer_spec, device_specs=device_specs)


def test_parameter_shapes():
"""Ensure parameter shape metadata aligns across devices and regions.

- `InverseDesign.parameter_shape` should equal the list of each `DeviceSpec.parameter_shape`.
- Each `DeviceSpec.parameter_shape` should equal the list of each region's `parameter_shape`.
"""
assert invdes.parameter_shape == [d.parameter_shape for d in invdes.device_specs]
for device_spec in invdes.device_specs:
assert device_spec.parameter_shape == [
d.parameter_shape for d in device_spec.design_regions
]


def test_flatten_unflatten_params():
"""Round-trip flatten/unflatten preserves the parameter vector.

Uses helper constructors to build correctly sized parameter arrays, then verifies that
flatten → unflatten → flatten yields an identical 1D vector.
"""
params = invdes.ones()
flat = invdes._flatten_params(params)
restored = invdes._unflatten_params(flat)
flat2 = invdes._flatten_params(restored)

assert np.allclose(flat, flat2)


def test_design_region_to_structure():
"""Each design region can map its parameter vector to a `td.Structure`.

Builds per-region parameter arrays with the provided helper and ensures `to_structure`
returns a structure without error.
"""
for design_region in design_regions:
params = design_region.ones()
_ = design_region.to_structure(params)


def test_device_spec_get_simulation():
"""`DeviceSpec.get_simulation` appends one structure per design region.

The resulting simulation should contain the original structures plus the number of
design regions in the spec.
"""
for device_spec in device_specs:
params = device_spec.ones()
sim = device_spec.get_simulation(params)
assert len(sim.structures) == len(sim_base.structures) + len(device_spec.design_regions)


def test_invdes_get_simulations():
"""`InverseDesign.get_simulations` returns a batch keyed by device names.

Confirms the number of simulations equals the number of device specs and that keys are
exactly the device names.
"""
params = invdes.ones()
sims = invdes.get_simulations(params)

assert len(sims) == len(invdes.device_specs)
assert set(sims.keys()) == {device_spec.name for device_spec in invdes.device_specs}


def test_metric_evaluate():
"""`Metric.evaluate` produces a non-zero scalar from emulated monitor data."""
for metric in metrics:
mnt_data = sim_data_base[metric.monitor_name]
val = metric.evaluate(mnt_data)
assert not np.allclose(val, 0.0)


def test_device_spec_get_metric():
"""`DeviceSpec.get_metric` aggregates weighted metric values into a scalar."""
for device_spec in device_specs:
val = device_spec.get_metric(sim_data_base)
assert not np.allclose(val, 0.0)


def test_invdes_get_metric():
"""`InverseDesign.get_metric` sums device metrics from batch results."""
batch_data = {device_spec.name: sim_data_base for device_spec in invdes.device_specs}
val = invdes.get_metric(batch_data)
assert not np.allclose(val, 0.0)


def test_inverse_design_unique_names_validation():
"""Constructing `InverseDesign` with duplicate device names raises `ValueError`."""
device_specs_fail = [device_spec1, device_spec1]
with pytest.raises(ValueError):
InverseDesign(optimizer_spec=optimizer_spec, device_specs=device_specs_fail)


@pytest.fixture
def use_emulated(monkeypatch):
"""Emulate the InverseDesign.to_simulation_data to call emulated run."""
monkeypatch.setattr(
DeviceSpec,
"run_simulation",
lambda self, simulation: run_emulated(simulation, task_name="test"),
)
monkeypatch.setattr(
InverseDesign,
"run_simulations",
lambda self, sims: {
task_name: run_emulated(sim, task_name=task_name) for task_name, sim in sims.items()
},
)


def test_objective_function(use_emulated):
"""`InverseDesign.get_objective` returns a non-zero scalar using emulated runs."""
params = invdes.ones()
val = invdes.get_objective(params)
assert not np.allclose(val, 0.0)
Empty file.
11 changes: 11 additions & 0 deletions tidy3d/plugins/invdes2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Public API for the `invdes2` inverse design scaffold."""

from __future__ import annotations

from .design_region import TopologyDesignRegion
from .device_spec import DeviceSpec
from .inverse_design import InverseDesign
from .metric import FluxMetric
from .optimizer_spec import OptimizerSpec

__all__ = ["DeviceSpec", "FluxMetric", "InverseDesign", "OptimizerSpec", "TopologyDesignRegion"]
64 changes: 64 additions & 0 deletions tidy3d/plugins/invdes2/design_region.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from abc import abstractmethod
from dataclasses import dataclass
from typing import Union

import autograd.numpy as np

import tidy3d as td


@dataclass
class DesignRegion:
"""Abstract parameterized geometry provider for inverse design.

Implementations transform parameter arrays into concrete `td.Structure`
instances that will be appended to a base simulation.
"""

@abstractmethod
def to_structure(self, params: np.ndarray) -> td.Structure:
"""Return a `td.Structure` built from the provided parameters."""

@property
@abstractmethod
def parameter_shape(self) -> int:
"""Return the (flattened) shape of the parameters for this design region."""

def ones(self, **kwargs) -> np.ndarray:
"""Return an array of ones with the shape of the parameters for this design region."""
return np.ones(self.parameter_shape, **kwargs)


@dataclass
class TopologyDesignRegion(DesignRegion):
"""Design region as a pixellated permittivity grid."""

size: tuple[float, float, float]
center: tuple[float, float, float]
eps_bounds: tuple[float, float]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The eps_bounds field is defined but never used in the implementation. Consider using it to clamp parameter values or remove if not needed.

Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/plugins/invdes2/design_region.py
Line: 40:40

Comment:
**style:** The `eps_bounds` field is defined but never used in the implementation. Consider using it to clamp parameter values or remove if not needed.

How can I resolve this? If you propose a fix, please make it concise.

pixel_size: float

@property
def shape_3d(self) -> tuple[int, int, int]:
"""Return the shape of the parameters for this design region."""
return tuple(int(np.ceil(size / self.pixel_size)) for size in self.size)

@property
def parameter_shape(self) -> int:
"""Return the shape of the parameters for this design region."""
return int(np.prod(self.shape_3d))

def to_structure(self, params: np.ndarray) -> td.Structure:
"""Return a `td.Structure` built from the provided parameters."""

geometry = td.Box(center=self.center, size=self.size)

# TODO: add transformations
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Remove TODO comment before finalizing the pull request.

Context Used: Rule from dashboard - Remove temporary debugging code (print() calls), commented-out code, and other workarounds before fi... (source)

Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/plugins/invdes2/design_region.py
Line: 58:58

Comment:
**style:** Remove TODO comment before finalizing the pull request.

**Context Used:** Rule from `dashboard` - Remove temporary debugging code (print() calls), commented-out code, and other workarounds before fi... ([source](https://app.greptile.com/review/custom-context?memory=f6a669d8-0060-4f11-9cac-10ac7ee749ea))

How can I resolve this? If you propose a fix, please make it concise.

eps_data = params.reshape(self.shape_3d)

return td.Structure.from_permittivity_array(geometry=geometry, eps_data=eps_data)


DesignRegionType = Union[TopologyDesignRegion]
94 changes: 94 additions & 0 deletions tidy3d/plugins/invdes2/device_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

from dataclasses import dataclass

import autograd.numpy as np

import tidy3d as td
import tidy3d.web as web

from .design_region import DesignRegionType
from .metric import MetricType


@dataclass
class DeviceSpec:
"""Specification of a single device scenario for inverse design.

Attributes
----------
simulation:
Base `td.Simulation` onto which parameterized structures are appended.
design_regions:
Ordered list of `DesignRegion` instances. Each must consume a
corresponding parameter array in `get_simulation`.
metrics:
List of `Metric` instances whose weighted sum forms the device's
objective contribution.
name:
Unique identifier for this device, used as the `task_name` and as the
key in batch results.
"""

simulation: td.Simulation
design_regions: list[DesignRegionType]
metrics: list[MetricType]
name: str

def get_simulation(self, params: list[np.ndarray]) -> td.Simulation:
"""Construct a simulation by appending parameterized structures.

Parameters
----------
params:
List of arrays matching `design_regions` order; each array is passed
to the corresponding region's `to_structure` to produce a structure.

Returns
-------
td.Simulation
A new simulation with appended structures.
"""
structures = list(self.simulation.structures)
for param, design_region in zip(params, self.design_regions):
structure = design_region.to_structure(param)
structures.append(structure)
return self.simulation.updated_copy(structures=structures)

def run_simulation(self, simulation: td.Simulation) -> web.SimulationData:
"""Run the simulation via Tidy3D Web and return results."""
return web.run(simulation, task_name=self.name)

def get_metric(self, sim_data: web.SimulationData) -> float:
"""Compute the weighted sum of metrics for this device.

Parameters
----------
sim_data:
Simulation results to be consumed by each metric.

Returns
-------
float
Weighted sum of metric values.
"""
value = 0.0
for metric in self.metrics:
mnt_data = sim_data[metric.monitor_name]
value = value + metric.weight * metric.evaluate(mnt_data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: consider using value += metric.weight * metric.evaluate(mnt_data) for more idiomatic Python

Suggested change
value = value + metric.weight * metric.evaluate(mnt_data)
value += metric.weight * metric.evaluate(mnt_data)
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/plugins/invdes2/device_spec.py
Line: 78:78

Comment:
**style:** consider using `value += metric.weight * metric.evaluate(mnt_data)` for more idiomatic Python

```suggestion
            value += metric.weight * metric.evaluate(mnt_data)
```

How can I resolve this? If you propose a fix, please make it concise.

return value

def get_objective(self, params: list[np.ndarray]) -> float:
"""Build, run, and score this device for the given parameters."""
sim = self.get_simulation(params)
sim_data = self.run_simulation(sim)
return self.get_metric(sim_data)

@property
def parameter_shape(self) -> list[int]:
"""Return the shape of the parameters for each design region."""
return [design_region.parameter_shape for design_region in self.design_regions]

def ones(self, **kwargs) -> list[np.ndarray]:
"""Return a list of arrays of ones with the shape of the parameters for each design region."""
return [design_region.ones(**kwargs) for design_region in self.design_regions]
Loading