Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
eb2166e
clarifies I/O in domino train.py
peterdsharpe Jun 12, 2025
0b2460d
Gives paths in config.yaml user-agnostic pathnames
peterdsharpe Jun 12, 2025
7f4dbce
Switches from relu -> gelu to allow smooth gradients
peterdsharpe Jun 12, 2025
7cdbfd2
Adds initial commit for design sensitivities study
peterdsharpe Jun 13, 2025
474f9a7
Corrects outdated type hint
peterdsharpe Jun 13, 2025
ea87787
Refactors parameters in signed_distance_field calls for clarity
peterdsharpe Jun 13, 2025
db42406
Refactors directory handling in create_directory and get_filenames fu…
peterdsharpe Jun 13, 2025
71723c7
Deletes merge(); this function is (a) not used anywhere, (b) can be r…
peterdsharpe Jun 13, 2025
10e5967
black formatting
peterdsharpe Jun 13, 2025
007348b
Code quality improvements
peterdsharpe Jun 13, 2025
36c4bd8
Replaces 'axis' with 'dim' in torch.cat calls for correctness with Py…
peterdsharpe Jun 13, 2025
392b273
Adds initial changes for DoMINO sensitivity
peterdsharpe Jun 13, 2025
1e3e7b5
Refactors DesignDatapipe and DoMINOInference for improved readability…
peterdsharpe Jun 16, 2025
2de7ac0
Refactors DesignDatapipe to directly use STL centers for geometry coo…
peterdsharpe Jun 16, 2025
5b88395
Enhances DesignDatapipe by updating bounding box type hints, improvin…
peterdsharpe Jun 16, 2025
9d7d7da
Implements Laplacian smoothing for mesh data in a new utility functio…
peterdsharpe Jun 16, 2025
1e1af3b
Adds numba to requirements for improved performance in sensitivity an…
peterdsharpe Jun 16, 2025
73cf5f4
Adds sbatch_logs/ to .gitignore to exclude SLURM batch log files from…
peterdsharpe Jun 16, 2025
0c05b1a
Merge branch 'psharpe/domino-sensitivities' of https://github.com/pet…
peterdsharpe Jun 16, 2025
14b5303
Adds compute-optimized mesh_postprocessing utilities
peterdsharpe Jun 17, 2025
a214c03
Working `main.py` with abstracted postprocessing step
peterdsharpe Jun 17, 2025
ca48cb9
formatting
peterdsharpe Jun 17, 2025
b0c9c3f
Refactors main.py to remove duplicate STL combining function and stre…
peterdsharpe Jun 17, 2025
74c5ebd
Commits configuration files for sensitivity studies
peterdsharpe Jun 17, 2025
4dbe76a
Adds requirements.txt
peterdsharpe Jun 18, 2025
0efce6e
Merge branch 'psharpe/domino-sensitivities' of github.com:peterdsharp…
peterdsharpe Jun 18, 2025
ef19d41
Adds raw and smooth drag gradient data files, and implements a plotti…
peterdsharpe Jun 18, 2025
0f12306
Refactors import statements in main.py for consistency and clarity. S…
peterdsharpe Jun 18, 2025
6be2f88
Creates main_gradient_checking.py for drag gradient checking using Do…
peterdsharpe Jun 18, 2025
c4b809a
Updates file paths in main_gradient_checking.py and plot_gradient_che…
peterdsharpe Jun 19, 2025
1a51eb2
Adds a new aerodynamics example using DoMINO to compute design sensit…
peterdsharpe Jun 19, 2025
679bf8e
Merge branch 'physicsnemo/main' into psharpe/domino-sensitivities
peterdsharpe Jun 19, 2025
9e02933
Add README.md for DoMINO sensitivity analysis pipeline, detailing usa…
peterdsharpe Jun 19, 2025
1b96606
black formatting fixes
peterdsharpe Jun 20, 2025
c56ac24
Add SPDX license headers to plot_gradient_checking.py
peterdsharpe Jun 20, 2025
270d2ed
Fixes markdownlint
peterdsharpe Jun 20, 2025
b81910f
Removes unused import
peterdsharpe Jun 20, 2025
f4859ee
Updates license year
peterdsharpe Jun 20, 2025
15facd8
Fixes license year
peterdsharpe Jun 20, 2025
278d1c4
Removes unused main block sections
peterdsharpe Jun 24, 2025
fa46229
Removes erroneous uv.lock commit
peterdsharpe Jun 24, 2025
2e90446
Removes some optimization language
peterdsharpe Jun 24, 2025
3c8e84f
Merge branch 'main' into psharpe/domino-sensitivities
peterdsharpe Jun 25, 2025
38b5a59
Remove unnecessary cached yaml
peterdsharpe Jun 25, 2025
900427b
Refactors to not require separate config (instead pulling it from DoM…
peterdsharpe Jun 25, 2025
085077f
Add warning for loading model without checkpoint in DoMINOInference
peterdsharpe Jun 25, 2025
55a8837
Merge branch 'main' into psharpe/domino-sensitivities
peterdsharpe Jul 7, 2025
63b3b5b
Add verbose option to DoMINOInference for memory usage logging
peterdsharpe Jul 7, 2025
a3542e6
Merge branch 'psharpe/domino-sensitivities' of github.com:peterdsharp…
peterdsharpe Jul 7, 2025
da20092
Refactor imports in design_datapipe.py for clarity and efficiency; re…
peterdsharpe Jul 7, 2025
fc1b48b
Refactor DesignDatapipe to use NearestNeighbors from cuML for neighbo…
peterdsharpe Jul 8, 2025
f79d269
Enhance DesignDatapipe to accept a device parameter for tensor manage…
peterdsharpe Jul 8, 2025
67763af
Merge branch 'main' into psharpe/domino-sensitivities
peterdsharpe Jul 8, 2025
37991f0
Readme cleanup
peterdsharpe Jul 8, 2025
90f1e44
Replace GELU activation with a configurable activation function in Ge…
peterdsharpe Jul 8, 2025
0358008
Merge branch 'physicsnemo/main' into psharpe/domino-sensitivities
peterdsharpe Jul 8, 2025
a15493a
formatting
peterdsharpe Jul 8, 2025
8668758
remove duplicate section
peterdsharpe Jul 8, 2025
5288a65
Makes activations configurable
peterdsharpe Jul 9, 2025
23347ed
formatting
peterdsharpe Jul 9, 2025
7891b17
add license
peterdsharpe Jul 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/cfd/external_aerodynamics/domino/src/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ resume_dir: ${output}/models
# └───────────────────────────────────────────┘
data_processor: # Data processor configurable parameters
kind: drivaer_aws # must be either drivesim or drivaer_aws
output_dir: /lustre/rranade/aws_data_all/
input_dir: /lustre/datasets/drivaer_aws/drivaer_data_full/
cached_dir: /lustre/cached/drivaer_aws/drivaer_data_full/
output_dir: /user/aws_data_all/
input_dir: /data/drivaer_aws/drivaer_data_full/
cached_dir: /user/cached/drivaer_aws/drivaer_data_full/
use_cache: false
num_processors: 12

Expand All @@ -64,8 +64,8 @@ variables:
# │ Training Data Configs │
# └───────────────────────────────────────────┘
data: # Input directory for training and validation data
input_dir: /lustre/rranade/aws_data_all/
input_dir_val: /lustre/rranade/aws_data_all_val/
input_dir: /user/aws_data_all/
input_dir_val: /user/aws_data_all_val/
bounding_box: # Bounding box dimensions for computational domain
min: [-3.5, -2.25 , -0.32]
max: [8.5 , 2.25 , 3.00]
Expand Down Expand Up @@ -153,7 +153,7 @@ train: # Training configurable parameters
sampler:
shuffle: true
drop_last: false
checkpoint_dir: /lustre/rranade/models/ # Use only for retraining
checkpoint_dir: /user/models/ # Use only for retraining

# ┌───────────────────────────────────────────┐
# │ Validation Configs │
Expand All @@ -170,9 +170,9 @@ val: # Validation configurable parameters
# │ Testing data Configs │
# └───────────────────────────────────────────┘
eval: # Testing configurable parameters
test_path: /lustre/rranade/testing_data # Dir for testing data in raw format (vtp, vtu ,stls)
save_path: /lustre/rranade/predicted_data # Dir to save predicted results in raw format (vtp, vtu)
test_path: /user/testing_data # Dir for testing data in raw format (vtp, vtu ,stls)
save_path: /user/predicted_data # Dir to save predicted results in raw format (vtp, vtu)
checkpoint_name: DoMINO.0.455.pt # Name of checkpoint to select from saved checkpoints
scaling_param_path: /lustre/rranade/scaling_params
scaling_param_path: /user/scaling_params
refine_stl: False # Automatically refine STL during inference
stencil_size: 7 # Stencil size for evaluating surface and volume model
4 changes: 3 additions & 1 deletion examples/cfd/external_aerodynamics/domino/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def main(cfg: DictConfig) -> None:
gpu_handle = nvmlDeviceGetHandleByIndex(dist.device.index)

compute_scaling_factors(
cfg, cfg.data_processor.output_dir, use_cache=cfg.data_processor.use_cache
cfg=cfg,
input_path=cfg.data_processor.output_dir,
use_cache=cfg.data_processor.use_cache
)
model_type = cfg.model.model_type

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This is the datapipe to read OpenFoam files (vtp/vtu/stl) and save them as point clouds
in npy format.

"""

import time, random, copy
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable

import numpy as np
import pandas as pd
import pyvista as pv
import vtk
from physicsnemo.utils.domino.utils import *
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from physicsnemo.utils.sdf import signed_distance_field
from physicsnemo.utils.domino.utils import *

AIR_DENSITY = 1.205
STREAM_VELOCITY = 30.00


def combine_stls(stl_path, stl_files):
meshes = []
for file in stl_files:
if ".stl" in file:
stl_file_path = os.path.join(stl_path, file)
reader = pv.get_reader(stl_file_path)
mesh_stl = reader.read()
meshes.append(mesh_stl)
combined_mesh = pv.merge(meshes)
return combined_mesh


class DesignDatapipe(Dataset):
"""
Datapipe for converting openfoam dataset to npy

"""

def __init__(
self,
mesh_stl,
bounding_box,
bounding_box_surface,
grid_resolution,
stream_velocity,
air_density,
stencil_size=7,
device: int = 0,
):
self.mesh_stl = mesh_stl
self.stl_vertices = self.mesh_stl.points
self.num_points = self.mesh_stl.cell_centers().points.shape[0]
self.bounding_box = bounding_box
self.bounding_box_surface = bounding_box_surface
self.device = device
self.stencil_size = stencil_size
self.grid_resolution = grid_resolution
self.stream_velocity = stream_velocity
self.air_density = air_density
self.out_dict = self.process_stl()

def __len__(self):
return self.num_points
# return 16

def __getitem__(self, idx):
surface_mesh_centers = self.out_dict["surface_mesh_centers"][idx]
surface_mesh_neighbors = self.out_dict["surface_mesh_neighbors"][idx]
surface_normals = self.out_dict["surface_normals"][idx]
surface_neighbors_normals = self.out_dict["surface_neighbors_normals"][idx]
surface_areas = self.out_dict["surface_areas"][idx]
surface_neighbors_areas = self.out_dict["surface_neighbors_areas"][idx]
pos_normals_com_surface = self.out_dict["pos_surface_center_of_mass"][idx]

out_dict_new = {}
out_dict_new["surface_mesh_centers"] = np.float32(surface_mesh_centers)
out_dict_new["surface_mesh_neighbors"] = np.float32(surface_mesh_neighbors)
out_dict_new["surface_normals"] = np.float32(surface_normals)
out_dict_new["surface_neighbors_normals"] = np.float32(
surface_neighbors_normals
)
out_dict_new["surface_areas"] = np.float32(surface_areas)
out_dict_new["surface_neighbors_areas"] = np.float32(surface_neighbors_areas)
out_dict_new["pos_surface_center_of_mass"] = np.float32(pos_normals_com_surface)

return out_dict_new

def process_stl(
self,
):
mesh_stl = self.mesh_stl
length_scale = np.amax(
np.amax(self.stl_vertices, 0) - np.amin(self.stl_vertices, 0)
)
stl_centers = mesh_stl.cell_centers().points
# Assuming triangular elements
stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[:, 1:]
mesh_indices_flattened = stl_faces.flatten()
print(stl_centers.shape, self.stl_vertices.shape)

surface_areas = mesh_stl.compute_cell_sizes(
length=False, area=True, volume=False
)
surface_areas = np.array(surface_areas.cell_data["Area"])

surface_normals = -1.0 * np.array(mesh_stl.cell_normals, dtype=np.float32)

center_of_mass = calculate_center_of_mass(stl_centers, surface_areas)

s_max = np.asarray(self.bounding_box_surface[1])
s_min = np.asarray(self.bounding_box_surface[0])

v_max = np.asarray(self.bounding_box[1])
v_min = np.asarray(self.bounding_box[0])

# General processing
nx, ny, nz = self.grid_resolution

grid = create_grid(v_max, v_min, self.grid_resolution)
grid_reshaped = grid.reshape(nx * ny * nz, 3)

# SDF on grid
sdf_grid = signed_distance_field(
mesh_vertices=self.stl_vertices,
mesh_indices=mesh_indices_flattened,
input_points=grid_reshaped,
use_sign_winding_number=True,
)
sdf_grid = sdf_grid.numpy().reshape(nx, ny, nz)

s_grid = create_grid(s_max, s_min, self.grid_resolution)
surf_grid_reshaped = s_grid.reshape(nx * ny * nz, 3)

surf_sdf_grid = signed_distance_field(
mesh_vertices=self.stl_vertices,
mesh_indices=mesh_indices_flattened,
input_points=surf_grid_reshaped,
use_sign_winding_number=True,
)
surf_sdf_grid = surf_sdf_grid.numpy().reshape(nx, ny, nz)

# Sample surface_vertices
grid = 2.0 * (grid - v_min) / (v_max - v_min) - 1.0
s_grid = 2.0 * (s_grid - s_min) / (s_max - s_min) - 1.0

# Surface processing
surface_coordinates = stl_centers
interp_func = KDTree(surface_coordinates)
dd, ii = interp_func.query(surface_coordinates, k=self.stencil_size)
surface_neighbors = surface_coordinates[ii]
surface_neighbors = surface_neighbors[:, 1:] + 1e-6
surface_neighbors_normals = surface_normals[ii]
surface_neighbors_normals = surface_neighbors_normals[:, 1:]
surface_neighbors_area = surface_areas[ii]
surface_neighbors_area = surface_neighbors_area[:, 1:]

pos_normals_com_surface = surface_coordinates - center_of_mass

surface_coordinates = (
2.0 * (surface_coordinates - s_min) / (s_max - s_min) - 1.0
)
surface_neighbors = 2.0 * (surface_neighbors - s_min) / (s_max - s_min) - 1.0

# Volume processing
volume_coordinates = (v_max - v_min) * np.random.rand(10, 3) + v_min

sdf_nodes, sdf_node_closest_point = signed_distance_field(
self.stl_vertices,
mesh_indices_flattened,
volume_coordinates,
include_hit_points=True,
use_sign_winding_number=True,
)
sdf_nodes = sdf_nodes.numpy().reshape(-1, 1)
sdf_node_closest_point = sdf_node_closest_point.numpy()
pos_normals_closest = volume_coordinates - sdf_node_closest_point
pos_normals_com = volume_coordinates - center_of_mass
volume_coordinates = 2.0 * (volume_coordinates - v_min) / (v_max - v_min) - 1.0
vol_grid_max_min = np.float32(np.asarray([v_min, v_max]))
surf_grid_max_min = np.float32(np.asarray([s_min, s_max]))

geometry_points = 300_000
geometry_coordinates_sampled, idx_geometry = shuffle_array(
stl_centers, geometry_points
)

# surface_points = 16
# surface_coordinates = surface_coordinates[:surface_points]
# surface_neighbors = surface_neighbors[:surface_points]
# surface_normals = surface_normals[:surface_points]
# surface_neighbors_normals = surface_neighbors_normals[:surface_points]
# surface_areas = surface_areas[:surface_points]
# surface_neighbors_area = surface_neighbors_area[:surface_points]
# pos_normals_com_surface = pos_normals_com_surface[:surface_points]

return {
"pos_volume_closest": pos_normals_closest,
"pos_volume_center_of_mass": pos_normals_com,
"pos_surface_center_of_mass": pos_normals_com_surface,
"geometry_coordinates": geometry_coordinates_sampled,
"grid": grid,
"surf_grid": s_grid,
"sdf_grid": sdf_grid,
"sdf_surf_grid": surf_sdf_grid,
"sdf_nodes": sdf_nodes,
"surface_mesh_centers": surface_coordinates,
"surface_mesh_neighbors": surface_neighbors,
"surface_normals": surface_normals,
"surface_neighbors_normals": surface_neighbors_normals,
"surface_areas": surface_areas,
"surface_neighbors_areas": surface_neighbors_area,
"volume_mesh_centers": volume_coordinates,
"volume_min_max": vol_grid_max_min,
"surface_min_max": surf_grid_max_min,
"length_scale": length_scale,
"stream_velocity": np.expand_dims(
np.array(self.stream_velocity, dtype=np.float32), -1
),
"air_density": np.expand_dims(
np.array(self.air_density, dtype=np.float32), -1
),
}


if __name__ == "__main__":
stl_path = "/raid/rranade/home/rranade/data/"
dirnames = get_filenames(stl_path)
filepath = os.path.join(stl_path, dirnames[0])
stl_files = get_filenames(filepath)
mesh_stl = combine_stls(filepath, stl_files)

bounding_box = [[-3.5, -2.25, -0.32], [8.5, 2.25, 3.00]]
bounding_box_surface = [[-1.1, -1.2, -0.32], [4.5, 1.2, 1.2]]

fd = DesignDatapipe(
mesh_stl,
bounding_box,
bounding_box_surface,
grid_resolution=[128, 64, 48],
stream_velocity=30.0,
air_density=1.205,
device=0,
)

train_dataloader = DataLoader(fd, batch_size=256_000, shuffle=False)

for i_batch, sample_batched in enumerate(train_dataloader):
print(i_batch, sample_batched["surface_mesh_centers"].shape)
Loading