-
Notifications
You must be signed in to change notification settings - Fork 423
Changes to support DoMINO Design Sensitivities work + DoMINO model code fixes #973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
peterdsharpe
merged 61 commits into
NVIDIA:main
from
peterdsharpe:psharpe/domino-sensitivities
Jul 11, 2025
Merged
Changes from 8 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
eb2166e
clarifies I/O in domino train.py
peterdsharpe 0b2460d
Gives paths in config.yaml user-agnostic pathnames
peterdsharpe 7f4dbce
Switches from relu -> gelu to allow smooth gradients
peterdsharpe 7cdbfd2
Adds initial commit for design sensitivities study
peterdsharpe 474f9a7
Corrects outdated type hint
peterdsharpe ea87787
Refactors parameters in signed_distance_field calls for clarity
peterdsharpe db42406
Refactors directory handling in create_directory and get_filenames fu…
peterdsharpe 71723c7
Deletes merge(); this function is (a) not used anywhere, (b) can be r…
peterdsharpe 10e5967
black formatting
peterdsharpe 007348b
Code quality improvements
peterdsharpe 36c4bd8
Replaces 'axis' with 'dim' in torch.cat calls for correctness with Py…
peterdsharpe 392b273
Adds initial changes for DoMINO sensitivity
peterdsharpe 1e3e7b5
Refactors DesignDatapipe and DoMINOInference for improved readability…
peterdsharpe 2de7ac0
Refactors DesignDatapipe to directly use STL centers for geometry coo…
peterdsharpe 5b88395
Enhances DesignDatapipe by updating bounding box type hints, improvin…
peterdsharpe 9d7d7da
Implements Laplacian smoothing for mesh data in a new utility functio…
peterdsharpe 1e1af3b
Adds numba to requirements for improved performance in sensitivity an…
peterdsharpe 73cf5f4
Adds sbatch_logs/ to .gitignore to exclude SLURM batch log files from…
peterdsharpe 0c05b1a
Merge branch 'psharpe/domino-sensitivities' of https://github.com/pet…
peterdsharpe 14b5303
Adds compute-optimized mesh_postprocessing utilities
peterdsharpe a214c03
Working `main.py` with abstracted postprocessing step
peterdsharpe ca48cb9
formatting
peterdsharpe b0c9c3f
Refactors main.py to remove duplicate STL combining function and stre…
peterdsharpe 74c5ebd
Commits configuration files for sensitivity studies
peterdsharpe 4dbe76a
Adds requirements.txt
peterdsharpe 0efce6e
Merge branch 'psharpe/domino-sensitivities' of github.com:peterdsharp…
peterdsharpe ef19d41
Adds raw and smooth drag gradient data files, and implements a plotti…
peterdsharpe 0f12306
Refactors import statements in main.py for consistency and clarity. S…
peterdsharpe 6be2f88
Creates main_gradient_checking.py for drag gradient checking using Do…
peterdsharpe c4b809a
Updates file paths in main_gradient_checking.py and plot_gradient_che…
peterdsharpe 1a51eb2
Adds a new aerodynamics example using DoMINO to compute design sensit…
peterdsharpe 679bf8e
Merge branch 'physicsnemo/main' into psharpe/domino-sensitivities
peterdsharpe 9e02933
Add README.md for DoMINO sensitivity analysis pipeline, detailing usa…
peterdsharpe 1b96606
black formatting fixes
peterdsharpe c56ac24
Add SPDX license headers to plot_gradient_checking.py
peterdsharpe 270d2ed
Fixes markdownlint
peterdsharpe b81910f
Removes unused import
peterdsharpe f4859ee
Updates license year
peterdsharpe 15facd8
Fixes license year
peterdsharpe 278d1c4
Removes unused main block sections
peterdsharpe fa46229
Removes erroneous uv.lock commit
peterdsharpe 2e90446
Removes some optimization language
peterdsharpe 3c8e84f
Merge branch 'main' into psharpe/domino-sensitivities
peterdsharpe 38b5a59
Remove unnecessary cached yaml
peterdsharpe 900427b
Refactors to not require separate config (instead pulling it from DoM…
peterdsharpe 085077f
Add warning for loading model without checkpoint in DoMINOInference
peterdsharpe 55a8837
Merge branch 'main' into psharpe/domino-sensitivities
peterdsharpe 63b3b5b
Add verbose option to DoMINOInference for memory usage logging
peterdsharpe a3542e6
Merge branch 'psharpe/domino-sensitivities' of github.com:peterdsharp…
peterdsharpe da20092
Refactor imports in design_datapipe.py for clarity and efficiency; re…
peterdsharpe fc1b48b
Refactor DesignDatapipe to use NearestNeighbors from cuML for neighbo…
peterdsharpe f79d269
Enhance DesignDatapipe to accept a device parameter for tensor manage…
peterdsharpe 67763af
Merge branch 'main' into psharpe/domino-sensitivities
peterdsharpe 37991f0
Readme cleanup
peterdsharpe 90f1e44
Replace GELU activation with a configurable activation function in Ge…
peterdsharpe 0358008
Merge branch 'physicsnemo/main' into psharpe/domino-sensitivities
peterdsharpe a15493a
formatting
peterdsharpe 8668758
remove duplicate section
peterdsharpe 5288a65
Makes activations configurable
peterdsharpe 23347ed
formatting
peterdsharpe 7891b17
add license
peterdsharpe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
269 changes: 269 additions & 0 deletions
269
examples/cfd/external_aerodynamics/domino_sensitivity/design_datapipe.py
peterdsharpe marked this conversation as resolved.
Show resolved
Hide resolved
peterdsharpe marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-FileCopyrightText: All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This is the datapipe to read OpenFoam files (vtp/vtu/stl) and save them as point clouds | ||
in npy format. | ||
|
||
""" | ||
|
||
import time, random, copy | ||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pyvista as pv | ||
import vtk | ||
from physicsnemo.utils.domino.utils import * | ||
from torch.utils.data import Dataset | ||
from torch.utils.data import DataLoader | ||
from physicsnemo.utils.sdf import signed_distance_field | ||
from physicsnemo.utils.domino.utils import * | ||
|
||
AIR_DENSITY = 1.205 | ||
STREAM_VELOCITY = 30.00 | ||
|
||
|
||
def combine_stls(stl_path, stl_files): | ||
meshes = [] | ||
for file in stl_files: | ||
if ".stl" in file: | ||
stl_file_path = os.path.join(stl_path, file) | ||
reader = pv.get_reader(stl_file_path) | ||
mesh_stl = reader.read() | ||
meshes.append(mesh_stl) | ||
combined_mesh = pv.merge(meshes) | ||
return combined_mesh | ||
|
||
|
||
class DesignDatapipe(Dataset): | ||
""" | ||
Datapipe for converting openfoam dataset to npy | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
mesh_stl, | ||
bounding_box, | ||
bounding_box_surface, | ||
grid_resolution, | ||
stream_velocity, | ||
air_density, | ||
stencil_size=7, | ||
device: int = 0, | ||
): | ||
self.mesh_stl = mesh_stl | ||
self.stl_vertices = self.mesh_stl.points | ||
self.num_points = self.mesh_stl.cell_centers().points.shape[0] | ||
self.bounding_box = bounding_box | ||
self.bounding_box_surface = bounding_box_surface | ||
self.device = device | ||
self.stencil_size = stencil_size | ||
self.grid_resolution = grid_resolution | ||
self.stream_velocity = stream_velocity | ||
self.air_density = air_density | ||
self.out_dict = self.process_stl() | ||
|
||
def __len__(self): | ||
return self.num_points | ||
# return 16 | ||
|
||
def __getitem__(self, idx): | ||
surface_mesh_centers = self.out_dict["surface_mesh_centers"][idx] | ||
surface_mesh_neighbors = self.out_dict["surface_mesh_neighbors"][idx] | ||
surface_normals = self.out_dict["surface_normals"][idx] | ||
surface_neighbors_normals = self.out_dict["surface_neighbors_normals"][idx] | ||
surface_areas = self.out_dict["surface_areas"][idx] | ||
surface_neighbors_areas = self.out_dict["surface_neighbors_areas"][idx] | ||
pos_normals_com_surface = self.out_dict["pos_surface_center_of_mass"][idx] | ||
|
||
out_dict_new = {} | ||
out_dict_new["surface_mesh_centers"] = np.float32(surface_mesh_centers) | ||
out_dict_new["surface_mesh_neighbors"] = np.float32(surface_mesh_neighbors) | ||
out_dict_new["surface_normals"] = np.float32(surface_normals) | ||
out_dict_new["surface_neighbors_normals"] = np.float32( | ||
surface_neighbors_normals | ||
) | ||
out_dict_new["surface_areas"] = np.float32(surface_areas) | ||
out_dict_new["surface_neighbors_areas"] = np.float32(surface_neighbors_areas) | ||
out_dict_new["pos_surface_center_of_mass"] = np.float32(pos_normals_com_surface) | ||
|
||
return out_dict_new | ||
|
||
def process_stl( | ||
self, | ||
): | ||
mesh_stl = self.mesh_stl | ||
length_scale = np.amax( | ||
np.amax(self.stl_vertices, 0) - np.amin(self.stl_vertices, 0) | ||
) | ||
stl_centers = mesh_stl.cell_centers().points | ||
# Assuming triangular elements | ||
stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[:, 1:] | ||
mesh_indices_flattened = stl_faces.flatten() | ||
print(stl_centers.shape, self.stl_vertices.shape) | ||
|
||
surface_areas = mesh_stl.compute_cell_sizes( | ||
length=False, area=True, volume=False | ||
) | ||
surface_areas = np.array(surface_areas.cell_data["Area"]) | ||
|
||
surface_normals = -1.0 * np.array(mesh_stl.cell_normals, dtype=np.float32) | ||
|
||
center_of_mass = calculate_center_of_mass(stl_centers, surface_areas) | ||
|
||
s_max = np.asarray(self.bounding_box_surface[1]) | ||
s_min = np.asarray(self.bounding_box_surface[0]) | ||
|
||
v_max = np.asarray(self.bounding_box[1]) | ||
v_min = np.asarray(self.bounding_box[0]) | ||
|
||
# General processing | ||
nx, ny, nz = self.grid_resolution | ||
|
||
grid = create_grid(v_max, v_min, self.grid_resolution) | ||
grid_reshaped = grid.reshape(nx * ny * nz, 3) | ||
|
||
# SDF on grid | ||
sdf_grid = signed_distance_field( | ||
mesh_vertices=self.stl_vertices, | ||
mesh_indices=mesh_indices_flattened, | ||
input_points=grid_reshaped, | ||
use_sign_winding_number=True, | ||
) | ||
sdf_grid = sdf_grid.numpy().reshape(nx, ny, nz) | ||
|
||
s_grid = create_grid(s_max, s_min, self.grid_resolution) | ||
surf_grid_reshaped = s_grid.reshape(nx * ny * nz, 3) | ||
|
||
surf_sdf_grid = signed_distance_field( | ||
mesh_vertices=self.stl_vertices, | ||
mesh_indices=mesh_indices_flattened, | ||
input_points=surf_grid_reshaped, | ||
use_sign_winding_number=True, | ||
) | ||
surf_sdf_grid = surf_sdf_grid.numpy().reshape(nx, ny, nz) | ||
|
||
# Sample surface_vertices | ||
grid = 2.0 * (grid - v_min) / (v_max - v_min) - 1.0 | ||
s_grid = 2.0 * (s_grid - s_min) / (s_max - s_min) - 1.0 | ||
|
||
# Surface processing | ||
surface_coordinates = stl_centers | ||
interp_func = KDTree(surface_coordinates) | ||
peterdsharpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dd, ii = interp_func.query(surface_coordinates, k=self.stencil_size) | ||
surface_neighbors = surface_coordinates[ii] | ||
surface_neighbors = surface_neighbors[:, 1:] + 1e-6 | ||
surface_neighbors_normals = surface_normals[ii] | ||
surface_neighbors_normals = surface_neighbors_normals[:, 1:] | ||
surface_neighbors_area = surface_areas[ii] | ||
surface_neighbors_area = surface_neighbors_area[:, 1:] | ||
|
||
pos_normals_com_surface = surface_coordinates - center_of_mass | ||
|
||
surface_coordinates = ( | ||
2.0 * (surface_coordinates - s_min) / (s_max - s_min) - 1.0 | ||
) | ||
surface_neighbors = 2.0 * (surface_neighbors - s_min) / (s_max - s_min) - 1.0 | ||
|
||
# Volume processing | ||
peterdsharpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
volume_coordinates = (v_max - v_min) * np.random.rand(10, 3) + v_min | ||
|
||
sdf_nodes, sdf_node_closest_point = signed_distance_field( | ||
self.stl_vertices, | ||
mesh_indices_flattened, | ||
volume_coordinates, | ||
include_hit_points=True, | ||
use_sign_winding_number=True, | ||
) | ||
sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) | ||
sdf_node_closest_point = sdf_node_closest_point.numpy() | ||
pos_normals_closest = volume_coordinates - sdf_node_closest_point | ||
pos_normals_com = volume_coordinates - center_of_mass | ||
volume_coordinates = 2.0 * (volume_coordinates - v_min) / (v_max - v_min) - 1.0 | ||
vol_grid_max_min = np.float32(np.asarray([v_min, v_max])) | ||
surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) | ||
|
||
geometry_points = 300_000 | ||
geometry_coordinates_sampled, idx_geometry = shuffle_array( | ||
stl_centers, geometry_points | ||
) | ||
|
||
# surface_points = 16 | ||
# surface_coordinates = surface_coordinates[:surface_points] | ||
# surface_neighbors = surface_neighbors[:surface_points] | ||
# surface_normals = surface_normals[:surface_points] | ||
# surface_neighbors_normals = surface_neighbors_normals[:surface_points] | ||
# surface_areas = surface_areas[:surface_points] | ||
# surface_neighbors_area = surface_neighbors_area[:surface_points] | ||
# pos_normals_com_surface = pos_normals_com_surface[:surface_points] | ||
|
||
return { | ||
"pos_volume_closest": pos_normals_closest, | ||
"pos_volume_center_of_mass": pos_normals_com, | ||
"pos_surface_center_of_mass": pos_normals_com_surface, | ||
"geometry_coordinates": geometry_coordinates_sampled, | ||
"grid": grid, | ||
"surf_grid": s_grid, | ||
"sdf_grid": sdf_grid, | ||
"sdf_surf_grid": surf_sdf_grid, | ||
"sdf_nodes": sdf_nodes, | ||
"surface_mesh_centers": surface_coordinates, | ||
"surface_mesh_neighbors": surface_neighbors, | ||
"surface_normals": surface_normals, | ||
"surface_neighbors_normals": surface_neighbors_normals, | ||
"surface_areas": surface_areas, | ||
"surface_neighbors_areas": surface_neighbors_area, | ||
"volume_mesh_centers": volume_coordinates, | ||
"volume_min_max": vol_grid_max_min, | ||
"surface_min_max": surf_grid_max_min, | ||
"length_scale": length_scale, | ||
"stream_velocity": np.expand_dims( | ||
np.array(self.stream_velocity, dtype=np.float32), -1 | ||
), | ||
"air_density": np.expand_dims( | ||
np.array(self.air_density, dtype=np.float32), -1 | ||
), | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
stl_path = "/raid/rranade/home/rranade/data/" | ||
dirnames = get_filenames(stl_path) | ||
filepath = os.path.join(stl_path, dirnames[0]) | ||
stl_files = get_filenames(filepath) | ||
mesh_stl = combine_stls(filepath, stl_files) | ||
|
||
bounding_box = [[-3.5, -2.25, -0.32], [8.5, 2.25, 3.00]] | ||
bounding_box_surface = [[-1.1, -1.2, -0.32], [4.5, 1.2, 1.2]] | ||
|
||
fd = DesignDatapipe( | ||
mesh_stl, | ||
bounding_box, | ||
bounding_box_surface, | ||
grid_resolution=[128, 64, 48], | ||
stream_velocity=30.0, | ||
air_density=1.205, | ||
device=0, | ||
) | ||
|
||
train_dataloader = DataLoader(fd, batch_size=256_000, shuffle=False) | ||
|
||
for i_batch, sample_batched in enumerate(train_dataloader): | ||
print(i_batch, sample_batched["surface_mesh_centers"].shape) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.