From 1fb4980556146250b7175e365d8a3e4c0538a60a Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Wed, 5 Mar 2025 00:58:49 +0800 Subject: [PATCH 01/11] feat(domino): support domino --- examples/domino/conf/config.yaml | 125 +++ examples/domino/layers/__init__.py | 15 + examples/domino/layers/ball_query.py | 322 +++++++ examples/domino/model.py | 1181 ++++++++++++++++++++++++++ examples/domino/requirements.txt | 4 + 5 files changed, 1647 insertions(+) create mode 100644 examples/domino/conf/config.yaml create mode 100644 examples/domino/layers/__init__.py create mode 100644 examples/domino/layers/ball_query.py create mode 100644 examples/domino/model.py create mode 100644 examples/domino/requirements.txt diff --git a/examples/domino/conf/config.yaml b/examples/domino/conf/config.yaml new file mode 100644 index 0000000000..6fdd14d0d2 --- /dev/null +++ b/examples/domino/conf/config.yaml @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +project: # Project name + name: AWS_Dataset + +exp_tag: 1 # Experiment tag +# Main output directory. +output: outputs/${project.name}/${exp_tag} + +hydra: # Hydra config + run: + dir: ${output} + output_subdir: hydra # Default is .hydra which causes files not being uploaded in W&B. + +data: # Input directory for training and validation data + input_dir: /lustre/rranade/modulus_dev/data/volume_data/ + input_dir_val: /lustre/rranade/modulus_dev/data/volume_data_val/ + bounding_box: # Bounding box dimensions for computational domain + min: [-3.5, -2.25 , -0.32] + max: [8.5 , 2.25 , 3.00] + bounding_box_surface: # Bounding box dimensions for car surface + min: [-1.1, -1.2 , -0.32] + max: [4.5 , 1.2 , 1.2] + +# The directory to search for checkpoints to continue training. +resume_dir: ${output}/models + +variables: + surface: + solution: + # The following is for AWS DrivAer dataset. + pMeanTrim: scalar + wallShearStressMeanTrim: vector + volume: + solution: + # The following is for AWS DrivAer dataset. + UMeanTrim: vector + pMeanTrim: scalar + nutMeanTrim: scalar + +model: + model_type: combined # train which model? surface, volume, combined + loss_function: "mse" # mse or rmse + interp_res: [128, 64, 48] # resolution of latent space + use_sdf_in_basis_func: true # SDF in basis function network + positional_encoding: false # calculate positional encoding? + volume_points_sample: 8192 # Number of points to sample in volume per epoch + surface_points_sample: 8192 # Number of points to sample on surface per epoch + geom_points_sample: 200_000 # Number of points to sample on STL per epoch + surface_neighbors: true # Pre-compute surface neighborhood from input data + num_surface_neighbors: 7 # How many neighbors? + use_surface_normals: true # Use surface normals and surface areas for surface computation? + use_only_normals: true # Use only surface normals and not surface area + integral_loss_scaling_factor: 0 # Scale integral loss by this factor + normalization: min_max_scaling # or mean_std_scaling + encode_parameters: true # encode inlet velocity and air density in the model + geometry_rep: # Hyperparameters for geometry representation network + base_filters: 16 + geo_conv: + base_neurons: 32 # 256 or 64 + base_neurons_out: 1 + radius_short: 0.1 + radius_long: 0.5 # 1.0, 1.5 + hops: 1 + geo_processor: + base_filters: 8 + geo_processor_sdf: + base_filters: 8 + nn_basis_functions: # Hyperparameters for basis function network + base_layer: 512 + aggregation_model: # Hyperparameters for aggregation network + base_layer: 512 + position_encoder: # Hyperparameters for position encoding network + base_neurons: 512 + geometry_local: # Hyperparameters for local geometry extraction + neighbors_in_radius: 64 + radius: 0.05 # 0.2 in expt 7 + base_layer: 512 + parameter_model: + base_layer: 512 + scaling_params: [30.0, 1.226] # [inlet_velocity, air_density] + +train: # Training configurable parameters + epochs: 500 + checkpoint_interval: 1 + dataloader: + batch_size: 1 + pin_memory: true + sampler: + shuffle: true + drop_last: false + checkpoint_dir: /lustre/rranade/modulus_dev/modulus_forked/modulus/examples/cfd/external_aerodynamics/domino/outputs/AWS_Dataset/3/models/ + +val: # Validation configurable parameters + dataloader: + batch_size: 1 + pin_memory: true + sampler: + shuffle: true + drop_last: false + +eval: # Testing configurable parameters + test_path: /lustre/rranade/benchmarking/drivaer_aws_surface_test_new/ + save_path: /lustre/rranade/domino/mesh_predictions_surf_final1/ + checkpoint_name: DoMINO.0.50.pt + +data_processor: # Data processor configurable parameters + kind: drivaer_aws # must be either drivesim or drivaer_aws + output_dir: /lustre/rranade/modulus_dev/data/volume_data/ + input_dir: /lustre/datasets/drivaer_aws/drivaer_data_full/ + num_processors: 12 diff --git a/examples/domino/layers/__init__.py b/examples/domino/layers/__init__.py new file mode 100644 index 0000000000..b2f171d4ac --- /dev/null +++ b/examples/domino/layers/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/domino/layers/ball_query.py b/examples/domino/layers/ball_query.py new file mode 100644 index 0000000000..58f5cdf594 --- /dev/null +++ b/examples/domino/layers/ball_query.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import warp as wp + + +class BallQuery(paddle.autograd.PyLayer): + """ + Warp based Ball Query. + """ + + @wp.kernel + def ball_query( + points1: wp.array(dtype=wp.vec3), + points2: wp.array(dtype=wp.vec3), + grid: wp.uint64, + k: wp.int32, + radius: wp.float32, + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + ): + + # Get index of point1 + tid = wp.tid() + + # Get position from points1 + pos = points1[tid] + + # particle contact + neighbors = wp.hash_grid_query(grid, pos, radius) + + # Keep track of the number of neighbors found + nr_found = wp.int32(0) + + # loop through neighbors to compute density + for index in neighbors: + # Check if outside the radius + pos2 = points2[index] + if wp.length(pos - pos2) > radius: + continue + + # Add neighbor to the list + mapping[0, tid, nr_found] = index + + # Increment the number of neighbors found + nr_found += 1 + + # Break if we have found enough neighbors + if nr_found == k: + num_neighbors[0, tid] = k + break + + # Set the number of neighbors + num_neighbors[0, tid] = nr_found + + @wp.kernel + def sparse_ball_query( + points2: wp.array(dtype=wp.vec3), + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + outputs: wp.array4d(dtype=wp.float32), + ): + # Get index of point1 + p1 = wp.tid() + + # Get number of neighbors + k = num_neighbors[0, p1] + + # Loop through neighbors + for _k in range(k): + # Get point2 index + index = mapping[0, p1, _k] + + # Get position from points2 + pos = points2[index] + + # Set the output + outputs[0, p1, _k, 0] = pos[0] + outputs[0, p1, _k, 1] = pos[1] + outputs[0, p1, _k, 2] = pos[2] + + @staticmethod + def forward( + ctx, + points1, + points2, + lengths1, + lengths2, + k, + radius, + hash_grid, + ): + # Only works for batch size 1 + if points1.shape[0] != 1: + raise AssertionError("nly works for batch size 1") + + # Convert from paddle to warp + ctx.points1 = wp.from_paddle( + points1[0], dtype=wp.vec3, requires_grad=points1.stop_gradient + ) + ctx.points2 = wp.from_paddle( + points2[0], dtype=wp.vec3, requires_grad=points2.stop_gradient + ) + ctx.lengths1 = wp.from_paddle(lengths1, dtype=wp.int32, requires_grad=False) + ctx.lengths2 = wp.from_paddle(lengths2, dtype=wp.int32, requires_grad=False) + ctx.k = k + ctx.radius = radius + + # Allocate the mapping and outputs + mapping = paddle.zeros([1, points1.shape[1], k], dtype=paddle.int32) + mapping.stop_gradient = False + ctx.mapping = wp.from_paddle(mapping, dtype=wp.int32, requires_grad=False) + num_neighbors = paddle.zeros([1, points1.shape[1]], dtype=paddle.int32) + num_neighbors.stop_gradient = False + ctx.num_neighbors = wp.from_paddle( + num_neighbors, dtype=wp.int32, requires_grad=False + ) + outputs = paddle.zeros([1, points1.shape[1], k, 3], dtype=paddle.float32) + outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient + ctx.outputs = wp.from_paddle(outputs, dtype=wp.float32) + outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient + + # Make grid + ctx.hash_grid = hash_grid + + # Build the grid + ctx.hash_grid.build(ctx.points2, radius) + + # Run the kernel to get mapping + wp.launch( + BallQuery.ball_query, + inputs=[ + ctx.points1, + ctx.points2, + ctx.hash_grid.id, + k, + radius, + ], + outputs=[ + ctx.mapping, + ctx.num_neighbors, + ], + dim=[ctx.points1.shape[0]], + ) + + # Run the kernel to get outputs + wp.launch( + BallQuery.sparse_ball_query, + inputs=[ + ctx.points2, + ctx.mapping, + ctx.num_neighbors, + ], + outputs=[ + ctx.outputs, + ], + dim=[ctx.points1.shape[0]], + ) + + return ( + wp.to_paddle(ctx.mapping), + wp.to_paddle(ctx.num_neighbors), + wp.to_paddle(ctx.outputs), + ) + + @staticmethod + def backward(ctx, grad_mapping, grad_num_neighbors, grad_outputs): + # Map incoming paddle grads to our output variable + ctx.outputs.grad = wp.from_paddle(grad_outputs, dtype=wp.float32) + + # Run the kernel in adjoint mode + wp.launch( + BallQuery.sparse_ball_query, + inputs=[ + ctx.points2, + ctx.mapping, + ctx.num_neighbors, + ], + outputs=[ + ctx.outputs, + ], + adj_inputs=[ctx.points2.grad, ctx.mapping.grad, ctx.num_neighbors.grad], + adj_outputs=[ + ctx.outputs.grad, + ], + dim=[ctx.points1.shape[0]], + adjoint=True, + ) + + # Return the gradients + return ( + wp.to_paddle(ctx.points1.grad).unsqueeze(0), + wp.to_paddle(ctx.points2.grad).unsqueeze(0), + None, + None, + None, + None, + None, + ) + + +class BallQueryLayer(paddle.nn.Layer): + """ + Paddle layer for differentiable and accelerated Ball Query + operation using Warp. + Args: + k (int): Number of neighbors. + radius (float): Radius of influence. + grid_size (int): Uniform grid resolution + """ + + def __init__(self, k, radius, grid_size=32): + super().__init__() + wp.init() + self.k = k + self.radius = radius + self.hash_grid = wp.HashGrid(grid_size, grid_size, grid_size) + + def forward(self, points1, points2, lengths1, lengths2): + return BallQuery.apply( + points1, + points2, + lengths1, + lengths2, + self.k, + self.radius, + self.hash_grid, + ) + + +if __name__ == "__main__": + # Make function for saving point clouds + import pyvista as pv + + def save_point_cloud(points, name): + cloud = pv.PolyData(points.detach().cpu().numpy()) + cloud.save(name) + + # Check forward pass + # Initialize tensors + n = 1 # number of point clouds + p1 = 128000 # 100000 # number of points in point cloud 1 + d = 3 # dimension of the points + p2 = 39321 # 100000 # number of points in point cloud 2 + points1 = paddle.rand([n, p1, d]) + points1.stop_gradient = False + + points2 = paddle.rand([n, p2, d]) + points2.stop_gradient = False + lengths1 = paddle.full((n,), p1, dtype=paddle.int32) + lengths2 = paddle.full((n,), p2, dtype=paddle.int32) + k = 256 # maximum number of neighbors + radius = 0.1 + + # Make ball query layer + layer = BallQueryLayer(k, radius) + + # Make ball query + with wp.ScopedTimer("ball query", active=True): + mapping, num_neighbors, outputs = layer( + points1, + points2, + lengths1, + lengths2, + ) + + for i in range(2): + p1 += 100 + p2 += 100 + points1 = paddle.rand([n, p1, d]) + points1.stop_gradient = True + points2 = paddle.rand([n, p2, d]) + points2.stop_gradient = True + lengths1 = paddle.full((n,), p1, dtype=paddle.int32) + lengths2 = paddle.full((n,), p2, dtype=paddle.int32) + + mapping, num_neighbors, outputs = layer( + points1, + points2, + lengths1, + lengths2, + ) + + # Perform matrix multiplication as comparison for timing + with wp.ScopedTimer("matrix multiplication 256", active=True): + outputs2 = paddle.matmul(points1[0], paddle.ones([3, k])) + + # Save the point clouds + save_point_cloud(points1[0], "point1.vtk") + save_point_cloud(points2[0], "point2.vtk") + save_point_cloud(outputs[0].reshape([-1, 3]), "outputs.vtk") + + # Optimize the background points to move to the query points + optimizer = paddle.optimizer.SGD(parameters=[points2], learning_rate=0.01) + + # Test optimization + for i in range(2): + optimizer.clear_gradients() + mapping, num_neighbors, outputs = layer(points1, points2, lengths1, lengths2) + + loss = (points1.unsqueeze(2) - outputs).pow(2).sum() + loss.backward() + optimizer.step() + + # Save the point clouds + save_point_cloud(points1[0], "point1_{}.vtk".format(i)) + save_point_cloud(outputs[0].reshape([-1, 3]), "outputs_{}.vtk".format(i)) diff --git a/examples/domino/model.py b/examples/domino/model.py new file mode 100644 index 0000000000..f270a8eb10 --- /dev/null +++ b/examples/domino/model.py @@ -0,0 +1,1181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code contains the DoMINO model architecture. +The DoMINO class contains an architecture to model both surface and +volume quantities together as well as separately (controlled using +the config.yaml file) +""" + +# from dataclasses import dataclass + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from layers.ball_query import BallQueryLayer + +# from modulus.models.meta import ModelMetaData +# from modulus.models.module import Module + + +def calculate_pos_encoding(nx, d=8): + """Function to caluculate positional encoding""" + vec = [] + for k in range(int(d / 2)): + vec.append(paddle.sin(nx / 10000 ** (2 * (k) / d))) + vec.append(paddle.cos(nx / 10000 ** (2 * (k) / d))) + return vec + + +def scale_sdf(sdf): + """Function to scale SDF""" + return sdf / (0.4 + abs(sdf)) + + +def calculate_gradient(sdf): + """Function to calculate the gradients of SDF""" + m, n, o = sdf.shape[2], sdf.shape[3], sdf.shape[4] + sdf_x = sdf[:, :, 2:m, :, :] - sdf[:, :, 0 : m - 2, :, :] + sdf_y = sdf[:, :, :, 2:n, :] - sdf[:, :, :, 0 : n - 2, :] + sdf_z = sdf[:, :, :, :, 2:o] - sdf[:, :, :, :, 0 : o - 2] + + sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 0, 1), mode="constant", value=0.0) + sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 1, 0), mode="constant", value=0.0) + sdf_y = F.pad(x=sdf_y, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=0.0) + sdf_y = F.pad(x=sdf_y, pad=(0, 0, 1, 0, 0, 0), mode="constant", value=0.0) + sdf_z = F.pad(x=sdf_z, pad=(0, 1, 0, 0, 0, 0), mode="constant", value=0.0) + sdf_z = F.pad(x=sdf_z, pad=(1, 0, 0, 0, 0, 0), mode="constant", value=0.0) + + return sdf_x, sdf_y, sdf_z + + +def binarize_sdf(sdf): + """Function to calculate the binarize the SDF""" + sdf = paddle.where(sdf >= 0, 0.0, 1.0).to(dtype=sdf.dtype) + return sdf + + +class BQWarp(nn.Layer): + """Warp based ball-query layer""" + + def __init__( + self, + input_features, + grid_resolution=[256, 96, 64], + radius=0.25, + neighbors_in_radius=10, + ): + super().__init__() + self.ball_query_layer = BallQueryLayer(neighbors_in_radius, radius) + self.grid_resolution = grid_resolution + + def forward(self, x, p_grid, reverse_mapping=True): + batch_size = x.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + p1 = nx * ny * nz + p2 = x.shape[1] + + if reverse_mapping: + lengths1 = paddle.full((batch_size,), p1, dtype=paddle.int32) + lengths2 = paddle.full((batch_size,), p2, dtype=paddle.int32) + mapping, num_neighbors, outputs = self.ball_query_layer( + p_grid, + x, + lengths1, + lengths2, + ) + else: + lengths1 = paddle.full((batch_size,), p2, dtype=paddle.int32) + lengths2 = paddle.full((batch_size,), p1, dtype=paddle.int32) + mapping, num_neighbors, outputs = self.ball_query_layer( + x, + p_grid, + lengths1, + lengths2, + ) + + return mapping, outputs + + +class GeoConvOut(nn.Layer): + """Geometry layer to project STLs on grids""" + + def __init__(self, input_features, model_parameters, grid_resolution=[256, 96, 64]): + super().__init__() + base_neurons = model_parameters.base_neurons + + self.fc1 = nn.Linear(input_features, base_neurons) + self.fc2 = nn.Linear(base_neurons, int(base_neurons / 2)) + self.fc3 = nn.Linear(int(base_neurons / 2), model_parameters.base_neurons_out) + + self.grid_resolution = grid_resolution + + self.activation = F.relu + + def forward(self, x, radius=0.025, neighbors_in_radius=10): + batch_size = x.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + + mask = abs(x - 0) > 1e-6 + + x = self.activation(self.fc1(x)) + x = self.activation(self.fc2(x)) + x = F.tanh(self.fc3(x)) + mask = mask[:, :, :, 0:1].expand( + [mask.shape[0], mask.shape[1], mask.shape[2], x.shape[-1]] + ) + + # paddle does not support multiplication with boolean tensors, + # so we convert the mask to float + # x = torch.sum(x * mask, 2) + x = paddle.sum(x * mask.to(dtype=x.dtype), 2) + + x = paddle.reshape(x, (batch_size, x.shape[-1], nx, ny, nz)) + return x + + +class GeoProcessor(nn.Layer): + """Geometry processing layer using CNNs""" + + def __init__(self, input_filters, model_parameters): + super().__init__() + base_filters = model_parameters.base_filters + self.conv1 = nn.Conv3D( + input_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv_bn1 = nn.BatchNorm3D(int(base_filters)) + self.conv2 = nn.Conv3D( + base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn2 = nn.BatchNorm3D(int(2 * base_filters)) + self.conv3 = nn.Conv3D( + 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn3 = nn.BatchNorm3D(int(4 * base_filters)) + self.conv3_1 = nn.Conv3D( + 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv4 = nn.Conv3D( + 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn4 = nn.BatchNorm3D(int(2 * base_filters)) + self.conv5 = nn.Conv3D( + 4 * base_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv_bn5 = nn.BatchNorm3D(int(base_filters)) + self.conv6 = nn.Conv3D( + 2 * base_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv_bn6 = nn.BatchNorm3D(int(input_filters)) + self.conv7 = nn.Conv3D( + 2 * input_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv8 = nn.Conv3D(input_filters, 1, kernel_size=3, padding="same") + self.avg_pool = paddle.nn.AvgPool3D((2, 2, 2)) + self.max_pool = nn.MaxPool3D(2) + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.activation = F.relu + self.batch_norm = False + + def forward(self, x): + # Encoder + x0 = x + if self.batch_norm: + x = self.activation(self.conv_bn1(self.conv1(x))) + else: + x = self.activation(self.conv1(x)) + x = self.max_pool(x) + x1 = x + if self.batch_norm: + x = self.activation(self.conv_bn2(self.conv2(x))) + else: + x = self.activation((self.conv2(x))) + x = self.max_pool(x) + + x2 = x + if self.batch_norm: + x = self.activation(self.conv_bn3(self.conv2(x))) + else: + x = self.activation((self.conv3(x))) + x = self.max_pool(x) + + # Processor loop + x = F.relu(self.conv3_1(x)) + + # Decoder + if self.batch_norm: + x = self.activation(self.conv_bn4(self.conv4(x))) + else: + x = self.activation((self.conv4(x))) + x = self.upsample(x) + x = paddle.concat((x, x2), axis=1) + + if self.batch_norm: + x = self.activation(self.conv_bn5(self.conv5(x))) + else: + x = self.activation((self.conv5(x))) + x = self.upsample(x) + x = paddle.concat((x, x1), axis=1) + if self.batch_norm: + x = self.activation(self.conv_bn6(self.conv6(x))) + else: + x = self.activation((self.conv6(x))) + x = self.upsample(x) + x = paddle.concat((x, x0), axis=1) + + x = self.activation(self.conv7(x)) + x = self.conv8(x) + + return x + + +class GeometryRep(nn.Layer): + """Geometry representation from STLs block""" + + def __init__(self, input_features, model_parameters=None): + super().__init__() + geometry_rep = model_parameters.geometry_rep + + self.bq_warp_short = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=geometry_rep.geo_conv.radius_short, + ) + + self.bq_warp_long = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=geometry_rep.geo_conv.radius_long, + ) + + self.geo_conv_out = GeoConvOut( + input_features=input_features, + model_parameters=geometry_rep.geo_conv, + grid_resolution=model_parameters.interp_res, + ) + + self.geo_processor_short_range = GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ) + self.geo_processor_long_range = GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ) + self.geo_processor_sdf = GeoProcessor( + input_filters=6, model_parameters=geometry_rep.geo_processor + ) + self.activation = F.relu + self.radius_short = geometry_rep.geo_conv.radius_short + self.radius_long = geometry_rep.geo_conv.radius_long + self.hops = geometry_rep.geo_conv.hops + + def forward(self, x, p_grid, sdf): + + # Expand SDF + sdf = paddle.unsqueeze(sdf, 1) + + # Calculate short-range geoemtry dependency + mapping, k_short = self.bq_warp_short(x, p_grid) + x_encoding_short = self.geo_conv_out(k_short) + + # Calculate long-range geometry dependency + mapping, k_long = self.bq_warp_long(x, p_grid) + x_encoding_long = self.geo_conv_out(k_long) + + # Scaled sdf to emphasis on surface + scaled_sdf = scale_sdf(sdf) + # Binary sdf + binary_sdf = binarize_sdf(sdf) + # Gradients of SDF + sdf_x, sdf_y, sdf_z = calculate_gradient(sdf) + + # Propagate information in the geometry enclosed BBox + for _ in range(self.hops): + dx = self.geo_processor_short_range(x_encoding_short) / self.hops + x_encoding_short = x_encoding_short + dx + + # Propagate information in the computational domain BBox + for _ in range(self.hops): + dx = self.geo_processor_long_range(x_encoding_long) / self.hops + x_encoding_long = x_encoding_long + dx + + # Process SDF and its computed features + sdf = paddle.concat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) + sdf_encoding = self.geo_processor_sdf(sdf) + + # Geometry encoding comprised of short-range, long-range and SDF features + encoding_g = paddle.concat((x_encoding_short, sdf_encoding, x_encoding_long), 1) + + return encoding_g + + +class NNBasisFunctions(nn.Layer): + """Basis function layer for point clouds""" + + def __init__(self, input_features, model_parameters=None): + super(NNBasisFunctions, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + facets = x + facets = self.activation(self.fc1(facets)) + facets = self.activation(self.fc2(facets)) + facets = self.fc3(facets) + + return facets + + +class ParameterModel(nn.Layer): + """Layer to encode parameters such as inlet velocity and air density""" + + def __init__(self, input_features, model_parameters=None): + super(ParameterModel, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + params = x + params = self.activation(self.fc1(params)) + params = self.activation(self.fc2(params)) + params = self.fc3(params) + + return params + + +class AggregationModel(nn.Layer): + """Layer to aggregate local geometry encoding with basis functions""" + + def __init__( + self, input_features, output_features, model_parameters=None, new_change=True + ): + super(AggregationModel, self).__init__() + self.input_features = input_features + self.output_features = output_features + self.new_change = new_change + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.fc4 = nn.Linear(int(base_layer), int(base_layer)) + self.fc5 = nn.Linear(int(base_layer), self.output_features) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + self.bn4 = nn.BatchNorm1D(int(base_layer)) + self.activation = F.relu + + def forward(self, x): + out = self.activation(self.fc1(x)) + out = self.activation(self.fc2(out)) + out = self.activation(self.fc3(out)) + out = self.activation(self.fc4(out)) + + out = self.fc5(out) + + return out + + +# @dataclass +# class MetaData(ModelMetaData): +# name: str = "DoMINO" +# # Optimization +# jit: bool = False +# cuda_graphs: bool = True +# amp: bool = True +# # Inference +# onnx_cpu: bool = True +# onnx_gpu: bool = True +# onnx_runtime: bool = True +# # Physics informed +# var_dim: int = 1 +# func_torch: bool = False +# auto_grad: bool = False + + +class DoMINO(nn.Layer): + """DoMINO model architecture + Parameters + ---------- + input_features : int + Number of point input features + output_features_vol : int + Number of output features in volume + output_features_surf : int + Number of output features on surface + model_parameters: dict + Dictionary of model parameters controlled by config.yaml + + Example + ------- + >>> from modulus.models.domino.model import DoMINO + >>> import torch, os + >>> from hydra import compose, initialize + >>> from omegaconf import OmegaConf + >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + >>> cfg = OmegaConf.register_new_resolver("eval", eval) + >>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"): + ... cfg = compose(config_name="config") + >>> cfg.model.model_type = "combined" + >>> model = DoMINO( + ... input_features=3, + ... output_features_vol=5, + ... output_features_surf=4, + ... model_parameters=cfg.model + ... ) + + Warp ... + >>> bsize = 1 + >>> nx, ny, nz = 128, 64, 48 + >>> num_neigh = 7 + >>> pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) + >>> pos_normals_com_vol = paddle.randn([bsize, 100, 3]) + >>> pos_normals_com_surface = paddle.randn([bsize, 100, 3]) + >>> geom_centers = paddle.randn([bsize, 100, 3]) + >>> grid = paddle.randn([bsize, nx, ny, nz, 3]) + >>> surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) + >>> sdf_grid = paddle.randn([bsize, nx, ny, nz]) + >>> sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) + >>> sdf_nodes = paddle.randn([bsize, 100, 1]) + >>> surface_coordinates = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) + >>> surface_normals = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) + >>> surface_sizes = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) + >>> volume_coordinates = paddle.randn([bsize, 100, 3]) + >>> vol_grid_max_min = paddle.randn([bsize, 2, 3]) + >>> surf_grid_max_min = paddle.randn([bsize, 2, 3]) + >>> stream_velocity = paddle.randn([bsize, 1]) + >>> air_density = paddle.randn([bsize, 1]) + >>> input_dict = { + ... "pos_volume_closest": pos_normals_closest_vol, + ... "pos_volume_center_of_mass": pos_normals_com_vol, + ... "pos_surface_center_of_mass": pos_normals_com_surface, + ... "geometry_coordinates": geom_centers, + ... "grid": grid, + ... "surf_grid": surf_grid, + ... "sdf_grid": sdf_grid, + ... "sdf_surf_grid": sdf_surf_grid, + ... "sdf_nodes": sdf_nodes, + ... "surface_mesh_centers": surface_coordinates, + ... "surface_mesh_neighbors": surface_neighbors, + ... "surface_normals": surface_normals, + ... "surface_neighbors_normals": surface_neighbors_normals, + ... "surface_areas": surface_sizes, + ... "surface_neighbors_areas": surface_neighbors_sizes, + ... "volume_mesh_centers": volume_coordinates, + ... "volume_min_max": vol_grid_max_min, + ... "surface_min_max": surf_grid_max_min, + ... "stream_velocity": stream_velocity, + ... "air_density": air_density, + ... } + >>> output = model(input_dict) + Module ... + >>> print(f"{output[0].shape}, {output[1].shape}") + torch.Size([1, 100, 5]), torch.Size([1, 100, 4]) + """ + + def __init__( + self, + input_features, + output_features_vol=None, + output_features_surf=None, + model_parameters=None, + ): + super(DoMINO, self).__init__() + self.input_features = input_features + self.output_features_vol = output_features_vol + self.output_features_surf = output_features_surf + + if self.output_features_vol is None and self.output_features_surf is None: + raise ValueError("Need to specify number of volume or surface features") + + self.num_variables_vol = output_features_vol + self.num_variables_surf = output_features_surf + self.grid_resolution = model_parameters.interp_res + self.surface_neighbors = model_parameters.surface_neighbors + self.use_surface_normals = model_parameters.use_surface_normals + self.use_only_normals = model_parameters.use_only_normals + self.encode_parameters = model_parameters.encode_parameters + self.param_scaling_factors = model_parameters.parameter_model.scaling_params + + if self.use_surface_normals: + if self.use_only_normals: + input_features_surface = input_features + 3 + else: + input_features_surface = input_features + 4 + else: + input_features_surface = input_features + + if self.encode_parameters: + # Defining the parameter model + base_layer_p = model_parameters.parameter_model.base_layer + self.parameter_model = ParameterModel( + input_features=2, model_parameters=model_parameters.parameter_model + ) + else: + base_layer_p = 0 + + self.geo_rep = GeometryRep( + input_features=input_features, + model_parameters=model_parameters, + ) + + # Basis functions for surface and volume + base_layer_nn = model_parameters.nn_basis_functions.base_layer + if self.output_features_surf is not None: + self.nn_basis_surf = nn.LayerList() + for _ in range(self.num_variables_surf): + self.nn_basis_surf.append( + NNBasisFunctions( + input_features=input_features_surface, + model_parameters=model_parameters.nn_basis_functions, + ) + ) + + if self.output_features_vol is not None: + self.nn_basis_vol = nn.LayerList() + for _ in range(self.num_variables_vol): + self.nn_basis_vol.append( + NNBasisFunctions( + input_features=input_features, + model_parameters=model_parameters.nn_basis_functions, + ) + ) + + # Positional encoding + position_encoder_base_neurons = model_parameters.position_encoder.base_neurons + if self.output_features_vol is not None: + if model_parameters.positional_encoding: + inp_pos_vol = 25 if model_parameters.use_sdf_in_basis_func else 12 + else: + inp_pos_vol = 7 if model_parameters.use_sdf_in_basis_func else 3 + + self.fc_p_vol = nn.Linear(inp_pos_vol, position_encoder_base_neurons) + + if self.output_features_surf is not None: + if model_parameters.positional_encoding: + inp_pos_surf = 12 + else: + inp_pos_surf = 3 + + self.fc_p_surf = nn.Linear(inp_pos_surf, position_encoder_base_neurons) + + # Positional encoding hidden layers + self.fc_p1 = nn.Linear( + position_encoder_base_neurons, position_encoder_base_neurons + ) + self.fc_p2 = nn.Linear( + position_encoder_base_neurons, position_encoder_base_neurons + ) + + # BQ for surface and volume + self.neighbors_in_radius = model_parameters.geometry_local.neighbors_in_radius + self.radius = model_parameters.geometry_local.radius + self.bq_warp = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=self.radius, + neighbors_in_radius=self.neighbors_in_radius, + ) + + base_layer_geo = model_parameters.geometry_local.base_layer + self.fc_1 = nn.Linear(self.neighbors_in_radius * 3, base_layer_geo) + self.fc_2 = nn.Linear(base_layer_geo, base_layer_geo) + self.activation = F.relu + + # Aggregation model + if self.output_features_surf is not None: + # Surface + self.agg_model_surf = nn.LayerList() + for _ in range(self.num_variables_surf): + self.agg_model_surf.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo + + base_layer_p, + output_features=1, + model_parameters=model_parameters.aggregation_model, + ) + ) + + if self.output_features_vol is not None: + # Volume + self.agg_model_vol = nn.LayerList() + for _ in range(self.num_variables_vol): + self.agg_model_vol.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo + + base_layer_p, + output_features=1, + model_parameters=model_parameters.aggregation_model, + ) + ) + + def geometry_encoder(self, geo_centers, p_grid, sdf): + """Function to return local geometry encoding""" + return self.geo_rep(geo_centers, p_grid, sdf) + + def position_encoder(self, encoding_node, eval_mode="volume"): + """Function to calculate positional encoding""" + if eval_mode == "volume": + x = self.activation(self.fc_p_vol(encoding_node)) + elif eval_mode == "surface": + x = self.activation(self.fc_p_surf(encoding_node)) + x = self.activation(self.fc_p1(x)) + x = self.fc_p2(x) + return x + + def geo_encoding_local_surface(self, encoding_g, volume_mesh_centers, p_grid): + """Function to calculate local geometry encoding from global encoding for surface""" + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + mapping = mapping.astype(paddle.int64) + mask = mapping != 0 + + geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) + geo_encoding = geo_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] + ) + sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) + sdf_encoding = sdf_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] + ) + geo_encoding_long = paddle.reshape( + encoding_g[:, 2], (batch_size, 1, nx * ny * nz) + ) + geo_encoding_long = geo_encoding_long.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] + ) + + # geo_encoding_sampled = torch.gather(geo_encoding, 2, mapping) * mask + # sdf_encoding_sampled = torch.gather(sdf_encoding, 2, mapping) * mask + # geo_encoding_long_sampled = torch.gather(geo_encoding_long, 2, mapping) * mask + geo_encoding_sampled = paddle.take_along_axis( + geo_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + sdf_encoding_sampled = paddle.take_along_axis( + sdf_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + geo_encoding_long_sampled = paddle.take_along_axis( + geo_encoding_long, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + + encoding_g = paddle.concat( + (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), + axis=2, + ) + encoding_g = self.activation(self.fc_1(encoding_g)) + encoding_g = self.fc_2(encoding_g) + + return encoding_g + + def geo_encoding_local(self, encoding_g, volume_mesh_centers, p_grid): + """Function to calculate local geometry encoding from global encoding""" + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + mapping = mapping.astype(paddle.int64) + mask = mapping != 0 + + geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) + geo_encoding = geo_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] + ) + sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) + sdf_encoding = sdf_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] + ) + geo_encoding_long = paddle.reshape( + encoding_g[:, 2], (batch_size, 1, nx * ny * nz) + ) + geo_encoding_long = geo_encoding_long.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] + ) + + # geo_encoding_sampled = torch.gather(geo_encoding, 2, mapping) * mask + # sdf_encoding_sampled = torch.gather(sdf_encoding, 2, mapping) * mask + # geo_encoding_long_sampled = torch.gather(geo_encoding_long, 2, mapping) * mask + geo_encoding_sampled = paddle.take_along_axis( + geo_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + sdf_encoding_sampled = paddle.take_along_axis( + sdf_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + geo_encoding_long_sampled = paddle.take_along_axis( + geo_encoding_long, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + + encoding_g = paddle.concat( + (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), + axis=2, + ) + encoding_g = self.activation(self.fc_1(encoding_g)) + encoding_g = self.fc_2(encoding_g) + + return encoding_g + + def calculate_solution_with_neighbors( + self, + surface_mesh_centers, + encoding_g, + encoding_node, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + inlet_velocity, + air_density, + ): + """Function to approximate solution given the neighborhood information""" + num_variables = self.num_variables_surf + nn_basis = self.nn_basis_surf + agg_model = self.agg_model_surf + num_sample_points = surface_mesh_neighbors.shape[2] + 1 + + if self.encode_parameters: + inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + [ + inlet_velocity.shape[0], + surface_mesh_centers.shape[1], + inlet_velocity.shape[2], + ] + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = paddle.unsqueeze(air_density, 1) + air_density = air_density.expand( + [ + air_density.shape[0], + surface_mesh_centers.shape[1], + air_density.shape[2], + ] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = paddle.concat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + + if self.use_surface_normals: + if self.use_only_normals: + surface_mesh_centers = paddle.concat( + (surface_mesh_centers, surface_normals), + axis=-1, + ) + surface_mesh_neighbors = paddle.concat( + ( + surface_mesh_neighbors, + surface_neighbors_normals, + ), + axis=-1, + ) + + else: + surface_mesh_centers = paddle.concat( + (surface_mesh_centers, surface_normals, 10**5 * surface_areas), + axis=-1, + ) + surface_mesh_neighbors = paddle.concat( + ( + surface_mesh_neighbors, + surface_neighbors_normals, + 10**5 * surface_neighbors_areas, + ), + axis=-1, + ) + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = surface_mesh_centers + else: + volume_m_c = surface_mesh_neighbors[:, :, p - 1] + noise = surface_mesh_centers - volume_m_c + dist = paddle.sqrt( + noise[:, :, 0:1] ** 2.0 + + noise[:, :, 1:2] ** 2.0 + + noise[:, :, 2:3] ** 2.0 + ) + basis_f = nn_basis[f](volume_m_c) + output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = paddle.concat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = paddle.concat((output_all, output_res), axis=-1) + + return output_all + + def calculate_solution( + self, + volume_mesh_centers, + encoding_g, + encoding_node, + inlet_velocity, + air_density, + eval_mode, + num_sample_points=20, + noise_intensity=50, + ): + """Function to approximate solution sampling the neighborhood information""" + if eval_mode == "volume": + num_variables = self.num_variables_vol + nn_basis = self.nn_basis_vol + agg_model = self.agg_model_vol + elif eval_mode == "surface": + num_variables = self.num_variables_surf + nn_basis = self.nn_basis_surf + agg_model = self.agg_model_surf + + if self.encode_parameters: + inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + [ + inlet_velocity.shape[0], + volume_mesh_centers.shape[1], + inlet_velocity.shape[2], + ] + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = paddle.unsqueeze(air_density, 1) + air_density = air_density.expand( + [ + air_density.shape[0], + volume_mesh_centers.shape[1], + air_density.shape[2], + ] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = paddle.concat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = volume_mesh_centers + else: + noise = paddle.rand( + shape=volume_mesh_centers.shape, dtype=volume_mesh_centers.dtype + ) + noise = 2 * (noise - 0.5) + noise = noise / noise_intensity + dist = paddle.sqrt( + noise[:, :, 0:1] ** 2.0 + + noise[:, :, 1:2] ** 2.0 + + noise[:, :, 2:3] ** 2.0 + ) + volume_m_c = volume_mesh_centers + noise + basis_f = nn_basis[f](volume_m_c) + output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = paddle.concat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = paddle.concat((output_all, output_res), axis=-1) + + return output_all + + def forward( + self, + data_dict, + ): + # Loading STL inputs, bounding box grids, precomputed SDF and scaling factors + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + # Parameters + stream_velocity = data_dict["stream_velocity"] + air_density = data_dict["air_density"] + + if self.output_features_vol is not None: + # Represent geometry on computational grid + # Computational domain grid + p_grid = data_dict["grid"] + sdf_grid = data_dict["sdf_grid"] + # Scaling factors + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + + # Normalize based on computational domain + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + encoding_g_vol = self.geo_rep(geo_centers_vol, p_grid, sdf_grid) + + # Normalize based on BBox around surface (car) + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) + + # SDF on volume mesh nodes + sdf_nodes = data_dict["sdf_nodes"] + # Positional encoding based on closest point on surface to a volume node + pos_volume_closest = data_dict["pos_volume_closest"] + # Positional encoding based on center of mass of geometry to volume node + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + encoding_node_vol = paddle.concat( + (sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), axis=-1 + ) + + # Calculate positional encoding on volume nodes + encoding_node_vol = self.position_encoder( + encoding_node_vol, eval_mode="volume" + ) + + if self.output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) + + # Positional encoding based on center of mass of geometry to surface node + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + encoding_node_surf = pos_surface_center_of_mass + + # Calculate positional encoding on surface centers + encoding_node_surf = self.position_encoder( + encoding_node_surf, eval_mode="surface" + ) + + encoding_g = 0.5 * encoding_g_surf + # Average the encodings + if self.output_features_vol is not None: + encoding_g += 0.5 * encoding_g_vol + + if self.output_features_vol is not None: + # Calculate local geometry encoding for volume + # Sampled points on volume + volume_mesh_centers = data_dict["volume_mesh_centers"] + encoding_g_vol = self.geo_encoding_local( + encoding_g, volume_mesh_centers, p_grid + ) + + # Approximate solution on volume node + output_vol = self.calculate_solution( + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + stream_velocity, + air_density, + eval_mode="volume", + ) + else: + output_vol = None + + if self.output_features_surf is not None: + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + # Calculate local geometry encoding for surface + encoding_g_surf = self.geo_encoding_local_surface( + 0.5 * encoding_g_surf, surface_mesh_centers, s_grid + ) + + # Approximate solution on surface cell center + if not self.surface_neighbors: + output_surf = self.calculate_solution( + surface_mesh_centers, + encoding_g_surf, + encoding_node_surf, + stream_velocity, + air_density, + eval_mode="surface", + num_sample_points=1, + noise_intensity=500, + ) + else: + output_surf = self.calculate_solution_with_neighbors( + surface_mesh_centers, + encoding_g_surf, + encoding_node_surf, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + stream_velocity, + air_density, + ) + else: + output_surf = None + + return output_vol, output_surf + + +if __name__ == "__main__": + from hydra import compose + from hydra import initialize + from omegaconf import OmegaConf + + if paddle.device.cuda.device_count() >= 1: + paddle.set_device("gpu") + else: + paddle.set_device("cpu") + cfg = OmegaConf.register_new_resolver("eval", eval) + with initialize(version_base="1.3", config_path="conf"): + cfg = compose(config_name="config") + cfg.model.model_type = "combined" + model = DoMINO( + input_features=3, + output_features_vol=5, + output_features_surf=4, + model_parameters=cfg.model, + ) + + bsize = 1 + nx, ny, nz = 128, 64, 48 + num_neigh = 7 + pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) + pos_normals_com_vol = paddle.randn([bsize, 100, 3]) + pos_normals_com_surface = paddle.randn([bsize, 100, 3]) + geom_centers = paddle.randn([bsize, 100, 3]) + grid = paddle.randn([bsize, nx, ny, nz, 3]) + surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) + sdf_grid = paddle.randn([bsize, nx, ny, nz]) + sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) + sdf_nodes = paddle.randn([bsize, 100, 1]) + surface_coordinates = paddle.randn([bsize, 100, 3]) + surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) + surface_normals = paddle.randn([bsize, 100, 3]) + surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) + surface_sizes = paddle.randn([bsize, 100, 3]) + surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) + volume_coordinates = paddle.randn([bsize, 100, 3]) + vol_grid_max_min = paddle.randn([bsize, 2, 3]) + surf_grid_max_min = paddle.randn([bsize, 2, 3]) + stream_velocity = paddle.randn([bsize, 1]) + air_density = paddle.randn([bsize, 1]) + input_dict = { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "stream_velocity": stream_velocity, + "air_density": air_density, + } + output = model(input_dict) + print(f"{output[0].shape}, {output[1].shape}") diff --git a/examples/domino/requirements.txt b/examples/domino/requirements.txt new file mode 100644 index 0000000000..2ff5dd1ccf --- /dev/null +++ b/examples/domino/requirements.txt @@ -0,0 +1,4 @@ +hydra-core +importlib_metadata +pyvista==0.34.2 +warp-lang From a313b5d4573c07153766c4f4d71ec2b583026462 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Sat, 15 Mar 2025 23:22:17 +0800 Subject: [PATCH 02/11] feat(domino): support domino for training and test --- examples/domino/README.md | 43 + examples/domino/conf/config.yaml | 23 +- examples/domino/modulus/__init__.py | 0 .../{layers => modulus/datapipes}/__init__.py | 0 .../domino/modulus/datapipes/cae/__init__.py | 17 + .../modulus/datapipes/cae/domino_datapipe.py | 553 ++++++++++ .../domino/modulus/distributed/__init__.py | 20 + examples/domino/modulus/distributed/config.py | 250 +++++ .../domino/modulus/distributed/manager.py | 611 +++++++++++ examples/domino/modulus/launch/__init__.py | 15 + .../domino/modulus/launch/logging/__init__.py | 17 + .../domino/modulus/launch/logging/console.py | 88 ++ .../domino/modulus/launch/logging/launch.py | 443 ++++++++ .../domino/modulus/launch/logging/mlflow.py | 199 ++++ .../domino/modulus/launch/logging/utils.py | 60 ++ .../domino/modulus/launch/logging/wandb.py | 131 +++ .../domino/modulus/launch/utils/__init__.py | 18 + .../domino/modulus/launch/utils/checkpoint.py | 387 +++++++ examples/domino/modulus/models/__init__.py | 0 .../domino/modulus/models/layers/__init__.py | 15 + .../{ => modulus/models}/layers/ball_query.py | 0 examples/domino/{ => modulus/models}/model.py | 27 +- examples/domino/modulus/utils/__init__.py | 15 + examples/domino/modulus/utils/domino/utils.py | 385 +++++++ examples/domino/modulus/utils/sdf.py | 143 +++ examples/domino/test.py | 742 ++++++++++++++ examples/domino/train.py | 953 ++++++++++++++++++ 27 files changed, 5141 insertions(+), 14 deletions(-) create mode 100644 examples/domino/README.md create mode 100644 examples/domino/modulus/__init__.py rename examples/domino/{layers => modulus/datapipes}/__init__.py (100%) create mode 100644 examples/domino/modulus/datapipes/cae/__init__.py create mode 100644 examples/domino/modulus/datapipes/cae/domino_datapipe.py create mode 100644 examples/domino/modulus/distributed/__init__.py create mode 100644 examples/domino/modulus/distributed/config.py create mode 100644 examples/domino/modulus/distributed/manager.py create mode 100644 examples/domino/modulus/launch/__init__.py create mode 100644 examples/domino/modulus/launch/logging/__init__.py create mode 100644 examples/domino/modulus/launch/logging/console.py create mode 100644 examples/domino/modulus/launch/logging/launch.py create mode 100644 examples/domino/modulus/launch/logging/mlflow.py create mode 100644 examples/domino/modulus/launch/logging/utils.py create mode 100644 examples/domino/modulus/launch/logging/wandb.py create mode 100644 examples/domino/modulus/launch/utils/__init__.py create mode 100644 examples/domino/modulus/launch/utils/checkpoint.py create mode 100644 examples/domino/modulus/models/__init__.py create mode 100644 examples/domino/modulus/models/layers/__init__.py rename examples/domino/{ => modulus/models}/layers/ball_query.py (100%) rename examples/domino/{ => modulus/models}/model.py (98%) create mode 100644 examples/domino/modulus/utils/__init__.py create mode 100644 examples/domino/modulus/utils/domino/utils.py create mode 100644 examples/domino/modulus/utils/sdf.py create mode 100644 examples/domino/test.py create mode 100644 examples/domino/train.py diff --git a/examples/domino/README.md b/examples/domino/README.md new file mode 100644 index 0000000000..0994a2e878 --- /dev/null +++ b/examples/domino/README.md @@ -0,0 +1,43 @@ +# DoMINO: Decomposable Multi-scale Iterative Neural Operator for External Aerodynamics + +DoMINO代码复现。 + +## 安装依赖 + +```shell +# 安装PaddlePaddle + +cd /path/PaddleScience +pip install -e . + +cd /path/PaddleScience/examples/domino +pip install -r requirements.txt +``` + +## 数据集下载与处理 + +1. 参考[DoMINO](https://github.com/NVIDIA/modulus/tree/main/examples/cfd/external_aerodynamics/domino#training-the-domino-model)数据下载和处理方式,执行`download_aws_dataset.sh`和`process_data.py`,获取数据。 + +2. 本仓库训练数据为`process_data.py`的后处理数据。 + +## 训练 + +1. 修改`conf/config.yaml`路径, 原始配置文件参考[DoMINO](https://github.com/NVIDIA/modulus/blob/main/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml)。 + +2. 训练 + +```shell +cd /path/PaddleScience/examples/domino +python train.py +``` + +## 推理 + +1. 修改`conf/config.yaml`路径, 原始配置文件参考[DoMINO](https://github.com/NVIDIA/modulus/blob/main/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml)。 + +2. 推理 + +```shell +cd /path/PaddleScience/examples/domino +python test.py +``` diff --git a/examples/domino/conf/config.yaml b/examples/domino/conf/config.yaml index 6fdd14d0d2..2a124b1107 100644 --- a/examples/domino/conf/config.yaml +++ b/examples/domino/conf/config.yaml @@ -17,6 +17,7 @@ project: # Project name name: AWS_Dataset +seed: 42 exp_tag: 1 # Experiment tag # Main output directory. output: outputs/${project.name}/${exp_tag} @@ -27,8 +28,8 @@ hydra: # Hydra config output_subdir: hydra # Default is .hydra which causes files not being uploaded in W&B. data: # Input directory for training and validation data - input_dir: /lustre/rranade/modulus_dev/data/volume_data/ - input_dir_val: /lustre/rranade/modulus_dev/data/volume_data_val/ + input_dir: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/outputs/volume_data/ + input_dir_val: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/outputs/volume_data/ bounding_box: # Bounding box dimensions for computational domain min: [-3.5, -2.25 , -0.32] max: [8.5 , 2.25 , 3.00] @@ -55,12 +56,12 @@ variables: model: model_type: combined # train which model? surface, volume, combined loss_function: "mse" # mse or rmse - interp_res: [128, 64, 48] # resolution of latent space + interp_res: [64, 32, 24] # resolution of latent space use_sdf_in_basis_func: true # SDF in basis function network positional_encoding: false # calculate positional encoding? - volume_points_sample: 8192 # Number of points to sample in volume per epoch - surface_points_sample: 8192 # Number of points to sample on surface per epoch - geom_points_sample: 200_000 # Number of points to sample on STL per epoch + volume_points_sample: 1024 # Number of points to sample in volume per epoch + surface_points_sample: 1024 # Number of points to sample on surface per epoch + geom_points_sample: 2_000 # Number of points to sample on STL per epoch surface_neighbors: true # Pre-compute surface neighborhood from input data num_surface_neighbors: 7 # How many neighbors? use_surface_normals: true # Use surface normals and surface areas for surface computation? @@ -95,11 +96,10 @@ model: scaling_params: [30.0, 1.226] # [inlet_velocity, air_density] train: # Training configurable parameters - epochs: 500 + epochs: 50 checkpoint_interval: 1 dataloader: batch_size: 1 - pin_memory: true sampler: shuffle: true drop_last: false @@ -108,15 +108,14 @@ train: # Training configurable parameters val: # Validation configurable parameters dataloader: batch_size: 1 - pin_memory: true sampler: shuffle: true drop_last: false eval: # Testing configurable parameters - test_path: /lustre/rranade/benchmarking/drivaer_aws_surface_test_new/ - save_path: /lustre/rranade/domino/mesh_predictions_surf_final1/ - checkpoint_name: DoMINO.0.50.pt + test_path: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/drivaer_data_full_new + save_path: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/outputs/mesh_predictions_surf_final1/ + checkpoint_name: /home/aistudio/xiaoyewww/PaddleScience/examples/domino/outputs/AWS_Dataset/1/models/DoMINO.0.30.pdparams data_processor: # Data processor configurable parameters kind: drivaer_aws # must be either drivesim or drivaer_aws diff --git a/examples/domino/modulus/__init__.py b/examples/domino/modulus/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/domino/layers/__init__.py b/examples/domino/modulus/datapipes/__init__.py similarity index 100% rename from examples/domino/layers/__init__.py rename to examples/domino/modulus/datapipes/__init__.py diff --git a/examples/domino/modulus/datapipes/cae/__init__.py b/examples/domino/modulus/datapipes/cae/__init__.py new file mode 100644 index 0000000000..71e4d00436 --- /dev/null +++ b/examples/domino/modulus/datapipes/cae/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .domino_datapipe import DoMINODataPipe # noqa: F401 diff --git a/examples/domino/modulus/datapipes/cae/domino_datapipe.py b/examples/domino/modulus/datapipes/cae/domino_datapipe.py new file mode 100644 index 0000000000..c5768eb4f7 --- /dev/null +++ b/examples/domino/modulus/datapipes/cae/domino_datapipe.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code provides the datapipe for reading the processed npy files, +generating multi-res grids, calculating signed distance fields, +positional encodings, sampling random points in the volume and on surface, +normalizing fields and returning the output tensors as a dictionary. + +This datapipe also non-dimensionalizes the fields, so the order in which the variables should +be fixed: velocity, pressure, turbulent viscosity for volume variables and +pressure, wall-shear-stress for surface variables. The different parameters such as +variable names, domain resolution, sampling size etc. are configurable in config.yaml. +""" + +from pathlib import Path +from typing import Literal +from typing import Optional +from typing import Sequence +from typing import Union + +import numpy as np +from paddle.io import Dataset + +from ...utils.domino.utils import KDTree +from ...utils.domino.utils import area_weighted_shuffle_array +from ...utils.domino.utils import calculate_center_of_mass +from ...utils.domino.utils import calculate_normal_positional_encoding +from ...utils.domino.utils import create_grid +from ...utils.domino.utils import get_filenames +from ...utils.domino.utils import normalize +from ...utils.domino.utils import pad +from ...utils.domino.utils import shuffle_array +from ...utils.domino.utils import standardize +from ...utils.sdf import signed_distance_field + + +class DoMINODataPipe(Dataset): + """ + Datapipe for DoMINO + + """ + + def __init__( + self, + data_path: Union[str, Path], # Input data path + phase: Literal["train", "val", "test"] = "train", # Train, test or val + surface_variables: Optional[Sequence] = ( + "pMean", + "wallShearStress", + ), # Names of surface variables + volume_variables: Optional[Sequence] = ( + "UMean", + "pMean", + ), # Names of volume variables + sampling: bool = False, # Sampling True or False + device: int = 0, # GPU device id + grid_resolution: Optional[Sequence] = ( + 256, + 96, + 64, + ), # Resolution of latent grid + normalize_coordinates: bool = False, # Normalize coordinates? + sample_in_bbox: bool = False, # Sample points in a specified bounding box + volume_points_sample: int = 1024, # Number of volume points sampled per batch + surface_points_sample: int = 1024, # Number of surface points sampled per batch + geom_points_sample: int = 300000, # Number of STL points sampled per batch + positional_encoding: bool = False, # Positional encoding, True or False + volume_factors=None, # Non-dimensionalization factors for volume variables + surface_factors=None, # Non-dimensionalization factors for surface variables + scaling_type=None, # Scaling min_max or mean_std + model_type=None, # Model_type, surface, volume or combined + bounding_box_dims=None, # Dimensions of bounding box + bounding_box_dims_surf=None, # Dimensions of bounding box + compute_scaling_factors=False, + num_surface_neighbors=11, # Surface neighbors to consider + ): + if isinstance(data_path, str): + data_path = Path(data_path) + data_path = data_path.expanduser() + + self.data_path = data_path + + if phase not in [ + "train", + "val", + "test", + ]: + raise AssertionError( + f"phase should be one of ['train', 'val', 'test'], got {phase}" + ) + + if not self.data_path.exists(): + raise AssertionError(f"Path {self.data_path} does not exist") + + if not self.data_path.is_dir(): + raise AssertionError(f"Path {self.data_path} is not a directory") + + self.sampling = sampling + self.grid_resolution = grid_resolution + self.normalize_coordinates = normalize_coordinates + self.model_type = model_type + self.bounding_box_dims = [] + self.bounding_box_dims.append(np.asarray(bounding_box_dims.max)) + self.bounding_box_dims.append(np.asarray(bounding_box_dims.min)) + + self.bounding_box_dims_surf = [] + self.bounding_box_dims_surf.append(np.asarray(bounding_box_dims_surf.max)) + self.bounding_box_dims_surf.append(np.asarray(bounding_box_dims_surf.min)) + + self.filenames = get_filenames(self.data_path) + total_files = len(self.filenames) + + self.phase = phase + if phase == "train": + self.indices = np.array(range(total_files)) + elif phase == "val": + self.indices = np.array(range(total_files)) + elif phase == "test": + self.indices = np.array(range(total_files)) + + np.random.shuffle(self.indices) + self.surface_variables = surface_variables + self.volume_variables = volume_variables + self.volume_points = volume_points_sample + self.surface_points = surface_points_sample + self.geom_points_sample = geom_points_sample + self.sample_in_bbox = sample_in_bbox + self.device = device + self.positional_encoding = positional_encoding + self.volume_factors = volume_factors + self.surface_factors = surface_factors + self.scaling_type = scaling_type + self.compute_scaling_factors = compute_scaling_factors + self.num_surface_neighbors = num_surface_neighbors + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + index = self.indices[idx] + cfd_filename = self.filenames[index] + + filepath = self.data_path / cfd_filename + data_dict = np.load(filepath, allow_pickle=True).item() + + stl_vertices = data_dict["stl_coordinates"] + stl_centers = data_dict["stl_centers"] + mesh_indices_flattened = data_dict["stl_faces"] + stl_sizes = data_dict["stl_areas"] + + # Check if stream velocity in keys + if "stream_velocity" in data_dict.keys(): + STREAM_VELOCITY = data_dict["stream_velocity"] + AIR_DENSITY = data_dict["air_density"] + else: + AIR_DENSITY = 1.205 + STREAM_VELOCITY = 30.00 + + # + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + + # Center of mass calculation + center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + + if self.bounding_box_dims_surf is None: + s_max = np.amax(stl_vertices, 0) + s_min = np.amin(stl_vertices, 0) + else: + s_max = np.float32(self.bounding_box_dims_surf[0]) + s_min = np.float32(self.bounding_box_dims_surf[1]) + + nx, ny, nz = self.grid_resolution + + surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) + surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_surf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + surf_grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + surf_grid = np.float32(surf_grid) + sdf_surf_grid = np.float32(sdf_surf_grid) + surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) + + if self.model_type == "volume" or self.model_type == "combined": + volume_coordinates = data_dict["volume_mesh_centers"] + volume_fields = data_dict["volume_fields"] + + if not self.compute_scaling_factors: + if self.bounding_box_dims is None: + c_max = s_max + (s_max - s_min) / 2 + c_min = s_min - (s_max - s_min) / 2 + c_min[2] = s_min[2] + else: + c_max = np.float32(self.bounding_box_dims[0]) + c_min = np.float32(self.bounding_box_dims[1]) + + ids_in_bbox = np.where( + (volume_coordinates[:, 0] > c_min[0]) + & (volume_coordinates[:, 0] < c_max[0]) + & (volume_coordinates[:, 1] > c_min[1]) + & (volume_coordinates[:, 1] < c_max[1]) + & (volume_coordinates[:, 2] > c_min[2]) + & (volume_coordinates[:, 2] < c_max[2]) + ) + + if self.sample_in_bbox: + volume_coordinates = volume_coordinates[ids_in_bbox] + volume_fields = volume_fields[ids_in_bbox] + + dx, dy, dz = ( + (c_max[0] - c_min[0]) / nx, + (c_max[1] - c_min[1]) / ny, + (c_max[2] - c_min[2]) / nz, + ) + + # Generate a grid of specified resolution to map the bounding box + # The grid is used for capturing structured geometry features and SDF representation of geometry + grid = create_grid(c_max, c_min, [nx, ny, nz]) + grid_reshaped = grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + + if self.sampling: + volume_coordinates_sampled, idx_volume = shuffle_array( + volume_coordinates, self.volume_points + ) + if volume_coordinates_sampled.shape[0] < self.volume_points: + volume_coordinates_sampled = pad( + volume_coordinates_sampled, + self.volume_points, + pad_value=-10.0, + ) + volume_fields = volume_fields[idx_volume] + volume_coordinates = volume_coordinates_sampled + + sdf_nodes, sdf_node_closest_point = signed_distance_field( + stl_vertices, + mesh_indices_flattened, + volume_coordinates, + include_hit_points=True, + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) + sdf_node_closest_point = sdf_node_closest_point.numpy() + + if self.positional_encoding: + pos_normals_closest_vol = calculate_normal_positional_encoding( + volume_coordinates, + sdf_node_closest_point, + cell_length=[dx, dy, dz], + ) + pos_normals_com_vol = calculate_normal_positional_encoding( + volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_normals_closest_vol = ( + volume_coordinates - sdf_node_closest_point + ) + pos_normals_com_vol = volume_coordinates - center_of_mass + + if self.normalize_coordinates: + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(grid, c_max, c_min) + + if self.scaling_type is not None: + if self.volume_factors is not None: + if self.scaling_type == "mean_std_scaling": + vol_mean = self.volume_factors[0] + vol_std = self.volume_factors[1] + volume_fields = standardize( + volume_fields, vol_mean, vol_std + ) + elif self.scaling_type == "min_max_scaling": + vol_min = self.volume_factors[1] + vol_max = self.volume_factors[0] + volume_fields = normalize(volume_fields, vol_max, vol_min) + + volume_fields = np.float32(volume_fields) + pos_normals_closest_vol = np.float32(pos_normals_closest_vol) + pos_normals_com_vol = np.float32(pos_normals_com_vol) + volume_coordinates = np.float32(volume_coordinates) + sdf_nodes = np.float32(sdf_nodes) + sdf_grid = np.float32(sdf_grid) + grid = np.float32(grid) + vol_grid_max_min = np.float32(np.asarray([c_min, c_max])) + else: + pos_normals_closest_vol = None + pos_normals_com_vol = None + sdf_nodes = None + sdf_grid = None + grid = None + vol_grid_max_min = None + + else: + volume_coordinates = None + volume_fields = None + pos_normals_closest_vol = None + pos_normals_com_vol = None + sdf_nodes = None + sdf_grid = None + grid = None + vol_grid_max_min = None + + if self.model_type == "surface" or self.model_type == "combined": + surface_coordinates = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_sizes = data_dict["surface_areas"] + surface_fields = data_dict["surface_fields"] + + if not self.compute_scaling_factors: + + c_max = np.float32(self.bounding_box_dims[0]) + c_min = np.float32(self.bounding_box_dims[1]) + + ids_in_bbox = np.where( + (surface_coordinates[:, 0] > c_min[0]) + & (surface_coordinates[:, 0] < c_max[0]) + & (surface_coordinates[:, 1] > c_min[1]) + & (surface_coordinates[:, 1] < c_max[1]) + & (surface_coordinates[:, 2] > c_min[2]) + & (surface_coordinates[:, 2] < c_max[2]) + ) + surface_coordinates = surface_coordinates[ids_in_bbox] + surface_normals = surface_normals[ids_in_bbox] + surface_sizes = surface_sizes[ids_in_bbox] + surface_fields = surface_fields[ids_in_bbox] + + # Get neighbors + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query( + surface_coordinates, k=self.num_surface_neighbors + ) + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_sizes = surface_sizes[ii] + surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] + + dx, dy, dz = ( + (s_max[0] - s_min[0]) / nx, + (s_max[1] - s_min[1]) / ny, + (s_max[2] - s_min[2]) / nz, + ) + + if self.positional_encoding: + pos_normals_com_surface = calculate_normal_positional_encoding( + surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_normals_com_surface = surface_coordinates - center_of_mass + + if self.normalize_coordinates: + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + + if self.sampling: + ( + surface_coordinates_sampled, + idx_surface, + ) = area_weighted_shuffle_array( + surface_coordinates, self.surface_points, surface_sizes + ) + if surface_coordinates_sampled.shape[0] < self.surface_points: + surface_coordinates_sampled = pad( + surface_coordinates_sampled, + self.surface_points, + pad_value=-10.0, + ) + + surface_fields = surface_fields[idx_surface] + pos_normals_com_surface = pos_normals_com_surface[idx_surface] + surface_normals = surface_normals[idx_surface] + surface_sizes = surface_sizes[idx_surface] + surface_neighbors = surface_neighbors[idx_surface] + surface_neighbors_normals = surface_neighbors_normals[idx_surface] + surface_neighbors_sizes = surface_neighbors_sizes[idx_surface] + surface_coordinates = surface_coordinates_sampled + + if self.scaling_type is not None: + if self.surface_factors is not None: + if self.scaling_type == "mean_std_scaling": + surf_mean = self.surface_factors[0] + surf_std = self.surface_factors[1] + surface_fields = standardize( + surface_fields, surf_mean, surf_std + ) + elif self.scaling_type == "min_max_scaling": + surf_min = self.surface_factors[1] + surf_max = self.surface_factors[0] + surface_fields = normalize( + surface_fields, surf_max, surf_min + ) + + surface_coordinates = np.float32(surface_coordinates) + surface_fields = np.float32(surface_fields) + surface_sizes = np.float32(surface_sizes) + surface_normals = np.float32(surface_normals) + surface_neighbors = np.float32(surface_neighbors) + surface_neighbors_normals = np.float32(surface_neighbors_normals) + surface_neighbors_sizes = np.float32(surface_neighbors_sizes) + pos_normals_com_surface = np.float32(pos_normals_com_surface) + else: + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_normals_com_surface = None + + else: + surface_coordinates = None + surface_fields = None + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_normals_com_surface = None + + if self.sampling: + geometry_points = self.geom_points_sample + geometry_coordinates_sampled, idx_geometry = shuffle_array( + stl_vertices, geometry_points + ) + if geometry_coordinates_sampled.shape[0] < geometry_points: + geometry_coordinates_sampled = pad( + geometry_coordinates_sampled, geometry_points, pad_value=-100.0 + ) + geom_centers = geometry_coordinates_sampled + else: + geom_centers = stl_vertices + + geom_centers = np.float32(geom_centers) + + if self.model_type == "combined": + # Add the parameters to the dictionary + return { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "surface_fields": surface_fields, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), + } + elif self.model_type == "surface": + return { + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "surf_grid": surf_grid, + "sdf_surf_grid": sdf_surf_grid, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "surface_fields": surface_fields, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), + } + elif self.model_type == "volume": + return { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), -1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), -1 + ), + } + + +if __name__ == "__main__": + fm_data = DoMINODataPipe( + data_path="/code/processed_data/new_models_1/", + phase="train", + sampling=False, + sample_in_bbox=False, + ) diff --git a/examples/domino/modulus/distributed/__init__.py b/examples/domino/modulus/distributed/__init__.py new file mode 100644 index 0000000000..4bc2bff921 --- /dev/null +++ b/examples/domino/modulus/distributed/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .manager import DistributedManager # noqa: F401 +from .manager import ModulusUndefinedGroupError # noqa: F401 +from .manager import ModulusUninitializedDistributedManagerWarning # noqa: F401 diff --git a/examples/domino/modulus/distributed/config.py b/examples/domino/modulus/distributed/config.py new file mode 100644 index 0000000000..831fed1aef --- /dev/null +++ b/examples/domino/modulus/distributed/config.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from treelib import Tree + + +class ProcessGroupNode: + """ + Class to store the attributes of a distributed process group + + Attributes + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, number of processes in the process group + """ + + def __init__( + self, + name: str, + size: Optional[int] = None, + ): + """ + Constructor for the ProcessGroupNode class + + Parameters + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, size of the process group + """ + self.name = name + self.size = size + + def __str__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return "ProcessGroupNode(" f"name={self.name}, " f"size={self.size}, " + + def __repr__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return self.__str__() + + +class ProcessGroupConfig: + """ + Class to define the configuration of a model's parallel process group structure as a + tree. Each node of the tree is of type `ProcessGroupNode`. + + Once the process group config structure (i.e, the tree structure) is set, it is + sufficient to set only the sizes for each leaf process group. Then, the size of + every parent group can be automatically computed as the product reduction of the + sub-tree of that parent group node. + + Examples + -------- + >>> from modulus.distributed import ProcessGroupNode, ProcessGroupConfig + >>> + >>> # Create world group that contains all processes that are part of this job + >>> world = ProcessGroupNode("world") + >>> + >>> # Create the process group config with the highest level process group + >>> config = ProcessGroupConfig(world) + >>> + >>> # Create model and data parallel sub-groups + >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction + >>> # Nodes can be added with either the name of the node or the node itself + >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world) + >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world") + >>> + >>> # Create spatial and channel parallel sub-groups + >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel") + >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel") + >>> + >>> config.leaf_groups() + ['data_parallel', 'spatial_parallel', 'channel_parallel'] + >>> + >>> # Set leaf group sizes + >>> # Note: product of all leaf-node sizes should be the world size + >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4} + >>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too + >>> config.get_node("model_parallel").size + 6 + """ + + def __init__(self, node: ProcessGroupNode): + """ + Constructor to the ProcessGroupConfig class + + Parameters + ---------- + node : ProcessGroupNode + Root node of the tree, typically would be 'world' + Note, it is generally recommended to set the child groups for 'world' + to 'model_parallel' and 'data_parallel' to aid with distributed + data parallel training unless there is a specific reason to choose a + different structure + """ + self.root = node + self.root_id = node.name + self.tree = Tree() + self.tree.create_node(node.name, node.name, data=node) + + def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]): + """ + Add a node to the process group config + + Parameters + ---------- + node : ProcessGroupNode + The new node to be added to the config + parent : Union[str, ProcessGroupNode] + Parent node of the node to be added. Should already be in the config. + If str, it is the name of the parent node. Otherwise, the parent + ProcessGroupNode itself. + """ + if isinstance(parent, ProcessGroupNode): + parent = parent.name + self.tree.create_node(node.name, node.name, data=node, parent=parent) + + def get_node(self, name: str) -> ProcessGroupNode: + """ + Method to get the node given the name of the node + + Parameters + ---------- + name : str + Name of the node to retrieve + + Returns + ------- + ProcessGroupNode + Node with the given name from the config + """ + return self.tree.get_node(name).data + + def update_parent_sizes(self, verbose: bool = False) -> int: + """ + Method to update parent node sizes after setting the sizes for each leaf node + + Parameters + ---------- + verbose : bool + If True, print a message each time a parent node size was updated + + Returns + ------- + int + Size of the root node + """ + return _tree_product_reduction(self.tree, self.root_id, verbose=verbose) + + def leaf_groups(self) -> List[str]: + """ + Get a list of all leaf group names + + Returns + ------- + List[str] + List of all leaf node names + """ + return [n.identifier for n in self.tree.leaves()] + + def set_leaf_group_sizes( + self, group_sizes: Dict[str, int], update_parent_sizes: bool = True + ): + """ + Set process group sizes for all leaf groups + + Parameters + ---------- + group_sizes : Dict[str, int] + Dictionary with a mapping of each leaf group name to its size + update_parent_sizes : bool + Update all parent group sizes based on the leaf group if True + If False, only set the leaf group sizes. + """ + for id, size in group_sizes.items(): + if not self.tree.contains(id): + raise AssertionError( + f"Process group {id} is not in this process group config" + ) + node = self.tree.get_node(id) + if not node.is_leaf(): + raise AssertionError(f"Process group {id} is not a leaf group") + node.data.size = size + + if update_parent_sizes: + self.update_parent_sizes() + + +def _tree_product_reduction(tree, node_id, verbose=False): + """ + Function to traverse a tree and compute the product reduction of + the sub-tree for each node starting from `node_id` + """ + children = tree.children(node_id) + node = tree.get_node(node_id) + if not children: + if node.data.size is None: + raise AssertionError("Leaf nodes should have a valid size set") + return node.data.size + + product = 1 + + for child in children: + product *= _tree_product_reduction(tree, child.identifier) + + if node.data.size != product: + if verbose: + print( + "Updating size of node " + f"{node.data.name} from {node.data.size} to {product}" + ) + node.data.size = product + + return product diff --git a/examples/domino/modulus/distributed/manager.py b/examples/domino/modulus/distributed/manager.py new file mode 100644 index 0000000000..eccc69f6a3 --- /dev/null +++ b/examples/domino/modulus/distributed/manager.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import queue +from typing import Optional +from warnings import warn + +import numpy as np +import paddle +import paddle.distributed as dist + +from .config import ProcessGroupConfig +from .config import ProcessGroupNode + + +class ModulusUndefinedGroupError(Exception): + """Exception for querying an undefined process group using the Modulus DistributedManager""" + + def __init__(self, name: str): + """ + + Parameters + ---------- + name : str + Name of the process group being queried. + + """ + message = ( + f"Cannot query process group '{name}' before it is explicitly created." + ) + super().__init__(message) + + +class ModulusUninitializedDistributedManagerWarning(Warning): + """Warning to indicate usage of an uninitialized DistributedManager""" + + def __init__(self): + message = ( + "A DistributedManager object is being instantiated before " + + "this singleton class has been initialized. Instantiating a manager before " + + "initialization can lead to unexpected results where processes fail " + + "to communicate. Initialize the distributed manager via " + + "DistributedManager.initialize() before instantiating." + ) + super().__init__(message) + + +class DistributedManager(object): + """Distributed Manager for setting up distributed training environment. + + This is a singleton that creates a persistance class instance for storing parallel + environment information through out the life time of the program. This should be + used to help set up Distributed Data Parallel and parallel datapipes. + + Note + ---- + One should call `DistributedManager.initialize()` prior to constructing a manager + object + + Example + ------- + >>> DistributedManager.initialize() + >>> manager = DistributedManager() + >>> manager.rank + 0 + >>> manager.world_size + 1 + """ + + _shared_state = {} + + def __new__(cls): + obj = super(DistributedManager, cls).__new__(cls) + obj.__dict__ = cls._shared_state + + # Set the defaults + if not hasattr(obj, "_rank"): + obj._rank = 0 + if not hasattr(obj, "_world_size"): + obj._world_size = 1 + if not hasattr(obj, "_local_rank"): + obj._local_rank = 0 + if not hasattr(obj, "_distributed"): + obj._distributed = False + if not hasattr(obj, "_device"): + obj._device = "gpu:0" if paddle.device.cuda.device_count() else "cpu" + if not hasattr(obj, "_cuda"): + obj._cuda = paddle.device.cuda.device_count() >= 1 + if not hasattr(obj, "_broadcast_buffers"): + obj._broadcast_buffers = False + if not hasattr(obj, "_find_unused_parameters"): + obj._find_unused_parameters = False + if not hasattr(obj, "_initialization_method"): + obj._initialization_method = "None" + if not hasattr(obj, "_groups"): + obj._groups = {} + if not hasattr(obj, "_group_ranks"): + obj._group_ranks = {} + if not hasattr(obj, "_group_names"): + obj._group_names = {} + if not hasattr(obj, "_is_initialized"): + obj._is_initialized = False + + return obj + + def __init__(self): + if not self._is_initialized: + raise ModulusUninitializedDistributedManagerWarning() + super().__init__() + + @property + def rank(self): + """Process rank""" + return self._rank + + @property + def local_rank(self): + """Process rank on local machine""" + return self._local_rank + + @property + def world_size(self): + """Number of processes in distributed enviroment""" + return self._world_size + + @property + def device(self): + """Process device""" + return self._device + + @property + def distributed(self): + """Distributed enviroment""" + return self._distributed + + @property + def cuda(self): + """If cuda is available""" + return self._cuda + + @property + def group_names(self): + """ + Returns a list of all named process groups created + """ + return self._groups.keys() + + def group(self, name=None): + """ + Returns a process group with the given name + If name is None, group is also None indicating the default process group + If named group does not exist, ModulusUndefinedGroupError exception is raised + """ + if name in self._groups.keys(): + return self._groups[name] + elif name is None: + return None + else: + raise ModulusUndefinedGroupError(name) + + def group_size(self, name=None): + """ + Returns the size of named process group + """ + if name is None: + return self._world_size + group = self.group(name) + return dist.get_world_size(group=group) + + def group_rank(self, name=None): + """ + Returns the rank in named process group + """ + if name is None: + return self._rank + group = self.group(name) + return dist.get_rank(group=group) + + def group_name(self, group=None): + """ + Returns the name of process group + """ + if group is None: + return None + return self._group_names[group] + + @property + def broadcast_buffers(self): + """broadcast_buffers in PyTorch DDP""" + return self._broadcast_buffers + + @broadcast_buffers.setter + def broadcast_buffers(self, broadcast: bool): + """Setter for broadcast_buffers""" + self._broadcast_buffers = broadcast + + @property + def find_unused_parameters(self): + """find_unused_parameters in PyTorch DDP""" + return self._find_unused_parameters + + @find_unused_parameters.setter + def find_unused_parameters(self, find_params: bool): + """Setter for find_unused_parameters""" + if find_params: + warn( + "Setting `find_unused_parameters` in DDP to true, " + "use only if necessary." + ) + self._find_unused_parameters = find_params + + def __str__(self): + output = ( + f"Initialized process {self.rank} of {self.world_size} using " + f"method '{self._initialization_method}'. Device set to {str(self.device)}" + ) + return output + + @classmethod + def is_initialized(cls) -> bool: + """If manager singleton has been initialized""" + return cls._shared_state.get("_is_initialized", False) + + @staticmethod + def get_available_backend(): + """Get communication backend""" + if ( + paddle.device.cuda.device_count() >= 1 + and paddle.core.is_compiled_with_nccl() + ): + return "nccl" + else: + return "gloo" + + @staticmethod + def initialize_env(): + """Setup method using generic initialization""" + rank = int(os.environ.get("RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + if "LOCAL_RANK" in os.environ: + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + local_rank = int(local_rank) + else: + local_rank = rank % paddle.device.cuda.device_count() + + else: + local_rank = rank % paddle.device.cuda.device_count() + + # Read env variables + addr = os.environ.get("MASTER_ADDR") + port = os.environ.get("MASTER_PORT") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + ) + + @staticmethod + def initialize_open_mpi(addr, port): + """Setup method using OpenMPI initialization""" + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="openmpi", + ) + + @staticmethod + def initialize_slurm(port): + """Setup method using SLURM initialization""" + rank = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NPROCS")) + local_rank = int(os.environ.get("SLURM_LOCALID")) + addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="slurm", + ) + + @staticmethod + def initialize(): + """ + Initialize distributed manager + + Current supported initialization methods are: + `ENV`: PyTorch environment variable initialization + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + `SLURM`: Initialization on SLURM systems. + Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and + `SLURM_LAUNCH_NODE_IPADDR` environment variables. + `OPENMPI`: Initialization for OpenMPI launchers. + Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and + `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. + + Initialization by default is done using the first valid method in the order + listed above. Initialization method can also be explicitly controlled using the + `MODULUS_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it + to one of the options above. + """ + if DistributedManager.is_initialized(): + warn("Distributed manager is already intialized") + return + + addr = os.getenv("MASTER_ADDR", "localhost") + port = os.getenv("MASTER_PORT", "12355") + # https://pytorch.org/docs/master/notes/cuda.html#id5 + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD") + if initialization_method is None: + try: + DistributedManager.initialize_env() + except TypeError: + if "SLURM_PROCID" in os.environ: + DistributedManager.initialize_slurm(port) + elif "OMPI_COMM_WORLD_RANK" in os.environ: + DistributedManager.initialize_open_mpi(addr, port) + else: + warn( + "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" + ) + DistributedManager._shared_state["_is_initialized"] = True + elif initialization_method == "ENV": + DistributedManager.initialize_env() + elif initialization_method == "SLURM": + DistributedManager.initialize_slurm(port) + elif initialization_method == "OPENMPI": + DistributedManager.initialize_open_mpi(addr, port) + else: + raise RuntimeError( + "Unknown initialization method " + f"{initialization_method}. " + "Supported values for " + "MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are " + "ENV, SLURM and OPENMPI" + ) + + # Set per rank numpy random seed for data sampling + np.random.seed(seed=DistributedManager().rank) + + @staticmethod + def setup( + rank=0, + world_size=1, + local_rank=None, + addr="localhost", + port="12355", + backend="nccl", + method="env", + ): + """Set up PyTorch distributed process group and update manager attributes""" + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + + DistributedManager._shared_state["_is_initialized"] = True + manager = DistributedManager() + + manager._distributed = paddle.distributed.is_available() + if manager._distributed: + # Update rank and world_size if using distributed + manager._rank = rank + manager._world_size = world_size + if local_rank is None: + manager._local_rank = rank % paddle.device.cuda.device_count() + else: + manager._local_rank = local_rank + + manager._device = ( + f"gpu:{manager.local_rank}" + if paddle.device.cuda.device_count() >= 1 + else "cpu" + ) + + if manager._distributed: + # Setup distributed process group + dist.init_process_group() + + if paddle.device.cuda.device_count() >= 0: + # Set device for this process and empty cache to optimize memory usage + paddle.set_device(manager.device) + paddle.device.cuda.empty_cache() + + manager._initialization_method = method + + @staticmethod + def create_process_subgroup( + name: str, size: int, group_name: Optional[str] = None, verbose: bool = False + ): # pragma: no cover + """ + Create a process subgroup of a parent process group. This must be a collective + call by all processes participating in this application. + + Parameters + ---------- + name : str + Name of the process subgroup to be created. + + size : int + Size of the process subgroup to be created. This must be an integer factor of + the parent group's size. + + group_name : Optional[str] + Name of the parent process group, optional. If None, the default process group + will be used. Default None. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "paddle.distributed is unavailable. " + "Check paddle build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if name in manager._groups: + raise AssertionError(f"Group with name {name} already exists") + + # Get parent group's params + group = manager._groups[group_name] if group_name else None + group_size = dist.get_world_size(group=group) + num_groups = manager.world_size // group_size + + # Get number of sub-groups per parent group + if group_size % size != 0: + raise AssertionError( + f"Cannot divide group size {group_size} evenly into subgroups of" + f" size {size}" + ) + num_subgroups = group_size // size + + # Create all the sub-groups + # Note: all ranks in the job need to create all sub-groups in + # the same order even if a rank is not part of a sub-group + manager._group_ranks[name] = [] + for g in range(num_groups): + for i in range(num_subgroups): + # Get global ranks that are part of this sub-group + start = i * size + end = start + size + if group_name: + ranks = manager._group_ranks[group_name][g][start:end] + else: + ranks = list(range(start, end)) + # Create sub-group and keep track of ranks + tmp_group = dist.new_group(ranks=ranks) + manager._group_ranks[name].append(ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[name] = tmp_group + manager._group_names[tmp_group] = name + + if verbose and manager.rank == 0: + print(f"Process group '{name}':") + for grp in manager._group_ranks[name]: + print(" ", grp) + + @staticmethod + def create_orthogonal_process_group( + orthogonal_group_name: str, group_name: str, verbose: bool = False + ): # pragma: no cover + """ + Create a process group that is orthogonal to the specified process group. + + Parameters + ---------- + orthogonal_group_name : str + Name of the orthogonal process group to be created. + + group_name : str + Name of the existing process group. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "paddle.distributed is unavailable. " + "Check paddle build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if group_name not in manager._groups: + raise ValueError(f"Group with name {group_name} does not exist") + if orthogonal_group_name in manager._groups: + raise ValueError(f"Group with name {orthogonal_group_name} already exists") + + group_ranks = manager._group_ranks[group_name] + orthogonal_ranks = [list(i) for i in zip(*group_ranks)] + + for ranks in orthogonal_ranks: + tmp_group = dist.new_group(ranks=ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[orthogonal_group_name] = tmp_group + manager._group_names[tmp_group] = orthogonal_group_name + + manager._group_ranks[orthogonal_group_name] = orthogonal_ranks + + if verbose and manager.rank == 0: + print(f"Process group '{orthogonal_group_name}':") + for grp in manager._group_ranks[orthogonal_group_name]: + print(" ", grp) + + @staticmethod + def create_group_from_node( + node: ProcessGroupNode, + parent: Optional[str] = None, + verbose: bool = False, + ): # pragma: no cover + if node.size is None: + raise AssertionError( + "Cannot create groups from a ProcessGroupNode that is not fully" + " populated. Ensure that config.set_leaf_group_sizes is called first" + " with `update_parent_sizes = True`" + ) + + DistributedManager.create_process_subgroup( + node.name, node.size, group_name=parent, verbose=verbose + ) + # Create orthogonal process group + orthogonal_group = f"__orthogonal_to_{node.name}" + DistributedManager.create_orthogonal_process_group( + orthogonal_group, node.name, verbose=verbose + ) + return orthogonal_group + + @staticmethod + def create_groups_from_config( + config: ProcessGroupConfig, verbose: bool = False + ): # pragma: no cover + # Traverse process group tree in breadth first order + # to create nested process groups + q = queue.Queue() + q.put(config.root_id) + DistributedManager.create_group_from_node(config.root) + + while not q.empty(): + node_id = q.get() + if verbose: + print(f"Node ID: {node_id}") + + children = config.tree.children(node_id) + if verbose: + print(f" Children: {children}") + + parent_group = node_id + for child in children: + # Create child group and replace parent group by orthogonal group so + # that each child forms an independent block of processes + parent_group = DistributedManager.create_group_from_node( + child.data, + parent=parent_group, + ) + + # Add child ids to the queue + q.put(child.identifier) + + @staticmethod + def cleanup(): + """Clean up distributed group and singleton""" + # Destroying group.WORLD is enough for all process groups to get destroyed + if ( + "_is_initialized" in DistributedManager._shared_state + and DistributedManager._shared_state["_is_initialized"] + and "_distributed" in DistributedManager._shared_state + and DistributedManager._shared_state["_distributed"] + ): + if paddle.device.cuda.device_count() >= 1: + dist.barrier() + else: + dist.barrier() + dist.destroy_process_group() + DistributedManager._shared_state = {} diff --git a/examples/domino/modulus/launch/__init__.py b/examples/domino/modulus/launch/__init__.py new file mode 100644 index 0000000000..b2f171d4ac --- /dev/null +++ b/examples/domino/modulus/launch/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/domino/modulus/launch/logging/__init__.py b/examples/domino/modulus/launch/logging/__init__.py new file mode 100644 index 0000000000..f9b51fe41d --- /dev/null +++ b/examples/domino/modulus/launch/logging/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .console import PythonLogger # noqa: F401 diff --git a/examples/domino/modulus/launch/logging/console.py b/examples/domino/modulus/launch/logging/console.py new file mode 100644 index 0000000000..423157693e --- /dev/null +++ b/examples/domino/modulus/launch/logging/console.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from termcolor import colored + + +class PythonLogger: + """Simple console logger for DL training + This is a WIP + """ + + def __init__(self, name: str = "launch"): + self.logger = logging.getLogger(name) + + def file_logging(self, file_name: str = "launch.log"): + """Log to file""" + if os.path.exists(file_name): + try: + os.remove(file_name) + except FileNotFoundError: + # ignore if already removed (can happen with multiple processes) + pass + formatter = logging.Formatter( + "[%(asctime)s - %(name)s - %(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + filehandler = logging.FileHandler(file_name) + filehandler.setFormatter(formatter) + filehandler.setLevel(logging.DEBUG) + self.logger.addHandler(filehandler) + + def log(self, message: str): + """Log message""" + self.logger.info(message) + + def info(self, message: str): + """Log info""" + self.logger.info(colored(message, "light_blue")) + + def success(self, message: str): + """Log success""" + self.logger.info(colored(message, "light_green")) + + def warning(self, message: str): + """Log warning""" + self.logger.warning(colored(message, "light_yellow")) + + def error(self, message: str): + """Log error""" + self.logger.error(colored(message, "light_red")) + + +class RankZeroLoggingWrapper: + """Wrapper class to only log from rank 0 process in distributed training.""" + + def __init__(self, obj, dist): + self.obj = obj + self.dist = dist + + def __getattr__(self, name): + attr = getattr(self.obj, name) + if callable(attr): + + def wrapper(*args, **kwargs): + if self.dist.rank == 0: + return attr(*args, **kwargs) + else: + return None + + return wrapper + else: + return attr diff --git a/examples/domino/modulus/launch/logging/launch.py b/examples/domino/modulus/launch/logging/launch.py new file mode 100644 index 0000000000..c8714da61a --- /dev/null +++ b/examples/domino/modulus/launch/logging/launch.py @@ -0,0 +1,443 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +import time +from os import getcwd +from os import makedirs +from os.path import abspath +from os.path import exists +from os.path import join +from typing import Dict +from typing import Tuple +from typing import Union + +import torch +import torch.cuda.profiler as profiler +from modulus.distributed import DistributedManager +from modulus.distributed import reduce_loss + +from .console import PythonLogger + + +class LaunchLogger(object): + """Modulus Launch logger + + An abstracted logger class that takes care of several fundamental logging functions. + This class should first be initialized and then used via a context manager. This will + auto compute epoch metrics. This is the standard logger for Modulus examples. + + Parameters + ---------- + name_space : str + Namespace of logger to use. This will define the loggers title in the console and + the wandb group the metric is plotted + epoch : int, optional + Current epoch, by default 1 + num_mini_batch : Union[int, None], optional + Number of mini-batches used to calculate the epochs progress, by default None + profile : bool, optional + Profile code using nvtx markers, by default False + mini_batch_log_freq : int, optional + Frequency to log mini-batch losses, by default 100 + epoch_alert_freq : Union[int, None], optional + Epoch frequency to send training alert, by default None + + Example + ------- + >>> from modulus.launch.logging import LaunchLogger + >>> LaunchLogger.initialize() + >>> epochs = 3 + >>> for i in range(epochs): + ... with LaunchLogger("Train", epoch=i) as log: + ... # Log 3 mini-batches manually + ... log.log_minibatch({"loss": 1.0}) + ... log.log_minibatch({"loss": 2.0}) + ... log.log_minibatch({"loss": 3.0}) + """ + + _instances = {} + console_backend = True + wandb_backend = False + mlflow_backend = False + tensorboard_backend = False + enable_profiling = False + + mlflow_run = None + mlflow_client = None + + def __new__(cls, name_space, *args, **kwargs): + # If namespace already has an instance just return that + if name_space in cls._instances: + return cls._instances[name_space] + + # Otherwise create new singleton instance for this namespace + self = super().__new__(cls) # don't pass remaining parameters to object.__new__ + cls._instances[name_space] = self + + # Constructor set up to only be ran once by a logger + self.pyLogger = PythonLogger(name_space) + self.total_iteration_index = None + # Distributed + self.root = True + if DistributedManager.is_initialized(): + self.root = DistributedManager().rank == 0 + # Profiler utils + if torch.cuda.is_available(): + self.profiler = torch.autograd.profiler.emit_nvtx( + enabled=cls.enable_profiling + ) + self.start_event = torch.cuda.Event(enable_timing=True) + self.end_event = torch.cuda.Event(enable_timing=True) + else: + self.profiler = None + + return self + + def __init__( + self, + name_space: str, + epoch: int = 1, + num_mini_batch: Union[int, None] = None, + profile: bool = False, + mini_batch_log_freq: int = 100, + epoch_alert_freq: Union[int, None] = None, + ): + self.name_space = name_space + self.mini_batch_index = 0 + self.minibatch_losses = {} + self.epoch_losses = {} + + self.mini_batch_log_freq = mini_batch_log_freq + self.epoch_alert_freq = epoch_alert_freq + self.epoch = epoch + self.num_mini_batch = num_mini_batch + self.profile = profile + # Init initial iteration based on current epoch + if self.total_iteration_index is None: + if num_mini_batch is not None: + self.total_iteration_index = (epoch - 1) * num_mini_batch + else: + self.total_iteration_index = 0 + + # Set x axis metric to epoch for this namespace + if self.wandb_backend: + import wandb + + wandb.define_metric(name_space + "/mini_batch_*", step_metric="iter") + wandb.define_metric(name_space + "/*", step_metric="epoch") + + def log_minibatch(self, losses: Dict[str, float]): + """Logs metrics for a mini-batch epoch + + This function should be called every mini-batch iteration. It will accumulate + loss values over a datapipe. At the end of a epoch the average of these losses + from each mini-batch will get calculated. + + Parameters + ---------- + losses : Dict[str, float] + Dictionary of metrics/loss values to log + """ + self.mini_batch_index += 1 + self.total_iteration_index += 1 + for name, value in losses.items(): + if name not in self.minibatch_losses: + self.minibatch_losses[name] = 0 + self.minibatch_losses[name] += value + + # Log of mini-batch loss + if self.mini_batch_index % self.mini_batch_log_freq == 0: + # Backend Logging + mini_batch_metrics = {} + for name, value in losses.items(): + mini_batch_metrics[f"{self.name_space}/mini_batch_{name}"] = value + self._log_backends( + mini_batch_metrics, step=("iter", self.total_iteration_index) + ) + + # Console + if self.root: + message = "Mini-Batch Losses:" + for name, value in losses.items(): + message += f" {name} = {value:10.3e}," + message = message[:-1] + # If we have datapipe length we can get a percent complete + if self.num_mini_batch: + mbp = 100 * (float(self.mini_batch_index) / self.num_mini_batch) + message = f"[{mbp:.02f}%] " + message + + self.pyLogger.log(message) + + def log_epoch(self, losses: Dict[str, float]): + """Logs metrics for a single epoch + + Parameters + ---------- + losses : Dict[str, float] + Dictionary of metrics/loss values to log + """ + for name, value in losses.items(): + self.epoch_losses[name] = value + + def __enter__(self): + self.mini_batch_index = 0 + self.minibatch_losses = {} + self.epoch_losses = {} + + # Trigger profiling + if self.profile and self.profiler: + self.logger.warning(f"Starting profile for epoch {self.epoch}") + self.profiler.__enter__() + profiler.start() + + # Timing stuff + if torch.cuda.is_available(): + self.start_event.record() + else: + self.start_event = time.time() + + if self.mlflow_backend: + self.mlflow_client.update_run(self.mlflow_run.info.run_id, "RUNNING") + + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + # Abnormal exit dont log + if exc_type is not None: + if self.mlflow_backend: + self.mlflow_client.set_terminated( + self.mlflow_run.info.run_id, status="KILLED" + ) + return + # Reduce mini-batch losses + for name, value in self.minibatch_losses.items(): + process_loss = value / self.mini_batch_index + self.epoch_losses[name] = process_loss + # Compute global loss + if DistributedManager.is_initialized() and DistributedManager().distributed: + self.epoch_losses[name] = reduce_loss(process_loss) + + if self.root: + # Console printing + # TODO: add out of total epochs progress + message = f"Epoch {self.epoch} Metrics:" + for name, value in self.epoch_losses.items(): + message += f" {name} = {value:10.3e}," + message = message[:-1] + self.pyLogger.info(message) + + metrics = { + f"{self.name_space}/{key}": value + for key, value in self.epoch_losses.items() + } + + # Exit profiling + if self.profile and self.profiler: + self.logger.warning("Ending profile") + self.profiler.__exit__() + profiler.end() + + # Timing stuff, TODO: histograms not line plots + if torch.cuda.is_available(): + self.end_event.record() + torch.cuda.synchronize() + # Returns milliseconds + # https://pytorch.org/docs/stable/generated/torch.cuda.Event.html#torch.cuda.Event.elapsed_time + epoch_time = self.start_event.elapsed_time(self.end_event) / 1000.0 + else: + end_event = time.time() + epoch_time = end_event - self.start_event + + # Return MS for time / iter + time_per_iter = 1000 * epoch_time / max([1, self.mini_batch_index]) + + if self.root: + message = f"Epoch Execution Time: {epoch_time:10.3e}s" + message += f", Time/Iter: {time_per_iter:10.3e}ms" + self.pyLogger.info(message) + + metrics[f"{self.name_space}/Epoch Time (s)"] = epoch_time + metrics[f"{self.name_space}/Time per iter (ms)"] = time_per_iter + + self._log_backends(metrics, step=("epoch", self.epoch)) + + # TODO this should be in some on delete method / clean up + if self.mlflow_backend: + self.mlflow_client.set_terminated( + self.mlflow_run.info.run_id, status="FINISHED" + ) + + # Alert + if ( + self.epoch_alert_freq + and self.root + and self.epoch % self.epoch_alert_freq == 0 + ): + if self.wandb_backend: + import wandb + + from .wandb import alert + + # TODO: Make this a little more informative? + alert( + title=f"{sys.argv[0]} training progress report", + text=f"Run {wandb.run.name} is at epoch {self.epoch}.", + ) + + def _log_backends( + self, + metric_dict: Dict[str, float], + step: Tuple[str, int] = None, + ): + """Logs a dictionary of metrics to different supported backends + + Parameters + ---------- + metric_dict : Dict[str, float] + Metric dictionary + step : Tuple[str, int], optional + Tuple containing (step name, step index), by default None + print : bool, optional + Print metrics, by default False + """ + + # MLFlow Logging + if self.mlflow_backend: + for key, value in metric_dict.items(): + # If value is None just skip + if value is None: + continue + # Keys only allow alpha numeric, ., -, /, _ and spaces + key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key) + self.mlflow_client.log_metric( + self.mlflow_run.info.run_id, key, value, step=step[1] + ) + + # WandB Logging + if self.wandb_backend: + import wandb + + # For WandB send step in as a metric + # Step argument in lod function does not work with multiple log calls at + # different intervals + metric_dict[step[0]] = step[1] + wandb.log(metric_dict) + + def log_figure( + self, + figure, + artifact_file: str = "artifact", + plot_dir: str = "./", + log_to_file: bool = False, + ): + """Logs figures on root process to wand or mlflow. Will store it to file in case neither are selected. + + Parameters + ---------- + figure : Figure + matplotlib or plotly figure to plot + artifact_file : str, optional + File name. CAUTION overrides old files of same name + plot_dir : str, optional + output directory for plot + log_to_file : bool, optional + set to true in case figure shall be stored to file in addition to logging it to mlflow/wandb + """ + dist = DistributedManager() + if dist.rank != 0: + return + + if self.wandb_backend: + import wandb + + wandb.log({artifact_file: figure}) + + if self.mlflow_backend: + self.mlflow_client.log_figure( + figure=figure, + artifact_file=artifact_file, + run_id=self.mlflow_run.info.run_id, + ) + + if (not self.wandb_backend) and (not self.mlflow_backend): + log_to_file = True + + if log_to_file: + plot_dir = abspath(join(getcwd(), plot_dir)) + if not exists(plot_dir): + makedirs(plot_dir) + if not artifact_file.endswith(".png"): + artifact_file += ".png" + figure.savefig(join(plot_dir, artifact_file)) + + @classmethod + def toggle_wandb(cls, value: bool): + """Toggle WandB logging + + Parameters + ---------- + value : bool + Use WandB logging + """ + cls.wandb_backend = value + + @classmethod + def toggle_mlflow(cls, value: bool): + """Toggle MLFlow logging + + Parameters + ---------- + value : bool + Use MLFlow logging + """ + cls.mlflow_backend = value + + @staticmethod + def initialize(use_wandb: bool = False, use_mlflow: bool = False): + """Initialize logging singleton + + Parameters + ---------- + use_wandb : bool, optional + Use WandB logging, by default False + use_mlflow : bool, optional + Use MLFlow logging, by default False + """ + if use_wandb: + import wandb + + if wandb.run is None: + PythonLogger().warning("WandB not initialized, turning off") + use_wandb = False + + if use_wandb: + LaunchLogger.toggle_wandb(True) + wandb.define_metric("epoch") + wandb.define_metric("iter") + + # let only root process log to mlflow + if DistributedManager.is_initialized(): + if DistributedManager().rank != 0: + return + + if LaunchLogger.mlflow_run is None and use_mlflow: + PythonLogger().warning("MLFlow not initialized, turning off") + use_mlflow = False + + if use_mlflow: + LaunchLogger.toggle_mlflow(True) diff --git a/examples/domino/modulus/launch/logging/mlflow.py b/examples/domino/modulus/launch/logging/mlflow.py new file mode 100644 index 0000000000..69e85e0e8e --- /dev/null +++ b/examples/domino/modulus/launch/logging/mlflow.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime +from pathlib import Path +from typing import Literal +from typing import Tuple + +import torch + +try: + import mlflow # noqa: F401 for docs + from mlflow.entities.run import Run + from mlflow.tracking import MlflowClient +except ImportError: + raise ImportError( + "These utilities require the MLFlow library. Install MLFlow using `pip install mlflow`. " + + "For more info, refer: https://www.mlflow.org/docs/2.5.0/quickstart.html#install-mlflow" + ) + +from modulus.distributed import DistributedManager + +from .console import PythonLogger +from .launch import LaunchLogger + +logger = PythonLogger("mlflow") + + +def initialize_mlflow( + experiment_name: str, + experiment_desc: str = None, + run_name: str = None, + run_desc: str = None, + user_name: str = None, + mode: Literal["offline", "online", "ngc"] = "offline", + tracking_location: str = None, + artifact_location: str = None, +) -> Tuple[MlflowClient, Run]: + """Initializes MLFlow logging client and run. + + Parameters + ---------- + experiment_name : str + Experiment name + experiment_desc : str, optional + Experiment description, by default None + run_name : str, optional + Run name, by default None + run_desc : str, optional + Run description, by default None + user_name : str, optional + User name, by default None + mode : str, optional + MLFlow mode. Supports "offline", "online" and "ngc". Offline mode records logs to + local file system. Online mode is for remote tracking servers. NGC is specific + standardized setup for NGC runs, default "offline" + tracking_location : str, optional + Tracking location for MLFlow. For offline this would be an absolute folder directory. + For online mode this would be a http URI or databricks. For NGC, this option is + ignored, by default "//mlruns" + artifact_location : str, optional + Optional separate artifact location, by default None + + Note + ---- + For NGC mode, one needs to mount a NGC workspace / folder system with a metric folder + at `/mlflow/mlflow_metrics/` and a artifact folder at `/mlflow/mlflow_artifacts/`. + + Note + ---- + This will set up Modulus Launch logger for MLFlow logging. Only one MLFlow logging + client is supported with the Modulus Launch logger. + + Returns + ------- + Tuple[MlflowClient, Run] + Returns MLFlow logging client and active run object + """ + dist = DistributedManager() + if dist.rank != 0: # only root process should be logging to mlflow + return + + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y_%H-%M-%S") + group_name = f"{run_name}_{time_string}" + + # Set default value here for Hydra + if tracking_location is None: + tracking_location = str(Path("./mlruns").absolute()) + + # Set up URI (remote or local) + if mode == "online": + tracking_uri = tracking_location + elif mode == "offline": + if not tracking_location.startswith("file://"): + tracking_location = "file://" + tracking_location + tracking_uri = tracking_location + elif mode == "ngc": + if not Path("/mlflow/mlflow_metrics").is_dir(): + raise IOError( + "NGC MLFlow config select but metrics folder '/mlflow/mlflow_metrics'" + + " not found. Aborting MLFlow setup." + ) + return + + if not Path("/mlflow/mlflow_artifacts").is_dir(): + raise IOError( + "NGC MLFlow config select but artifact folder '/mlflow/mlflow_artifacts'" + + " not found. Aborting MLFlow setup." + ) + return + tracking_uri = "file:///mlflow/mlflow_metrics" + artifact_location = "file:///mlflow/mlflow_artifacts" + else: + logger.warning(f"Unsupported MLFlow mode '{mode}' provided") + tracking_uri = "file://" + str(Path("./mlruns").absolute()) + + mlflow.set_tracking_uri(tracking_uri) + client = MlflowClient() + + check_mlflow_logged_in(client) + + experiment = client.get_experiment_by_name(experiment_name) + # If experiment does not exist create one + if experiment is None: + logger.info(f"No {experiment_name} experiment found, creating...") + experiment_id = client.create_experiment( + experiment_name, artifact_location=artifact_location + ) + client.set_experiment_tag(experiment_id, "mlflow.note.content", experiment_desc) + else: + logger.success(f"Existing {experiment_name} experiment found") + experiment_id = experiment.experiment_id + + # Create an run and set its tags + run = client.create_run( + experiment_id, tags={"mlflow.user": user_name}, run_name=run_name + ) + client.set_tag(run.info.run_id, "mlflow.note.content", run_desc) + + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y %H:%M:%S") + client.set_tag(run.info.run_id, "date", time_string) + client.set_tag(run.info.run_id, "host", os.uname()[1]) + if torch.cuda.is_available(): + client.set_tag(run.info.run_id, "gpu", torch.cuda.get_device_name(dist.device)) + client.set_tag(run.info.run_id, "group", group_name) + + run = client.get_run(run.info.run_id) + + # Set run instance in Modulus logger + LaunchLogger.mlflow_run = run + LaunchLogger.mlflow_client = client + + return client, run + + +def check_mlflow_logged_in(client: MlflowClient): + """Checks to see if MLFlow URI is functioning + + This isn't the best solution right now and overrides http timeout. Can update if MLFlow + use is increased. + """ + + logger.warning( + "Checking MLFlow logging location is working (if this hangs it's not)" + ) + t0 = os.environ.get("MLFLOW_HTTP_REQUEST_TIMEOUT", None) + try: + # Adjust http timeout to 5 seconds + os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = str(max(int(t0), 5)) if t0 else "5" + experiment = client.create_experiment("test") + client.delete_experiment(experiment) + + except Exception as e: + logger.error("Failed to validate MLFlow logging location works") + raise e + finally: + # Restore http request + if t0: + os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = t0 + else: + del os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] + + logger.success("MLFlow logging location is working") diff --git a/examples/domino/modulus/launch/logging/utils.py b/examples/domino/modulus/launch/logging/utils.py new file mode 100644 index 0000000000..0a9e66dde8 --- /dev/null +++ b/examples/domino/modulus/launch/logging/utils.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +import torch +from modulus.distributed import DistributedManager + + +def create_ddp_group_tag(group_name: str = None) -> str: + """Creates a common group tag for logging + + For some reason this does not work with multi-node. Seems theres a bug in PyTorch + when one uses a distributed util before DDP + + Parameters + ---------- + group_name : str, optional + Optional group name prefix. If None will use ``"DDP_Group_"``, by default None + + Returns + ------- + str + Group tag + """ + dist = DistributedManager() + if dist.rank == 0: + # Store time stamp as int tensor for broadcasting + def tint(x): + return int(datetime.now().strftime(f"%{x}")) + + time_index = torch.IntTensor( + [tint(x) for x in ["m", "d", "y", "H", "M", "S"]] + ).to(dist.device) + else: + time_index = torch.IntTensor([0, 0, 0, 0, 0, 0]).to(dist.device) + + if torch.distributed.is_available(): + # Broadcast group ID to all processes + torch.distributed.broadcast(time_index, src=0) + + time_string = f"{time_index[0]}/{time_index[1]}/{time_index[2]}_\ + {time_index[3]}-{time_index[4]}-{time_index[5]}" + + if group_name is None: + group_name = "DDP_Group" + return group_name + "_" + time_string diff --git a/examples/domino/modulus/launch/logging/wandb.py b/examples/domino/modulus/launch/logging/wandb.py new file mode 100644 index 0000000000..f147cb6888 --- /dev/null +++ b/examples/domino/modulus/launch/logging/wandb.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Weights and Biases Routines and Utilities""" + +import logging +import os +from datetime import datetime +from pathlib import Path +from typing import Literal + +import wandb +from modulus.distributed import DistributedManager +from wandb import AlertLevel + +from .utils import create_ddp_group_tag + +DEFAULT_WANDB_CONFIG = "~/.netrc" +logger = logging.getLogger(__name__) + +_WANDB_INITIALIZED = False + + +def initialize_wandb( + project: str, + entity: str, + name: str = "train", + group: str = None, + sync_tensorboard: bool = False, + save_code: bool = False, + resume: str = None, + wandb_id: str = None, + config=None, + mode: Literal["offline", "online", "disabled"] = "offline", + results_dir: str = None, +): + """Function to initialize wandb client with the weights and biases server. + + Parameters + ---------- + project : str + Name of the project to sync data with + entity : str, + Name of the wanbd entity + sync_tensorboard : bool, optional + sync tensorboard summary writer with wandb, by default False + save_code : bool, optional + Whether to push a copy of the code to wandb dashboard, by default False + name : str, optional + Name of the task running, by default "train" + group : str, optional + Group name of the task running. Good to set for ddp runs, by default None + resume: str, optional + Sets the resuming behavior. Options: "allow", "must", "never", "auto" or None, + by default None. + wandb_id: str, optional + A unique ID for this run, used for resuming. Used in conjunction with `resume` + parameter to enable experiment resuming. + See W&B documentation for more details: + https://docs.wandb.ai/guides/runs/resuming/ + config : optional + a dictionary-like object for saving inputs , like hyperparameters. + If dict, argparse or absl.flags, it will load the key value pairs into the + wandb.config object. If str, it will look for a yaml file by that name, + by default None. + mode: str, optional + Can be "offline", "online" or "disabled", by default "offline" + results_dir : str, optional + Output directory of the experiment, by default "//wandb" + """ + + # Set default value here for Hydra + if results_dir is None: + results_dir = str(Path("./wandb").absolute()) + + wandb_dir = results_dir + if DistributedManager.is_initialized() and DistributedManager().distributed: + if group is None: + group = create_ddp_group_tag() + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") + wandb_name = f"{name}_Process_{DistributedManager().rank}_{time_string}" + else: + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") + wandb_name = f"{name}_{time_string}" + + if not os.path.exists(wandb_dir): + os.makedirs(wandb_dir, exist_ok=True) + + wandb.init( + project=project, + entity=entity, + sync_tensorboard=sync_tensorboard, + name=wandb_name, + resume=resume, + config=config, + mode=mode, + dir=wandb_dir, + group=group, + save_code=save_code, + id=wandb_id, + ) + + +def alert(title, text, duration=300, level=0, is_master=True): + """Send alert.""" + alert_levels = {0: AlertLevel.INFO, 1: AlertLevel.WARN, 2: AlertLevel.ERROR} + if is_wandb_initialized() and is_master: + wandb.alert( + title=title, text=text, level=alert_levels[level], wait_duration=duration + ) + + +def is_wandb_initialized(): + """Check if wandb has been initialized.""" + global _WANDB_INITIALIZED + return _WANDB_INITIALIZED diff --git a/examples/domino/modulus/launch/utils/__init__.py b/examples/domino/modulus/launch/utils/__init__.py new file mode 100644 index 0000000000..ddd9be3cdf --- /dev/null +++ b/examples/domino/modulus/launch/utils/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .checkpoint import load_checkpoint # noqa: F401 +from .checkpoint import save_checkpoint # noqa: F401 diff --git a/examples/domino/modulus/launch/utils/checkpoint.py b/examples/domino/modulus/launch/utils/checkpoint.py new file mode 100644 index 0000000000..1f5f062155 --- /dev/null +++ b/examples/domino/modulus/launch/utils/checkpoint.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import re +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import NewType +from typing import Optional +from typing import Union + +import paddle +from paddle.amp import GradScaler +from paddle.optimizer.lr import LRScheduler + +from ...distributed import DistributedManager +from ...launch.logging import PythonLogger + +# from modulus.utils.capture import _StaticCapture + +optimizer = NewType("optimizer", paddle.optimizer) +scheduler = NewType("scheduler", LRScheduler) +scaler = NewType("scaler", GradScaler) + +checkpoint_logging = PythonLogger("checkpoint") + + +def _get_checkpoint_filename( + path: str, + base_name: str = "checkpoint", + index: Union[int, None] = None, + saving: bool = False, + model_type: str = "mdlus", +) -> str: + """Gets the file name /path of checkpoint + + This function has three different ways of providing a checkout filename: + - If supplied an index this will return the checkpoint name using that index. + - If index is None and saving is false, this will get the checkpoint with the + largest index (latest save). + - If index is None and saving is true, it will return the next valid index file name + which is calculated by indexing the largest checkpoint index found by one. + + Parameters + ---------- + path : str + Path to checkpoints + base_name: str, optional + Base file name, by default checkpoint + index : Union[int, None], optional + Checkpoint index, by default None + saving : bool, optional + Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for Modulus models and "pdparams" for PyTorch models + + + Returns + ------- + str + Checkpoint file name + """ + # Get model parallel rank so all processes in the first model parallel group + # can save their checkpoint. In the case without model parallelism, + # model_parallel_rank should be the same as the process rank itself and + # only rank 0 saves + if not DistributedManager.is_initialized(): + checkpoint_logging.warning( + "`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors" + ) + DistributedManager.initialize() + manager = DistributedManager() + model_parallel_rank = ( + manager.group_rank("model_parallel") + if "model_parallel" in manager.group_names + else 0 + ) + + # Input file name + checkpoint_filename = str( + Path(path).resolve() / f"{base_name}.{model_parallel_rank}" + ) + + # File extension for Modulus models or PaddlePaddle models + file_extension = ".pdparams" + + # If epoch is provided load that file + if index is not None: + checkpoint_filename = checkpoint_filename + f".{index}" + checkpoint_filename += file_extension + # Otherwise try loading the latest epoch or rolling checkpoint + else: + file_names = [ + Path(fname).name + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ) + ] + + if len(file_names) > 0: + # If checkpoint from a null index save exists load that + # This is the most likely line to error since it will fail with + # invalid checkpoint names + file_idx = [ + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) + for fname in file_names + ] + file_idx.sort() + # If we are saving index by 1 to get the next free file name + if saving: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" + else: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" + checkpoint_filename += file_extension + else: + checkpoint_filename += ".0" + file_extension + + return checkpoint_filename + + +def _unique_model_names( + models: List[paddle.nn.Layer], +) -> Dict[str, paddle.nn.Layer]: + """Util to clean model names and index if repeat names, will also strip DDP wrappers + if they exist. + + Parameters + ---------- + model : List[paddle.nn.Layer] + List of models to generate names for + + Returns + ------- + Dict[str, paddle.nn.Layer] + Dictionary of model names and respective modules + """ + # Loop through provided models and set up base names + model_dict = {} + for model0 in models: + if hasattr(model0, "module"): + # Strip out DDP layer + model0 = model0.module + # Base name of model is meta.name unless paddle model + base_name = model0.__class__.__name__ + # if isinstance(model0, modulus): + # base_name = model0.meta.name + # If we have multiple models of the same name, introduce another index + if base_name in model_dict: + model_dict[base_name].append(model0) + else: + model_dict[base_name] = [model0] + + # Set up unique model names if needed + output_dict = {} + for key, model in model_dict.items(): + if len(model) > 1: + for i, model0 in enumerate(model): + output_dict[key + str(i)] = model0 + else: + output_dict[key] = model[0] + + return output_dict + + +def save_checkpoint( + path: str, + models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Training checkpoint saving utility + + This will save a training checkpoint in the provided path following the file naming + convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint + method in Modulus core can then be used to read this file. + + Parameters + ---------- + path : str + Path to save the training checkpoint + models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional + A single or list of PaddlePaddle models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler. Will attempt to save on in static capture if none provided, by + default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none this will save the checkpoint in the next + valid index, by default None + metadata : Optional[Dict[str, Any]], optional + Additional metadata to save, by default None + """ + # Create checkpoint directory if it does not exist + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Output directory {path} does not exist, will " "attempt to create" + ) + Path(path).mkdir(parents=True, exist_ok=True) + + # == Saving model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = "pdparams" + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type=model_type + ) + + # Save state dictionary + paddle.save(model.state_dict(), file_name) + checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + + # == Saving training checkpoint == + checkpoint_dict = {} + # Optimizer state dict + if optimizer: + checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() + + # Scheduler state dict + if scheduler: + checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() + + # Scheduler state dict + if scaler: + checkpoint_dict["scaler_state_dict"] = scaler.state_dict() + # Static capture is being used, save its grad scaler + # if _StaticCapture._amp_scalers: + # checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() + + # Output file name + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pdparams" + ) + if epoch: + checkpoint_dict["epoch"] = epoch + if metadata: + checkpoint_dict["metadata"] = metadata + # Save checkpoint to memory + if bool(checkpoint_dict): + paddle.save( + checkpoint_dict, + output_filename, + ) + checkpoint_logging.success(f"Saved training checkpoint: {output_filename}") + + +def load_checkpoint( + path: str, + models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata_dict: Optional[Dict[str, Any]] = {}, +) -> int: + """Checkpoint loading utility + + This loader is designed to be used with the save checkpoint utility in Modulus + Launch. Given a path, this method will try to find a checkpoint and load state + dictionaries into the provided training objects. + + Parameters + ---------- + path : str + Path to training checkpoint + models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler, by default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none is provided this will attempt to load the + checkpoint with the largest index, by default None + metadata_dict: Optional[Dict[str, Any]], optional + Dictionary to store metadata from the checkpoint, by default None + + Returns + ------- + int + Loaded epoch + """ + # Check if checkpoint directory exists + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Provided checkpoint directory {path} does not exist, skipping load" + ) + return 0 + + # == Loading model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = "pdparams" + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, model_type=model_type + ) + if not Path(file_name).exists(): + checkpoint_logging.error( + f"Could not find valid model file {file_name}, skipping load" + ) + continue + # Load state dictionary + model.set_state_dict(paddle.load(file_name)) + + checkpoint_logging.success(f"Loaded model state dictionary {file_name}") + + # == Loading training checkpoint == + checkpoint_filename = _get_checkpoint_filename( + path, index=epoch, model_type="pdparams" + ) + if not Path(checkpoint_filename).is_file(): + checkpoint_logging.warning( + "Could not find valid checkpoint file, skipping load" + ) + return 0 + + checkpoint_dict = paddle.load(checkpoint_filename) + checkpoint_logging.success(f"Loaded checkpoint file {checkpoint_filename}") + + # Optimizer state dict + if optimizer and "optimizer_state_dict" in checkpoint_dict: + optimizer.set_state_dict(checkpoint_dict["optimizer_state_dict"]) + checkpoint_logging.success("Loaded optimizer state dictionary") + + # Scheduler state dict + if scheduler and "scheduler_state_dict" in checkpoint_dict: + scheduler.set_state_dict(checkpoint_dict["scheduler_state_dict"]) + checkpoint_logging.success("Loaded scheduler state dictionary") + + # Scaler state dict + if scaler and "scaler_state_dict" in checkpoint_dict: + scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) + checkpoint_logging.success("Loaded grad scaler state dictionary") + + # if "static_capture_state_dict" in checkpoint_dict: + # _StaticCapture.set_state_dict(checkpoint_dict["static_capture_state_dict"]) + # checkpoint_logging.success("Loaded static capture state dictionary") + + epoch = 0 + if "epoch" in checkpoint_dict: + epoch = checkpoint_dict["epoch"] + # Update metadata if exists and the dictionary object is provided + metadata = checkpoint_dict.get("metadata", {}) + for key, value in metadata.items(): + metadata_dict[key] = value + + return epoch diff --git a/examples/domino/modulus/models/__init__.py b/examples/domino/modulus/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/domino/modulus/models/layers/__init__.py b/examples/domino/modulus/models/layers/__init__.py new file mode 100644 index 0000000000..b2f171d4ac --- /dev/null +++ b/examples/domino/modulus/models/layers/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/domino/layers/ball_query.py b/examples/domino/modulus/models/layers/ball_query.py similarity index 100% rename from examples/domino/layers/ball_query.py rename to examples/domino/modulus/models/layers/ball_query.py diff --git a/examples/domino/model.py b/examples/domino/modulus/models/model.py similarity index 98% rename from examples/domino/model.py rename to examples/domino/modulus/models/model.py index f270a8eb10..ca87ceb4aa 100644 --- a/examples/domino/model.py +++ b/examples/domino/modulus/models/model.py @@ -22,16 +22,37 @@ """ # from dataclasses import dataclass +import math import paddle import paddle.nn as nn import paddle.nn.functional as F -from layers.ball_query import BallQueryLayer + +import ppsci + +from .layers.ball_query import BallQueryLayer # from modulus.models.meta import ModelMetaData # from modulus.models.module import Module +def kaiming_init(layer): + if isinstance(layer, (nn.layer.conv._ConvNd, nn.Linear)): + print(f"layer: {layer} ") + init_kaimingUniform = paddle.nn.initializer.KaimingUniform( + nonlinearity="leaky_relu", negative_slope=math.sqrt(5) + ) + init_kaimingUniform(layer.weight) + if layer.bias is not None: + fan_in, _ = ppsci.utils.initializer._calculate_fan_in_and_fan_out( + layer.weight + ) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init_uniform = paddle.nn.initializer.Uniform(low=-bound, high=bound) + init_uniform(layer.bias) + + def calculate_pos_encoding(nx, d=8): """Function to caluculate positional encoding""" vec = [] @@ -659,6 +680,8 @@ def __init__( ) ) + self.apply(kaiming_init) + def geometry_encoder(self, geo_centers, p_grid, sdf): """Function to return local geometry encoding""" return self.geo_rep(geo_centers, p_grid, sdf) @@ -1122,7 +1145,7 @@ def forward( else: paddle.set_device("cpu") cfg = OmegaConf.register_new_resolver("eval", eval) - with initialize(version_base="1.3", config_path="conf"): + with initialize(version_base="1.3", config_path="../../scripts/conf"): cfg = compose(config_name="config") cfg.model.model_type = "combined" model = DoMINO( diff --git a/examples/domino/modulus/utils/__init__.py b/examples/domino/modulus/utils/__init__.py new file mode 100644 index 0000000000..b2f171d4ac --- /dev/null +++ b/examples/domino/modulus/utils/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/domino/modulus/utils/domino/utils.py b/examples/domino/modulus/utils/domino/utils.py new file mode 100644 index 0000000000..351eead555 --- /dev/null +++ b/examples/domino/modulus/utils/domino/utils.py @@ -0,0 +1,385 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Important utilities for data processing and training, testing DoMINO. +""" + +import os +import time + +import numpy as np +from scipy.spatial import KDTree + +try: + import pyvista as pv + + PV_AVAILABLE = True +except ImportError: + PV_AVAILABLE = False +try: + import vtk + from vtk import vtkDataSetTriangleFilter + from vtk.util import numpy_support + + VTK_AVAILABLE = True +except ImportError: + VTK_AVAILABLE = False + + +def calculate_center_of_mass(stl_centers, stl_sizes): + """Function to calculate center of mass""" + stl_sizes = np.expand_dims(stl_sizes, -1) + center_of_mass = np.sum(stl_centers * stl_sizes, axis=0) / np.sum(stl_sizes, axis=0) + return center_of_mass + + +def normalize(field, mx, mn): + """Function to normalize fields""" + return 2.0 * (field - mn) / (mx - mn) - 1.0 + + +def unnormalize(field, mx, mn): + """Function to unnormalize fields""" + return (field + 1.0) * (mx - mn) * 0.5 + mn + + +def standardize(field, mean, std): + """Function to standardize fields""" + return (field - mean) / std + + +def unstandardize(field, mean, std): + """Function to unstandardize fields""" + return field * std + mean + + +def write_to_vtp(polydata, filename): + """Function to write polydata to vtp""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def write_to_vtu(polydata, filename): + """Function to write polydata to vtu""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLUnstructuredGridWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def extract_surface_triangles(tet_mesh): + """Extracts the surface triangles from a triangular mesh.""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + if not PV_AVAILABLE: + raise ImportError("PyVista is not installed. This function cannot be used.") + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tet_mesh) + surface_filter.Update() + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise ValueError("Face is not a triangle") + + return triangle_indices + + +def convert_to_tet_mesh(polydata): + """Function to convert tet to stl""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + # Create a VTK DataSetTriangleFilter object + tet_filter = vtkDataSetTriangleFilter() + tet_filter.SetInputData(polydata) + tet_filter.Update() # Update to apply the filter + + # Get the output as an UnstructuredGrid + # tet_mesh = pv.wrap(tet_filter.GetOutput()) + tet_mesh = tet_filter.GetOutput() + return tet_mesh + + +def get_node_to_elem(polydata): + """Function to convert node to elem""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + c2p = vtk.vtkPointDataToCellData() + c2p.SetInputData(polydata) + c2p.Update() + cell_data = c2p.GetOutput() + return cell_data + + +def get_fields_from_cell(ptdata, var_list): + """Function to get fields from elem""" + fields = [] + for var in var_list: + variable = ptdata.GetArray(var) + num_tuples = variable.GetNumberOfTuples() + cell_fields = [] + for j in range(num_tuples): + variable_value = np.array(variable.GetTuple(j)) + cell_fields.append(variable_value) + cell_fields = np.asarray(cell_fields) + fields.append(cell_fields) + fields = np.transpose(np.asarray(fields), (1, 0)) + + return fields + + +def get_fields(data, variables): + """Function to get fields from VTP/VTU""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + fields = [] + for array_name in variables: + try: + array = data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = numpy_support.vtk_to_numpy(array).reshape( + array.GetNumberOfTuples(), array.GetNumberOfComponents() + ) + fields.append(array_data) + return fields + + +def get_vertices(polydata): + """Function to get vertices""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(points.GetData()) + return vertices + + +def get_volume_data(polydata, variables): + """Function to get volume data""" + vertices = get_vertices(polydata) + point_data = polydata.GetPointData() + + fields = get_fields(point_data, variables) + + return vertices, fields + + +def get_surface_data(polydata, variables): + """Function to get surface data""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) + + point_data = polydata.GetPointData() + fields = [] + for array_name in variables: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + fields.append(array_data) + + polys = polydata.GetPolys() + if polys is None: + raise ValueError("Failed to get polygons from the polydata.") + polys.InitTraversal() + edges = [] + id_list = vtk.vtkIdList() + for _ in range(polys.GetNumberOfCells()): + polys.GetNextCell(id_list) + num_ids = id_list.GetNumberOfIds() + edges = [ + (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) + ] + + return vertices, fields, edges + + +def calculate_normal_positional_encoding( + coordinates_a, coordinates_b=None, cell_length=[] +): + """Function to get normal positional encoding""" + dx = cell_length[0] + dy = cell_length[1] + dz = cell_length[2] + if coordinates_b is not None: + normals = coordinates_a - coordinates_b + pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) + pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) + pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) + pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + else: + normals = coordinates_a + pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) + pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) + pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) + pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + + return pos_normals + + +def nd_interpolator(coodinates, field, grid): + """Function to for nd interpolation""" + interp_func = KDTree(coodinates[0]) + dd, ii = interp_func.query(grid, k=2) + + field_grid = field[ii] + field_grid = np.float32(np.mean(field_grid, (3))) + return field_grid + + +def pad(arr, npoin, pad_value=0.0): + """Function for padding""" + arr_pad = pad_value * np.ones( + (npoin - arr.shape[0], arr.shape[1]), dtype=np.float32 + ) + arr_padded = np.concatenate((arr, arr_pad), axis=0) + return arr_padded + + +def pad_inp(arr, npoin, pad_value=0.0): + """Function for padding arrays""" + arr_pad = pad_value * np.ones( + (npoin - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=np.float32 + ) + arr_padded = np.concatenate((arr, arr_pad), axis=0) + return arr_padded + + +def shuffle_array(arr, npoin): + """Function for shuffling arrays""" + np.random.seed(seed=int(time.time())) + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + idx = idx[:npoin] + return arr[idx], idx + + +def shuffle_array_without_sampling(arr): + """Function for shuffline arrays without sampling.""" + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + return arr[idx], idx + + +def create_directory(filepath): + """Function to create directories""" + if not os.path.exists(filepath): + os.makedirs(filepath) + + +def get_filenames(filepath): + """Function to get filenames from a directory""" + if os.path.exists(filepath): + filenames = os.listdir(filepath) + return filenames + else: + FileNotFoundError() + + +def calculate_pos_encoding(nx, d=8): + """Function for calculating positional encoding""" + vec = [] + for k in range(int(d / 2)): + vec.append(np.sin(nx / 10000 ** (2 * (k) / d))) + vec.append(np.cos(nx / 10000 ** (2 * (k) / d))) + return vec + + +def combine_dict(old_dict, new_dict): + """Function to combine dictionaries""" + for j in old_dict.keys(): + old_dict[j] += new_dict[j] + return old_dict + + +def merge(*lists): + """Function to merge lists""" + newlist = lists[:] + for x in lists: + if x not in newlist: + newlist.extend(x) + return newlist + + +def create_grid(mx, mn, nres): + """Function to create grid""" + dx = np.linspace(mn[0], mx[0], nres[0]) + dy = np.linspace(mn[1], mx[1], nres[1]) + dz = np.linspace(mn[2], mx[2], nres[2]) + + xv, yv, zv = np.meshgrid(dx, dy, dz) + xv = np.expand_dims(xv, -1) + yv = np.expand_dims(yv, -1) + zv = np.expand_dims(zv, -1) + grid = np.concatenate((xv, yv, zv), axis=-1) + grid = np.transpose(grid, (1, 0, 2, 3)) + + return grid + + +def mean_std_sampling(field, mean, std, tolerance=3.0): + """Function for mean/std based sampling""" + idx_all = [] + for v in range(field.shape[-1]): + fv = field[:, v] + idx = np.where( + (fv > mean[v] + tolerance * std[v]) | (fv < mean[v] - tolerance * std[v]) + ) + if len(idx[0]) != 0: + idx_all += list(idx[0]) + + return idx_all + + +def dict_to_device(state_dict, device): + """Function to load dictionary to device""" + new_state_dict = {} + for k, v in state_dict.items(): + new_state_dict[k] = v.to(device) + return new_state_dict + + +def area_weighted_shuffle_array(arr, npoin, area): + factor = 1.0 + total_area = np.sum(area**factor) + probs = area**factor / total_area + np.random.seed(seed=int(time.time())) + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + ids = np.random.choice(idx, npoin, p=probs[idx]) + return arr[ids], ids diff --git a/examples/domino/modulus/utils/sdf.py b/examples/domino/modulus/utils/sdf.py new file mode 100644 index 0000000000..914b54bd76 --- /dev/null +++ b/examples/domino/modulus/utils/sdf.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F401 + +import numpy as np +import warp as wp +from numpy.typing import NDArray + + +@wp.kernel +def _bvh_query_distance( + mesh: wp.uint64, + points: wp.array(dtype=wp.vec3f), + max_dist: wp.float32, + sdf: wp.array(dtype=wp.float32), + sdf_hit_point: wp.array(dtype=wp.vec3f), + sdf_hit_point_id: wp.array(dtype=wp.int32), + use_sign_winding_number: bool = False, +): + + """ + Computes the signed distance from each point in the given array `points` + to the mesh represented by `mesh`,within the maximum distance `max_dist`, + and stores the result in the array `sdf`. + + Parameters: + mesh (wp.uint64): The identifier of the mesh. + points (wp.array): An array of 3D points for which to compute the + signed distance. + max_dist (wp.float32): The maximum distance within which to search + for the closest point on the mesh. + sdf (wp.array): An array to store the computed signed distances. + sdf_hit_point (wp.array): An array to store the computed hit points. + sdf_hit_point_id (wp.array): An array to store the computed hit point ids. + use_sign_winding_number (bool): Flag to use sign_winding_number method for SDF. + + Returns: + None + """ + tid = wp.tid() + + if use_sign_winding_number: + res = wp.mesh_query_point_sign_winding_number(mesh, points[tid], max_dist) + else: + res = wp.mesh_query_point_sign_normal(mesh, points[tid], max_dist) + + mesh_ = wp.mesh_get(mesh) + + p0 = mesh_.points[mesh_.indices[3 * res.face + 0]] + p1 = mesh_.points[mesh_.indices[3 * res.face + 1]] + p2 = mesh_.points[mesh_.indices[3 * res.face + 2]] + + p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 + + sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) + sdf_hit_point[tid] = p_closest + sdf_hit_point_id[tid] = res.face + + +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: NDArray[float], + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = False, + use_sign_winding_number: bool = False, +) -> wp.array: + """ + Computes the signed distance field (SDF) for a given mesh and input points. + + Parameters: + ---------- + mesh_vertices (list[tuple[float, float, float]]): List of vertices defining the mesh. + mesh_indices (list[tuple[int, int, int]]): List of indices defining the triangles of the mesh. + input_points (list[tuple[float, float, float]]): List of input points for which to compute the SDF. + max_dist (float, optional): Maximum distance within which to search for + the closest point on the mesh. Default is 1e8. + include_hit_points (bool, optional): Whether to include hit points in + the output. Default is False. + include_hit_points_id (bool, optional): Whether to include hit point + IDs in the output. Default is False. + + Returns: + ------- + wp.array: An array containing the computed signed distance field. + + Example: + ------- + >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] + >>> mesh_indices = np.array((0, 1, 2)) + >>> input_points = [(0.5, 0.5, 0.5)] + >>> signed_distance_field(mesh_vertices, mesh_indices, input_points).numpy() + Module ... + array([0.5], dtype=float32) + """ + + wp.init() + mesh = wp.Mesh( + wp.array(mesh_vertices, dtype=wp.vec3), wp.array(mesh_indices, dtype=wp.int32) + ) + + sdf_points = wp.array(input_points, dtype=wp.vec3) + sdf = wp.zeros(shape=sdf_points.shape, dtype=wp.float32) + sdf_hit_point = wp.zeros(shape=sdf_points.shape, dtype=wp.vec3f) + sdf_hit_point_id = wp.zeros(shape=sdf_points.shape, dtype=wp.int32) + + wp.launch( + kernel=_bvh_query_distance, + dim=len(sdf_points), + inputs=[ + mesh.id, + sdf_points, + max_dist, + sdf, + sdf_hit_point, + sdf_hit_point_id, + use_sign_winding_number, + ], + ) + + if include_hit_points and include_hit_points_id: + return (sdf, sdf_hit_point, sdf_hit_point_id) + elif include_hit_points: + return (sdf, sdf_hit_point) + elif include_hit_points_id: + return (sdf, sdf_hit_point_id) + else: + return sdf diff --git a/examples/domino/test.py b/examples/domino/test.py new file mode 100644 index 0000000000..45a43f20e1 --- /dev/null +++ b/examples/domino/test.py @@ -0,0 +1,742 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed pipeline for testing the DoMINO model on +CFD datasets. It includes the instantiating the DoMINO model and datapipe, +automatically loading the most recent checkpoint, reading the VTP/VTU/STL +testing files, calculation of parameters required for DoMINO model and +evaluating the model in parallel using DataParallel across multiple +GPUs. This is a common recipe that enables training of combined models for surface +and volume as well either of them separately. The model predictions are loaded in +the the VTP/VTU files and saved in the specified directory. The eval tab in +config.yaml can be used to specify the input and output directories. +""" + +import os +import re + +import hydra +import numpy as np +import paddle +import pyvista as pv +import vtk +from hydra.utils import to_absolute_path +from modulus.distributed import DistributedManager +from modulus.models.model import DoMINO +from modulus.utils.domino.utils import KDTree +from modulus.utils.domino.utils import calculate_center_of_mass +from modulus.utils.domino.utils import calculate_normal_positional_encoding +from modulus.utils.domino.utils import create_directory +from modulus.utils.domino.utils import create_grid +from modulus.utils.domino.utils import get_fields +from modulus.utils.domino.utils import get_filenames +from modulus.utils.domino.utils import get_node_to_elem +from modulus.utils.domino.utils import get_volume_data +from modulus.utils.domino.utils import normalize +from modulus.utils.domino.utils import unnormalize +from modulus.utils.domino.utils import write_to_vtp +from modulus.utils.domino.utils import write_to_vtu +from modulus.utils.sdf import signed_distance_field +from omegaconf import DictConfig +from omegaconf import OmegaConf +from paddle import DataParallel +from vtk.util import numpy_support + +AIR_DENSITY = 1.205 +STREAM_VELOCITY = 30.00 + + +def loss_fn(output, target): + masked_loss = paddle.mean(((output - target) ** 2.0), (0, 1, 2)) + loss = paddle.mean(masked_loss) + return loss + + +def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): + running_tloss_vol = 0.0 + running_tloss_surf = 0.0 + + if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": + output_features_vol = True + else: + output_features_vol = None + + if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": + output_features_surf = True + else: + output_features_surf = None + + with paddle.no_grad(): + point_batch_size = 256000 + + # Non-dimensionalization factors + air_density = data_dict["air_density"] + stream_velocity = data_dict["stream_velocity"] + length_scale = data_dict["length_scale"] + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + if output_features_vol is not None: + # Represent geometry on computational grid + # Computational domain grid + p_grid = data_dict["grid"] + sdf_grid = data_dict["sdf_grid"] + # Scaling factors + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + + # Normalize based on computational domain + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + encoding_g_vol = model.module.geo_rep(geo_centers_vol, p_grid, sdf_grid) + + # Normalize based on BBox around surface (car) + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.module.geo_rep( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + if output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.module.geo_rep( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + geo_encoding = 0.5 * encoding_g_surf + # Average the encodings + if output_features_vol is not None: + geo_encoding += 0.5 * encoding_g_vol + + if output_features_vol is not None: + # First calculate volume predictions if required + volume_mesh_centers = data_dict["volume_mesh_centers"] + target_vol = data_dict["volume_fields"] + # SDF on volume mesh nodes + sdf_nodes = data_dict["sdf_nodes"] + # Positional encoding based on closest point on surface to a volume node + pos_volume_closest = data_dict["pos_volume_closest"] + # Positional encoding based on center of mass of geometry to volume node + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + p_grid = data_dict["grid"] + + prediction_vol = np.zeros_like(target_vol.cpu().numpy()) + num_points = volume_mesh_centers.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with paddle.no_grad(): + target_batch = target_vol[:, start_idx:end_idx] + volume_mesh_centers_batch = volume_mesh_centers[ + :, start_idx:end_idx + ] + sdf_nodes_batch = sdf_nodes[:, start_idx:end_idx] + pos_volume_closest_batch = pos_volume_closest[:, start_idx:end_idx] + pos_normals_com_batch = pos_volume_center_of_mass[ + :, start_idx:end_idx + ] + geo_encoding_local = model.module.geo_encoding_local( + geo_encoding, volume_mesh_centers_batch, p_grid + ) + if cfg.model.use_sdf_in_basis_func: + pos_encoding = paddle.concat( + ( + sdf_nodes_batch, + pos_volume_closest_batch, + pos_normals_com_batch, + ), + axis=-1, + ) + else: + pos_encoding = pos_normals_com_batch + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="volume" + ) + tpredictions_batch = model.module.calculate_solution( + volume_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + stream_velocity, + air_density, + num_sample_points=20, + eval_mode="volume", + ) + running_tloss_vol += loss_fn(tpredictions_batch, target_batch) + prediction_vol[ + :, start_idx:end_idx + ] = tpredictions_batch.cpu().numpy() + + prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1]) + + prediction_vol[:, :, :3] = ( + prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy() + ) + prediction_vol[:, :, 3] = ( + prediction_vol[:, :, 3] + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() + ) + prediction_vol[:, :, 4] = ( + prediction_vol[:, :, 4] + * stream_velocity[0, 0].cpu().numpy() + * length_scale[0].cpu().numpy() + ) + else: + prediction_vol = None + + if output_features_surf is not None: + # Next calculate surface predictions + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + num_points = surface_mesh_centers.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + target_surf = data_dict["surface_fields"] + prediction_surf = np.zeros_like(target_surf.cpu().numpy()) + + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with paddle.no_grad(): + target_batch = target_surf[:, start_idx:end_idx] + surface_mesh_centers_batch = surface_mesh_centers[ + :, start_idx:end_idx + ] + surface_mesh_neighbors_batch = surface_mesh_neighbors[ + :, start_idx:end_idx + ] + surface_normals_batch = surface_normals[:, start_idx:end_idx] + surface_neighbors_normals_batch = surface_neighbors_normals[ + :, start_idx:end_idx + ] + surface_areas_batch = surface_areas[:, start_idx:end_idx] + surface_neighbors_areas_batch = surface_neighbors_areas[ + :, start_idx:end_idx + ] + pos_surface_center_of_mass_batch = pos_surface_center_of_mass[ + :, start_idx:end_idx + ] + geo_encoding_local = model.module.geo_encoding_local_surface( + 0.5 * encoding_g_surf, surface_mesh_centers_batch, s_grid + ) + pos_encoding = pos_surface_center_of_mass_batch + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="surface" + ) + + if cfg.model.surface_neighbors: + tpredictions_batch = ( + model.module.calculate_solution_with_neighbors( + surface_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + surface_mesh_neighbors_batch, + surface_normals_batch, + surface_neighbors_normals_batch, + surface_areas_batch, + surface_neighbors_areas_batch, + stream_velocity, + air_density, + ) + ) + else: + tpredictions_batch = model.module.calculate_solution( + surface_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + stream_velocity, + air_density, + num_sample_points=1, + eval_mode="surface", + ) + running_tloss_surf += loss_fn(tpredictions_batch, target_batch) + prediction_surf[ + :, start_idx:end_idx + ] = tpredictions_batch.cpu().numpy() + + prediction_surf = ( + unnormalize(prediction_surf, surf_factors[0], surf_factors[1]) + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() + ) + + else: + prediction_surf = None + + return prediction_vol, prediction_surf + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + input_path = cfg.eval.test_path + + model_type = cfg.model.model_type + + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + if os.path.exists(vol_save_path) and os.path.exists(surf_save_path): + vol_factors = np.load(vol_save_path) + surf_factors = np.load(surf_save_path) + else: + vol_factors = None + surf_factors = None + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + model_parameters=cfg.model, + ) + + checkpoint = paddle.load( + to_absolute_path(os.path.join(cfg.resume_dir, cfg.eval.checkpoint_name)), + ) + + model.set_state_dict(checkpoint) + + print("Model loaded") + + if dist.world_size > 1: + model = DataParallel( + model, + find_unused_parameters=dist.find_unused_parameters, + ) + + dirnames_per_gpu = get_filenames(input_path) + + pred_save_path = cfg.eval.save_path + create_directory(pred_save_path) + + for count, dirname in enumerate(dirnames_per_gpu): + # print(f"Processing file {dirname}") + filepath = os.path.join(input_path, dirname) + tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) + stl_path = os.path.join(filepath, f"drivaer_{tag}.stl") + vtp_path = os.path.join(filepath, f"boundary_{tag}.vtp") + vtu_path = os.path.join(filepath, f"volume_{tag}.vtu") + + vtp_pred_save_path = os.path.join( + pred_save_path, f"boundary_{tag}_predicted.vtp" + ) + vtu_pred_save_path = os.path.join(pred_save_path, f"volume_{tag}_predicted.vtu") + + # Read STL + reader = pv.get_reader(stl_path) + mesh_stl = reader.read() + stl_vertices = mesh_stl.points + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ + :, 1: + ] # Assuming triangular elements + mesh_indices_flattened = stl_faces.flatten() + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"], dtype=np.float32) + stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32) + + # Center of mass calculation + center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + + if cfg.data.bounding_box_surface is None: + s_max = np.amax(stl_vertices, 0) + s_min = np.amin(stl_vertices, 0) + else: + bounding_box_dims_surf = [] + bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.max)) + bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.min)) + s_max = np.float32(bounding_box_dims_surf[0]) + s_min = np.float32(bounding_box_dims_surf[1]) + + nx, ny, nz = cfg.model.interp_res + + surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) + surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_surf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + surf_grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + surf_grid = np.float32(surf_grid) + sdf_surf_grid = np.float32(sdf_surf_grid) + surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) + + # Read VTP + if model_type == "surface" or model_type == "combined": + reader = vtk.vtkXMLPolyDataReader() + reader.SetFileName(vtp_path) + reader.Update() + polydata_surf = reader.GetOutput() + + celldata_all = get_node_to_elem(polydata_surf) + + celldata = celldata_all.GetCellData() + surface_fields = get_fields(celldata, surface_variable_names) + surface_fields = np.concatenate(surface_fields, axis=-1) + + mesh = pv.PolyData(polydata_surf) + surface_coordinates = np.array(mesh.cell_centers().points, dtype=np.float32) + + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query( + surface_coordinates, k=cfg.model.num_surface_neighbors + ) + + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + + surface_normals = np.array(mesh.cell_normals, dtype=np.float32) + surface_sizes = mesh.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"], dtype=np.float32) + + # Normalize cell normals + surface_normals = ( + surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] + ) + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_sizes = surface_sizes[ii] + surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] + + dx, dy, dz = ( + (s_max[0] - s_min[0]) / nx, + (s_max[1] - s_min[1]) / ny, + (s_max[2] - s_min[2]) / nz, + ) + + if cfg.model.positional_encoding: + pos_surface_center_of_mass = calculate_normal_positional_encoding( + surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_surface_center_of_mass = surface_coordinates - center_of_mass + + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + + else: + surface_coordinates = None + surface_fields = None + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_surface_center_of_mass = None + + # Read VTU + if model_type == "volume" or model_type == "combined": + reader = vtk.vtkXMLUnstructuredGridReader() + reader.SetFileName(vtu_path) + reader.Update() + polydata_vol = reader.GetOutput() + volume_coordinates, volume_fields = get_volume_data( + polydata_vol, volume_variable_names + ) + volume_fields = np.concatenate(volume_fields, axis=-1) + # print(f"Processed vtu {vtu_path}") + + bounding_box_dims = [] + bounding_box_dims.append(np.asarray(cfg.data.bounding_box.max)) + bounding_box_dims.append(np.asarray(cfg.data.bounding_box.min)) + + if bounding_box_dims is None: + c_max = s_max + (s_max - s_min) / 2 + c_min = s_min - (s_max - s_min) / 2 + c_min[2] = s_min[2] + else: + c_max = np.float32(bounding_box_dims[0]) + c_min = np.float32(bounding_box_dims[1]) + + dx, dy, dz = ( + (c_max[0] - c_min[0]) / nx, + (c_max[1] - c_min[1]) / ny, + (c_max[2] - c_min[2]) / nz, + ) + # Generate a grid of specified resolution to map the bounding box + # The grid is used for capturing structured geometry features and SDF representation of geometry + grid = create_grid(c_max, c_min, [nx, ny, nz]) + grid_reshaped = grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + + # SDF calculation + sdf_nodes, sdf_node_closest_point = signed_distance_field( + stl_vertices, + mesh_indices_flattened, + volume_coordinates, + include_hit_points=True, + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) + sdf_node_closest_point = sdf_node_closest_point.numpy() + + if cfg.model.positional_encoding: + pos_volume_closest = calculate_normal_positional_encoding( + volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz] + ) + pos_volume_center_of_mass = calculate_normal_positional_encoding( + volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_volume_closest = volume_coordinates - sdf_node_closest_point + pos_volume_center_of_mass = volume_coordinates - center_of_mass + + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(grid, c_max, c_min) + vol_grid_max_min = np.asarray([c_min, c_max]) + + else: + volume_coordinates = None + volume_fields = None + pos_volume_closest = None + pos_volume_center_of_mass = None + + # print(f"Processed sdf and normalized") + + geom_centers = np.float32(stl_vertices) + + if model_type == "combined": + # Add the parameters to the dictionary + data_dict = { + "pos_volume_closest": pos_volume_closest, + "pos_volume_center_of_mass": pos_volume_center_of_mass, + "pos_surface_center_of_mass": pos_surface_center_of_mass, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "surface_fields": surface_fields, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + elif model_type == "surface": + data_dict = { + "pos_surface_center_of_mass": np.float32(pos_surface_center_of_mass), + "geometry_coordinates": np.float32(geom_centers), + "surf_grid": np.float32(surf_grid), + "sdf_surf_grid": np.float32(sdf_surf_grid), + "surface_mesh_centers": np.float32(surface_coordinates), + "surface_mesh_neighbors": np.float32(surface_neighbors), + "surface_normals": np.float32(surface_normals), + "surface_neighbors_normals": np.float32(surface_neighbors_normals), + "surface_areas": np.float32(surface_sizes), + "surface_neighbors_areas": np.float32(surface_neighbors_sizes), + "surface_fields": np.float32(surface_fields), + "surface_min_max": np.float32(surf_grid_max_min), + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + elif model_type == "volume": + data_dict = { + "pos_volume_closest": pos_volume_closest, + "pos_volume_center_of_mass": pos_volume_center_of_mass, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + + data_dict = { + key: paddle.to_tensor(np.expand_dims(np.float32(value), 0)) + for key, value in data_dict.items() + } + + prediction_vol, prediction_surf = test_step( + data_dict, model, dist.device, cfg, vol_factors, surf_factors + ) + + if prediction_surf is not None: + surface_sizes = np.expand_dims(surface_sizes, -1) + + force_x_pred = np.sum( + prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0] + - prediction_surf[0, :, 1] * surface_sizes[:, 0] + ) + force_x_true = np.sum( + surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0] + - surface_fields[:, 1] * surface_sizes[:, 0] + ) + print(dirname, force_x_pred, force_x_true) + + if prediction_vol is not None: + target_vol = volume_fields + prediction_vol = prediction_vol[0] + c_min = vol_grid_max_min[0] + c_max = vol_grid_max_min[1] + volume_coordinates = unnormalize(volume_coordinates, c_max, c_min) + ids_in_bbox = np.where( + (volume_coordinates[:, 0] < c_min[0]) + | (volume_coordinates[:, 0] > c_max[0]) + | (volume_coordinates[:, 1] < c_min[1]) + | (volume_coordinates[:, 1] > c_max[1]) + | (volume_coordinates[:, 2] < c_min[2]) + | (volume_coordinates[:, 2] > c_max[2]) + ) + target_vol[ids_in_bbox] = 0.0 + prediction_vol[ids_in_bbox] = 0.0 + l2_gt = np.sum(np.square(target_vol), (0)) + l2_error = np.sum(np.square(prediction_vol - target_vol), (0)) + print( + "L-2 norm:", + dirname, + np.sqrt(l2_error), + np.sqrt(l2_gt), + np.sqrt(l2_error) / np.sqrt(l2_gt), + ) + + if prediction_surf is not None: + surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 0:1]) + surfParam_vtk.SetName(f"{surface_variable_names[0]}Pred") + celldata_all.GetCellData().AddArray(surfParam_vtk) + + surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 1:]) + surfParam_vtk.SetName(f"{surface_variable_names[1]}Pred") + celldata_all.GetCellData().AddArray(surfParam_vtk) + + write_to_vtp(celldata_all, vtp_pred_save_path) + + if prediction_vol is not None: + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 0:3]) + volParam_vtk.SetName(f"{volume_variable_names[0]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 3:4]) + volParam_vtk.SetName(f"{volume_variable_names[1]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) + volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + write_to_vtu(polydata_vol, vtu_pred_save_path) + + +if __name__ == "__main__": + main() diff --git a/examples/domino/train.py b/examples/domino/train.py new file mode 100644 index 0000000000..dcb96dbb0e --- /dev/null +++ b/examples/domino/train.py @@ -0,0 +1,953 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed pipeline for training the DoMINO model on +CFD datasets. It includes the computation of scaling factors, instantiating +the DoMINO model and datapipe, automatically loading the most recent checkpoint, +training the model in parallel using DistributedDataParallel across multiple +GPUs, calculating the loss and updating model parameters using mixed precision. +This is a common recipe that enables training of combined models for surface and +volume as well either of them separately. Validation is also conducted every epoch, +where predictions are compared against ground truth values. The code logs training +and validation metrics to TensorBoard. The train tab in config.yaml can be used to +specify batch size, number of epochs and other training parameters. +""" + +import os +import re +import sys +import time + +import hydra +import numpy as np +import paddle +from hydra.utils import to_absolute_path +from omegaconf import DictConfig +from omegaconf import OmegaConf +from paddle import DataParallel +from paddle.amp import GradScaler +from paddle.amp import auto_cast +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(SCRIPT_DIR)) + +from modulus.datapipes.cae.domino_datapipe import DoMINODataPipe # noqa: E402 +from modulus.distributed import DistributedManager # noqa: E402 +from modulus.launch.utils import load_checkpoint # noqa: E402 +from modulus.launch.utils import save_checkpoint # noqa: E402 +from modulus.models.model import DoMINO # noqa: E402 +from modulus.utils.domino.utils import create_directory # noqa: E402 +from modulus.utils.domino.utils import mean_std_sampling # noqa: E402 + + +def relative_loss_fn(output, target, padded_value=-10): + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + masked_loss = paddle.sum(((output - target) ** 2.0) * mask, (0, 1)) / paddle.sum( + mask, (0, 1) + ) + masked_truth = paddle.sum(((target) ** 2.0) * mask, (0, 1)) / paddle.sum( + mask, (0, 1) + ) + loss = paddle.mean(masked_loss / masked_truth) + return loss + + +def mse_loss_fn(output, target, padded_value=-10): + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + masked_loss = paddle.sum(((output - target) ** 2.0) * mask, (0, 1)) / paddle.sum( + mask, (0, 1) + ) + loss = paddle.mean(masked_loss) + return loss + + +def mse_loss_fn_surface(output, target, normals, padded_value=-10): + masked_loss_pres = paddle.mean( + ((output[:, :, :1] - target[:, :, :1]) ** 2.0), (0, 1) + ) + + ws_x_true = target[:, :, 1:2] + ws_x_pred = output[:, :, 1:2] + masked_loss_ws_x = paddle.mean(((ws_x_pred - ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] + ws_y_pred = output[:, :, 2:3] + masked_loss_ws_y = paddle.mean(((ws_y_pred - ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] + ws_z_pred = output[:, :, 3:4] + masked_loss_ws_z = paddle.mean(((ws_z_pred - ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def relative_loss_fn_surface(output, target, normals, padded_value=-10): + masked_loss_pres = paddle.mean( + ((output[:, :, :1] - target[:, :, :1]) ** 2.0), (0, 1) + ) / paddle.mean(((target[:, :, :1]) ** 2.0), (0, 1)) + + ws_x_true = target[:, :, 1:2] + ws_x_pred = output[:, :, 1:2] + masked_loss_ws_x = paddle.mean( + ((ws_x_pred - ws_x_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] + ws_y_pred = output[:, :, 2:3] + masked_loss_ws_y = paddle.mean( + ((ws_y_pred - ws_y_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] + ws_z_pred = output[:, :, 3:4] + masked_loss_ws_z = paddle.mean( + ((ws_z_pred - ws_z_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def relative_loss_fn_area(output, target, normals, area, padded_value=-10): + scale_factor = 1.0 # Get this from the dataset + area = area * 10**4 + pres_x_true = target[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + pres_x_pred = output[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + + masked_loss_pres_x = paddle.mean( + ((pres_x_pred - pres_x_true) ** 2.0), (0, 1) + ) / paddle.mean(((pres_x_true) ** 2.0), (0, 1)) + + ws_x_true = target[:, :, 1:2] * area * scale_factor**2.0 + ws_x_pred = output[:, :, 1:2] * area * scale_factor**2.0 + masked_loss_ws_x = paddle.mean( + ((ws_x_pred - ws_x_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] * area * scale_factor**2.0 + ws_y_pred = output[:, :, 2:3] * area * scale_factor**2.0 + masked_loss_ws_y = paddle.mean( + ((ws_y_pred - ws_y_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] * area * scale_factor**2.0 + ws_z_pred = output[:, :, 3:4] * area * scale_factor**2.0 + masked_loss_ws_z = paddle.mean( + ((ws_z_pred - ws_z_true) ** 2.0), (0, 1) + ) / paddle.mean(((ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres_x) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def mse_loss_fn_area(output, target, normals, area, padded_value=-10): + scale_factor = 1.0 # Get this from the dataset + area = area * 10**4 + + pres_x_true = target[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + pres_x_pred = output[:, :, :1] * normals[:, :, 0:1] * area * scale_factor**2.0 + + masked_loss_pres_x = paddle.mean(((pres_x_pred - pres_x_true) ** 2.0), (0, 1)) + + ws_x_true = target[:, :, 1:2] * area * scale_factor**2.0 + ws_x_pred = output[:, :, 1:2] * area * scale_factor**2.0 + masked_loss_ws_x = paddle.mean(((ws_x_pred - ws_x_true) ** 2.0), (0, 1)) + + ws_y_true = target[:, :, 2:3] * area * scale_factor**2.0 + ws_y_pred = output[:, :, 2:3] * area * scale_factor**2.0 + masked_loss_ws_y = paddle.mean(((ws_y_pred - ws_y_true) ** 2.0), (0, 1)) + + ws_z_true = target[:, :, 3:4] * area * scale_factor**2.0 + ws_z_pred = output[:, :, 3:4] * area * scale_factor**2.0 + masked_loss_ws_z = paddle.mean(((ws_z_pred - ws_z_true) ** 2.0), (0, 1)) + + loss = ( + paddle.mean(masked_loss_pres_x) + + paddle.mean(masked_loss_ws_x) + + paddle.mean(masked_loss_ws_y) + + paddle.mean(masked_loss_ws_z) + ) + loss = loss / 4 + return loss + + +def integral_loss_fn(output, target, area, normals, padded_value=-10): + vel_inlet = 30.0 # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + area = paddle.unsqueeze(area, -1) + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + output_true[:, :, 0] = output_true[:, :, 0] * normals[:, :, 0] + output_pred[:, :, 0] = output_pred[:, :, 0] * normals[:, :, 0] + + masked_pred = paddle.sum(output_pred, (1)) + masked_truth = paddle.sum(output_true, (1)) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = paddle.mean(loss) + return loss + + +def integral_loss_fn_new(output, target, area, normals, padded_value=-10): + drag_loss = drag_loss_fn(output, target, area, normals, padded_value=-10) + lift_loss = lift_loss_fn(output, target, area, normals, padded_value=-10) + return lift_loss + drag_loss + + +def lift_loss_fn(output, target, area, normals, padded_value=-10): + vel_inlet = 30.0 # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + area = paddle.unsqueeze(area, -1) + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + pres_true = output_true[:, :, 0] * normals[:, :, 2] + pres_pred = output_pred[:, :, 0] * normals[:, :, 2] + + wz_true = output_true[:, :, -1] + wz_pred = output_pred[:, :, -1] + + masked_pred = paddle.sum(pres_pred + wz_pred, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + masked_truth = paddle.sum(pres_true + wz_true, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = paddle.mean(loss) + return loss + + +def drag_loss_fn(output, target, area, normals, padded_value=-10): + vel_inlet = 30.0 # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + mask = mask.to(dtype=output.dtype) + area = paddle.unsqueeze(area, -1) + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + pres_true = output_true[:, :, 0] * normals[:, :, 0] + pres_pred = output_pred[:, :, 0] * normals[:, :, 0] + + wx_true = output_true[:, :, 1] + wx_pred = output_pred[:, :, 1] + + masked_pred = paddle.sum(pres_pred + wx_pred, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + masked_truth = paddle.sum(pres_true + wx_true, (1)) / ( + paddle.sum(area) * (vel_inlet) ** 2.0 + ) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = paddle.mean(loss) + return loss + + +def validation_step( + dataloader, + model, + device, + use_sdf_basis=False, + use_surface_normals=False, + integral_scaling_factor=1.0, + loss_fn_type="mse", +): + running_vloss = 0.0 + with paddle.no_grad(): + for i_batch, sampled_batched in enumerate(dataloader): + prediction_vol, prediction_surf = model(sampled_batched) + + if prediction_vol is not None: + target_vol = sampled_batched["volume_fields"] + if loss_fn_type == "rmse": + loss_norm_vol = relative_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + else: + loss_norm_vol = mse_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + + if prediction_surf is not None: + target_surf = sampled_batched["surface_fields"] + surface_normals = sampled_batched["surface_normals"] + surface_areas = sampled_batched["surface_areas"] + if loss_fn_type == "rmse": + loss_norm_surf = relative_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = relative_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + else: + loss_norm_surf = mse_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = mse_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + loss_integral = ( + integral_loss_fn_new( + prediction_surf, + target_surf, + surface_areas, + surface_normals, + padded_value=-10, + ) + ) * integral_scaling_factor + + if prediction_surf is not None and prediction_vol is not None: + vloss = ( + loss_norm_vol + + 0.5 * loss_norm_surf + + loss_integral + + 0.5 * loss_norm_surf_area + ) + elif prediction_vol is not None: + vloss = loss_norm_vol + elif prediction_surf is not None: + vloss = 0.5 * loss_norm_surf + loss_integral + 0.5 * loss_norm_surf_area + + running_vloss += vloss + + avg_vloss = running_vloss / (i_batch + 1) + + return avg_vloss + + +def train_epoch( + dataloader, + model, + optimizer, + scaler, + epoch_index, + device, + integral_scaling_factor, + loss_fn_type, +): + + running_loss = 0.0 + last_loss = 0.0 + loss_interval = 1 + + for i_batch, sampled_batched in enumerate(dataloader): + with auto_cast(enable=False): + prediction_vol, prediction_surf = model(sampled_batched) + + if prediction_vol is not None: + target_vol = sampled_batched["volume_fields"] + if loss_fn_type == "rmse": + loss_norm_vol = relative_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + else: + loss_norm_vol = mse_loss_fn( + prediction_vol, target_vol, padded_value=-10 + ) + + if prediction_surf is not None: + + target_surf = sampled_batched["surface_fields"] + surface_areas = sampled_batched["surface_areas"] + surface_normals = sampled_batched["surface_normals"] + if loss_fn_type == "rmse": + loss_norm_surf = relative_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = relative_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + else: + loss_norm_surf = mse_loss_fn_surface( + prediction_surf, target_surf, surface_normals, padded_value=-10 + ) + loss_norm_surf_area = mse_loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + padded_value=-10, + ) + loss_integral = ( + integral_loss_fn_new( + prediction_surf, + target_surf, + surface_areas, + surface_normals, + padded_value=-10, + ) + ) * integral_scaling_factor + + if prediction_vol is not None and prediction_surf is not None: + loss_norm = ( + loss_norm_vol + + 0.5 * loss_norm_surf + + loss_integral + + 0.5 * loss_norm_surf_area + ) + elif prediction_vol is not None: + loss_norm = loss_norm_vol + elif prediction_surf is not None: + loss_norm = ( + 0.5 * loss_norm_surf + loss_integral + 0.5 * loss_norm_surf_area + ) + + loss = loss_norm + loss = loss / loss_interval + scaler.scale(loss).backward() + + if ((i_batch + 1) % loss_interval == 0) or (i_batch + 1 == len(dataloader)): + scaler.step(optimizer) + scaler.update() + optimizer.clear_gradients() + # Gather data and report + running_loss += loss.item() + + if prediction_vol is not None and prediction_surf is not None: + print( + f"Device {device}, batch processed: {i_batch + 1}, loss volume: {loss_norm_vol:.5f} \ + , loss surface: {loss_norm_surf:.5f}, loss integral: {loss_integral:.5f}, loss surface area: {loss_norm_surf_area:.5f}" + ) + elif prediction_vol is not None: + print( + f"Device {device}, batch processed: {i_batch + 1}, loss volume: {loss_norm_vol:.5f}" + ) + elif prediction_surf is not None: + print( + f"Device {device}, batch processed: {i_batch + 1} \ + , loss surface: {loss_norm_surf:.5f}, loss integral: {loss_integral:.5f}, loss surface area: {loss_norm_surf_area:.5f}" + ) + + last_loss = running_loss / (i_batch + 1) # loss per batch + print(f" Device {device}, batch: {i_batch + 1}, loss norm: {loss:.5f}") + tb_x = epoch_index * len(dataloader) + i_batch + 1 + print(f"Loss/train: {last_loss}/{tb_x}") + + return last_loss + + +def compute_scaling_factors(cfg: DictConfig): + + model_type = cfg.model.model_type + + if model_type == "volume" or model_type == "combined": + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + if not os.path.exists(vol_save_path): + input_path = cfg.data.input_dir + + volume_variable_names = list(cfg.variables.volume.solution.keys()) + + fm_dict = DoMINODataPipe( + input_path, + phase="train", + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=None, + normalize_coordinates=True, + sampling=False, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + compute_scaling_factors=True, + ) + + # Calculate mean + if cfg.model.normalization == "mean_std_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + vol_fields = d_dict["volume_fields"] + + if vol_fields is not None: + if j == 0: + vol_fields_sum = np.mean(vol_fields, 0) + else: + vol_fields_sum += np.mean(vol_fields, 0) + else: + vol_fields_sum = 0.0 + + vol_fields_mean = vol_fields_sum / len(fm_dict) + + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + vol_fields = d_dict["volume_fields"] + + if vol_fields is not None: + if j == 0: + vol_fields_sum_square = np.mean( + (vol_fields - vol_fields_mean) ** 2.0, 0 + ) + else: + vol_fields_sum_square += np.mean( + (vol_fields - vol_fields_mean) ** 2.0, 0 + ) + else: + vol_fields_sum_square = 0.0 + + vol_fields_std = np.sqrt(vol_fields_sum_square / len(fm_dict)) + + vol_scaling_factors = [vol_fields_mean, vol_fields_std] + + if cfg.model.normalization == "min_max_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + vol_fields = d_dict["volume_fields"] + + if vol_fields is not None: + vol_mean = np.mean(vol_fields, 0) + vol_std = np.std(vol_fields, 0) + vol_idx = mean_std_sampling( + vol_fields, vol_mean, vol_std, tolerance=12.0 + ) + vol_fields_sampled = np.delete(vol_fields, vol_idx, axis=0) + if j == 0: + vol_fields_max = np.amax(vol_fields_sampled, 0) + vol_fields_min = np.amin(vol_fields_sampled, 0) + else: + vol_fields_max1 = np.amax(vol_fields_sampled, 0) + vol_fields_min1 = np.amin(vol_fields_sampled, 0) + + for k in range(vol_fields.shape[-1]): + if vol_fields_max1[k] > vol_fields_max[k]: + vol_fields_max[k] = vol_fields_max1[k] + + if vol_fields_min1[k] < vol_fields_min[k]: + vol_fields_min[k] = vol_fields_min1[k] + else: + vol_fields_max = 0.0 + vol_fields_min = 0.0 + + if j > 20: + break + vol_scaling_factors = [vol_fields_max, vol_fields_min] + np.save(vol_save_path, vol_scaling_factors) + + if model_type == "surface" or model_type == "combined": + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + + if not os.path.exists(surf_save_path): + input_path = cfg.data.input_dir + + volume_variable_names = list(cfg.variables.volume.solution.keys()) + surface_variable_names = list(cfg.variables.surface.solution.keys()) + + fm_dict = DoMINODataPipe( + input_path, + phase="train", + grid_resolution=cfg.model.interp_res, + volume_variables=None, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=False, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + compute_scaling_factors=True, + ) + + # Calculate mean + if cfg.model.normalization == "mean_std_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + surf_fields = d_dict["surface_fields"] + + if surf_fields is not None: + if j == 0: + surf_fields_sum = np.mean(surf_fields, 0) + else: + surf_fields_sum += np.mean(surf_fields, 0) + else: + surf_fields_sum = 0.0 + + surf_fields_mean = surf_fields_sum / len(fm_dict) + + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + surf_fields = d_dict["surface_fields"] + + if surf_fields is not None: + if j == 0: + surf_fields_sum_square = np.mean( + (surf_fields - surf_fields_mean) ** 2.0, 0 + ) + else: + surf_fields_sum_square += np.mean( + (surf_fields - surf_fields_mean) ** 2.0, 0 + ) + else: + surf_fields_sum_square = 0.0 + + surf_fields_std = np.sqrt(surf_fields_sum_square / len(fm_dict)) + + surf_scaling_factors = [surf_fields_mean, surf_fields_std] + + if cfg.model.normalization == "min_max_scaling": + for j in range(len(fm_dict)): + d_dict = fm_dict[j] + surf_fields = d_dict["surface_fields"] + + if surf_fields is not None: + surf_mean = np.mean(surf_fields, 0) + surf_std = np.std(surf_fields, 0) + surf_idx = mean_std_sampling( + surf_fields, surf_mean, surf_std, tolerance=12.0 + ) + surf_fields_sampled = np.delete(surf_fields, surf_idx, axis=0) + if j == 0: + surf_fields_max = np.amax(surf_fields_sampled, 0) + surf_fields_min = np.amin(surf_fields_sampled, 0) + else: + surf_fields_max1 = np.amax(surf_fields_sampled, 0) + surf_fields_min1 = np.amin(surf_fields_sampled, 0) + + for k in range(surf_fields.shape[-1]): + if surf_fields_max1[k] > surf_fields_max[k]: + surf_fields_max[k] = surf_fields_max1[k] + + if surf_fields_min1[k] < surf_fields_min[k]: + surf_fields_min[k] = surf_fields_min1[k] + else: + surf_fields_max = 0.0 + surf_fields_min = 0.0 + + if j > 20: + break + + surf_scaling_factors = [surf_fields_max, surf_fields_min] + np.save(surf_save_path, surf_scaling_factors) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + compute_scaling_factors(cfg) + input_path = cfg.data.input_dir + input_path_val = cfg.data.input_dir_val + model_type = cfg.model.model_type + + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + num_vol_vars = 0 + volume_variable_names = [] + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + num_surf_vars = 0 + surface_variable_names = [] + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + if os.path.exists(vol_save_path) and os.path.exists(surf_save_path): + vol_factors = np.load(vol_save_path) + surf_factors = np.load(surf_save_path) + else: + vol_factors = None + surf_factors = None + + train_dataset = DoMINODataPipe( + input_path, + phase="train", + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=True, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + volume_factors=vol_factors, + surface_factors=surf_factors, + scaling_type=cfg.model.normalization, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + num_surface_neighbors=cfg.model.num_surface_neighbors, + ) + + val_dataset = DoMINODataPipe( + input_path_val, + phase="val", + grid_resolution=cfg.model.interp_res, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + normalize_coordinates=True, + sampling=True, + sample_in_bbox=True, + volume_points_sample=cfg.model.volume_points_sample, + surface_points_sample=cfg.model.surface_points_sample, + geom_points_sample=cfg.model.geom_points_sample, + positional_encoding=cfg.model.positional_encoding, + volume_factors=vol_factors, + surface_factors=surf_factors, + scaling_type=cfg.model.normalization, + model_type=cfg.model.model_type, + bounding_box_dims=cfg.data.bounding_box, + bounding_box_dims_surf=cfg.data.bounding_box_surface, + num_surface_neighbors=cfg.model.num_surface_neighbors, + ) + + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=1, + num_replicas=dist.world_size, + rank=dist.rank, + **cfg.train.sampler, + ) + + val_sampler = DistributedBatchSampler( + val_dataset, + batch_size=1, + num_replicas=dist.world_size, + rank=dist.rank, + **cfg.val.sampler, + ) + + train_dataloader = DataLoader(train_dataset, **cfg.train.dataloader) + val_dataloader = DataLoader(val_dataset, **cfg.val.dataloader) + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + model_parameters=cfg.model, + ) + # model = torch.compile(model, disable=True) # TODO make this configurable + + # Print model summary (structure and parmeter count). + + if dist.world_size > 1: + model = DataParallel( + model, + find_unused_parameters=dist.find_unused_parameters, + ) + + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), learning_rate=0.001 + ) + scheduler = paddle.optimizer.lr.MultiStepDecay( + learning_rate=optimizer.get_lr(), + milestones=[50, 100, 150, 200, 250, 300, 350, 400], + gamma=0.5, + ) + optimizer.set_lr_scheduler(scheduler) + + # Initialize the scaler for mixed precision + scaler = GradScaler() + + epoch_number = 0 + + model_save_path = os.path.join(cfg.output, "models") + param_save_path = os.path.join(cfg.output, "param") + best_model_path = os.path.join(model_save_path, "best_model") + if dist.rank == 0: + create_directory(model_save_path) + create_directory(param_save_path) + create_directory(best_model_path) + + if dist.world_size > 1: + paddle.distributed.barrier() + + init_epoch = load_checkpoint( + to_absolute_path(cfg.resume_dir), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + ) + + if init_epoch != 0: + init_epoch += 1 # Start with the next epoch + epoch_number = init_epoch + + # retrive the smallest validation loss if available + numbers = [] + for filename in os.listdir(best_model_path): + match = re.search(r"\d+\.\d*[1-9]\d*", filename) + if match: + number = float(match.group(0)) + numbers.append(number) + + best_vloss = min(numbers) if numbers else 1_000_000.0 + + initial_integral_factor_orig = cfg.model.integral_loss_scaling_factor + + for epoch in range(init_epoch, cfg.train.epochs): + start_time = time.time() + print(f"Device {dist.device}, epoch {epoch_number}:") + + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + initial_integral_factor = initial_integral_factor_orig + + model.train() + avg_loss = train_epoch( + dataloader=train_dataloader, + model=model, + optimizer=optimizer, + scaler=scaler, + epoch_index=epoch, + device=dist.device, + integral_scaling_factor=initial_integral_factor, + loss_fn_type=cfg.model.loss_function, + ) + + model.eval() + avg_vloss = validation_step( + dataloader=val_dataloader, + model=model, + device=dist.device, + use_sdf_basis=cfg.model.use_sdf_in_basis_func, + use_surface_normals=cfg.model.use_surface_normals, + integral_scaling_factor=initial_integral_factor, + loss_fn_type=cfg.model.loss_function, + ) + + scheduler.step() + print( + f"Device {dist.device} " + f"LOSS train {avg_loss:.5f} " + f"valid {avg_vloss:.5f} " + f"Current lr {scheduler.get_lr()}" + f"Integral factor {initial_integral_factor}" + ) + + # if dist.rank == 0: + # writer.add_scalars( + # "Training vs. Validation Loss", + # {"Training": avg_loss, "Validation": avg_vloss}, + # epoch_number, + # ) + # writer.flush() + + # Track best performance, and save the model's state + if dist.world_size > 1: + paddle.distributed.barrier() + + if avg_vloss < best_vloss: # This only considers GPU: 0, is that okay? + best_vloss = avg_vloss + # if dist.rank == 0: + save_checkpoint( + to_absolute_path(best_model_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=str( + best_vloss.item() + ), # hacky way of using epoch to store metadata + ) + print( + f"Device { dist.device}, Best val loss {best_vloss}, Time taken {time.time() - start_time}" + ) + + if dist.rank == 0 and (epoch + 1) % cfg.train.checkpoint_interval == 0.0: + save_checkpoint( + to_absolute_path(model_save_path), + models=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + epoch=epoch, + ) + + epoch_number += 1 + + if scheduler.get_lr() == 1e-6: + print("Training ended") + exit() + + +if __name__ == "__main__": + main() From 3eea7c0ac91dff80cc431c13cc1c78d4c7a2f401 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Wed, 21 May 2025 00:37:24 +0800 Subject: [PATCH 03/11] feat(domino): support domino for training and test --- README.md | 1 + docs/index.md | 1 + docs/zh/examples/domino.md | 89 +++++++++++++++++++ examples/domino/conf/config.yaml | 16 ++-- examples/domino/download_aws_dataset.sh | 64 ++++++++++++++ examples/domino/process_data.py | 109 ++++++++++++++++++++++++ examples/domino/requirements.txt | 2 + 7 files changed, 274 insertions(+), 8 deletions(-) create mode 100644 docs/zh/examples/domino.md create mode 100644 examples/domino/download_aws_dataset.sh create mode 100644 examples/domino/process_data.py diff --git a/README.md b/README.md index 2159b96a4b..4011049c4f 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 | 热仿真 | [1D 换热器热仿真](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/heat_exchanger) | 机理驱动 | PI-DeepONet | 无监督学习 | - | - | | 热仿真 | [2D 热仿真](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/heat_pinn) | 机理驱动 | PINN | 无监督学习 | - | [Paper](https://arxiv.org/abs/1711.10561)| | 热仿真 | [2D 芯片热仿真](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/chip_heat) | 机理驱动 | PI-DeepONet | 无监督学习 | - | [Paper](https://doi.org/10.1063/5.0194245)| +| 外流空气动力学 | [DoMINO](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/domino) | 数据驱动 | FNO | 监督学习 | [Data](https://caemldatasets.org/drivaerml/) | [Paper](https://arxiv.org/abs/2501.13350)|

材料科学(AI for Material)

diff --git a/docs/index.md b/docs/index.md index 128e4c9464..44f471ab52 100644 --- a/docs/index.md +++ b/docs/index.md @@ -133,6 +133,7 @@ | 热仿真 | [1D 换热器热仿真](./zh/examples/heat_exchanger.md) | 机理驱动 | PI-DeepONet | 无监督学习 | - | - | | 热仿真 | [2D 热仿真](./zh/examples/heat_pinn.md) | 机理驱动 | PINN | 无监督学习 | - | [Paper](https://arxiv.org/abs/1711.10561)| | 热仿真 | [2D 芯片热仿真](./zh/examples/chip_heat.md) | 机理驱动 | PI-DeepONet | 无监督学习 | - | [Paper](https://doi.org/10.1063/5.0194245)| +| 外流空气动力学 | [DoMINO](./zh/examples/domino.md) | 数据驱动 | FNO | 监督学习 | [Data](https://caemldatasets.org/drivaerml/) | [Paper](https://arxiv.org/abs/2501.13350)|

材料科学(AI for Material)

diff --git a/docs/zh/examples/domino.md b/docs/zh/examples/domino.md new file mode 100644 index 0000000000..d54ed27361 --- /dev/null +++ b/docs/zh/examples/domino.md @@ -0,0 +1,89 @@ +# DoMINO + +=== "模型训练命令" + + ``` sh + cd examples/domino + + # 1. Download the DrivAer ML dataset using the provided download_aws_dataset.sh script or using the Hugging Face repo(https://huggingface.co/datasets/neashton/drivaerml). + sh download_aws_dataset.sh + + # 2. Specify the configuration settings in `examples/domino/conf/config.yaml`. + + # 3. Run process_data.py. This will process VTP/VTU files and save them as npy for faster processing in DoMINO datapipe. Modify data_processor key in config file. Additionally, run cache_data.py to save outputs of DoMINO datapipe in the .npy files. The DoMINO datapipe is set up to calculate Signed Distance Field and Nearest Neighbor interpolations on-the-fly during training. Caching will save these as a preprocessing step and should be used in cases where the STL surface meshes are upwards of 30 million cells. The final processed dataset should be divided and saved into 2 directories, for training and validation. Specify these directories in conf/config.yaml. + python3 process_data.py + + # 4. run train + python3 train.py + ``` + +=== "模型评估命令" + + 暂无 + +=== "模型导出命令" + + 暂无 + +=== "模型推理命令" + + ``` sh + cd examples/domino + python3 test.py + ``` + +## 1. 背景简介 + +外部空气动力学涉及高雷诺数Navier-Stokes方程求解,传统CFD方法计算成本高昂。神经算子通过端到端映射提升了效率,但面临多尺度耦合建模与长期预测稳定性不足的挑战。Decomposable Multi-scale Iterative Neural Operator(Domino)提出可分解多尺度架构,通过分层特征解耦、迭代残差校正及参数独立编码,显著提升跨尺度流动建模精度与泛化能力。实验显示,其计算速度较CFD快2-3个量级,分离流预测精度较FNO等模型提升约40%,为飞行器设计等工程问题提供高效解决方案。 + +## 2. 模型原理 + +DOMINO (Decomposable Multi-scale Iterative Neural Operator)是一种新颖的机器学习模型架构,旨在解决大规模工程仿真代理建模中的挑战。它是一个基于点云的机器学习模型,利用局部几何信息来预测离散点上的流场 。 + +以下是DOMINO模型的主要原理: + +- 全局几何表示学习(Global Geometry Representation): + - 模型首先以几何体的三维表面网格作为输入。 + - 在几何体周围构建一个紧密贴合的表面包围盒和一个表示计算域的包围盒。 + - 几何点云的特征(如空间坐标)通过可学习的点卷积核投影到表面包围盒上的N维结构化网格上(分辨率为$m×m×m×f$)。 + - 点卷积核的实现使用了NVIDIA Warp加速的自定义球查询层 。 + - 通过两种方法将几何特征传播到计算域包围盒中:1)学习一组单独的多尺度点卷积核,将几何信息投影到计算域网格上;2)使用包含卷积、池化和反池化层的CNN块,将表面包围盒网格上的特征$G_s$​传播到计算域包围盒网格$G_c$。CNN块会迭代评估。 + - 计算域网格上计算出的$m×m×m×f$特征代表了几何点云的全局编码。此外,还会计算符号距离场(SDF)及其梯度分量,并附加到学习到的特征中,以提供关于几何拓扑的额外信息。 + +- 局部几何表示(Local Geometry Representation): + - 局部几何表示取决于计算域中评估解场的物理位置。 + - 在计算局部几何表示之前,会在计算域中采样一批离散点。 + - 对于批次中每个采样点,在其周围定义一个大小为$l×l×l$的子区域,并计算局部几何编码。 + - 局部编码本质上是全局编码的一个子集,取决于其在计算域中的位置,并通过点卷积计算。 + - 提取的局部特征通过全连接神经网络进一步转换。 + - 这种局部几何表示用于使用聚合网络评估采样点上的解场。 + +- 聚合网络(Aggregation Network): + - 局部几何表示代表了采样点及其邻居的计算模板附近几何和解的学习特征。 + - 计算模板中的每个点都由其在计算域中的物理坐标、这些坐标处的SDF、来自域质心的法向量以及表面法向量(如果点在表面上)表示。 + - 这些输入特征通过一个全连接神经网络(称为基函数神经网络),计算出一个潜在向量,代表计算模板中每个点的这些特征。 + - 每个潜在向量与局部几何编码连接,并通过另一组全连接层,以预测计算模板中每个点上的解向量。 + - 解向量通过逆距离加权方案进行平均,以预测采样点处的最终解向量。 + - 对于每个解变量,都使用聚合网络的一个独立实例,但全局几何编码网络在它们之间是共享的。 + +DOMINO模型通过这种分解式、多尺度和迭代的方法,能够有效地处理大规模仿真数据,捕捉长距离和短距离的相互作用,并在不牺牲准确性的情况下提供可扩展、准确和可推广的代理模型 。 + +## 3. 完整代码 + +``` py linenums="1" title="examples/domino/train.py" +--8<-- +examples/domino/train.py +--8<-- +``` + +``` py linenums="1" title="examples/domino/test.py" +--8<-- +examples/domino/test.py +--8<-- +``` + +## 4. 结果展示 + +## 5. 参考资料 + +- [DoMINO: A Decomposable Multi-scale Iterative Neural Operator for Modeling Large Scale Engineering Simulations](https://arxiv.org/abs/2501.13350) diff --git a/examples/domino/conf/config.yaml b/examples/domino/conf/config.yaml index 2a124b1107..8b649e9968 100644 --- a/examples/domino/conf/config.yaml +++ b/examples/domino/conf/config.yaml @@ -28,8 +28,8 @@ hydra: # Hydra config output_subdir: hydra # Default is .hydra which causes files not being uploaded in W&B. data: # Input directory for training and validation data - input_dir: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/outputs/volume_data/ - input_dir_val: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/outputs/volume_data/ + input_dir: outputs/volume_data/ + input_dir_val: outputs/volume_data/ bounding_box: # Bounding box dimensions for computational domain min: [-3.5, -2.25 , -0.32] max: [8.5 , 2.25 , 3.00] @@ -103,7 +103,7 @@ train: # Training configurable parameters sampler: shuffle: true drop_last: false - checkpoint_dir: /lustre/rranade/modulus_dev/modulus_forked/modulus/examples/cfd/external_aerodynamics/domino/outputs/AWS_Dataset/3/models/ + checkpoint_dir: outputs/AWS_Dataset/3/models/ val: # Validation configurable parameters dataloader: @@ -113,12 +113,12 @@ val: # Validation configurable parameters drop_last: false eval: # Testing configurable parameters - test_path: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/drivaer_data_full_new - save_path: /home/aistudio/modulus/examples/cfd/external_aerodynamics/domino/outputs/mesh_predictions_surf_final1/ - checkpoint_name: /home/aistudio/xiaoyewww/PaddleScience/examples/domino/outputs/AWS_Dataset/1/models/DoMINO.0.30.pdparams + test_path: drivaer_data_full + save_path: outputs/mesh_predictions_surf_final1/ + checkpoint_name: outputs/AWS_Dataset/1/models/DoMINO.0.30.pdparams data_processor: # Data processor configurable parameters kind: drivaer_aws # must be either drivesim or drivaer_aws - output_dir: /lustre/rranade/modulus_dev/data/volume_data/ - input_dir: /lustre/datasets/drivaer_aws/drivaer_data_full/ + output_dir: data/volume_data/ + input_dir: drivaer_aws/drivaer_data_full/ num_processors: 12 diff --git a/examples/domino/download_aws_dataset.sh b/examples/domino/download_aws_dataset.sh new file mode 100644 index 0000000000..f793dd021f --- /dev/null +++ b/examples/domino/download_aws_dataset.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# This Bash script downloads the AWS DrivAer files from the Amazon S3 bucket to a local directory. +# Only the volume files (.vtu), STL files (.stl), and VTP files (.vtp) are downloaded. +# It uses a function, download_run_files, to check for the existence of three specific files (".vtu", ".stl", ".vtp") in a run directory. +# If a file doesn't exist, it's downloaded from the S3 bucket. If it does exist, the download is skipped. +# The script runs multiple downloads in parallel, both within a single run and across multiple runs. +# It also includes checks to prevent overloading the system by limiting the number of parallel downloads. + +# Set the local directory to download the files +LOCAL_DIR="./drivaer_data_full" # <--- This is the directory where the files will be downloaded. + +# Set the S3 bucket and prefix +S3_BUCKET="caemldatasets" +S3_PREFIX="drivaer/dataset" + +# Create the local directory if it doesn't exist +mkdir -p "$LOCAL_DIR" + +# Function to download files for a specific run +download_run_files() { + local i=$1 + RUN_DIR="run_$i" + RUN_LOCAL_DIR="$LOCAL_DIR/$RUN_DIR" + + # Create the run directory if it doesn't exist + mkdir -p "$RUN_LOCAL_DIR" + + # Check if the .vtu file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/volume_$i.vtu" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/volume_$i.vtu" "$RUN_LOCAL_DIR/" & + else + echo "File volume_$i.vtu already exists, skipping download." + fi + + # Check if the .stl file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/drivaer_$i.stl" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/drivaer_$i.stl" "$RUN_LOCAL_DIR/" & + else + echo "File drivaer_$i.stl already exists, skipping download." + fi + + # Check if the .vtp file exists before downloading + if [ ! -f "$RUN_LOCAL_DIR/boundary_$i.vtp" ]; then + aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/boundary_$i.vtp" "$RUN_LOCAL_DIR/" & + else + echo "File boundary_$i.vtp already exists, skipping download." + fi + + wait # Ensure that both files for this run are downloaded before moving to the next run +} + +# Loop through the run folders and download the files +for i in $(seq 1 500); do + download_run_files "$i" & + + # Limit the number of parallel jobs to avoid overloading the system + if (( $(jobs -r | wc -l) >= 8 )); then + wait -n # Wait for the next background job to finish before starting a new one + fi +done + +# Wait for all remaining background jobs to finish +wait diff --git a/examples/domino/process_data.py b/examples/domino/process_data.py new file mode 100644 index 0000000000..d8240354a8 --- /dev/null +++ b/examples/domino/process_data.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code runs the data processing in parallel to load OpenFoam files, process them +and save in the npy format for faster processing in the DoMINO datapipes. Several +parameters such as number of processors, input and output paths, etc. can be +configured in config.yaml in the data_processing tab. +""" + +import multiprocessing +import os +import time + +import hydra +import numpy as np +from omegaconf import DictConfig +from omegaconf import OmegaConf +from openfoam_datapipe import OpenFoamDataset +from physicsnemo.utils.domino.utils import * # noqa: F403 + + +def process_files(*args_list): + ids = args_list[0] + processor_id = args_list[1] + fm_data = args_list[2] + output_dir = args_list[3] + for j in ids: + fname = fm_data.filenames[j] + if len(os.listdir(os.path.join(fm_data.data_path, fname))) == 0: + print(f"Skipping {fname} - empty.") + continue + outname = os.path.join(output_dir, fname) + print("Filename:%s on processor: %d" % (outname, processor_id)) + filename = f"{outname}.npy" + if os.path.exists(filename): + print(f"Skipping {filename} - already exists.") + continue + start_time = time.time() + data_dict = fm_data[j] + np.save(filename, data_dict) + print("Time taken for %d = %f" % (j, time.time() - start_time)) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + + fm_data = OpenFoamDataset( + cfg.data_processor.input_dir, + kind=cfg.data_processor.kind, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + model_type=cfg.model.model_type, + ) + output_dir = cfg.data_processor.output_dir + create_directory(output_dir) # noqa: F405 + n_processors = cfg.data_processor.num_processors + + num_files = len(fm_data) + ids = np.arange(num_files) + num_elements = int(num_files / n_processors) + 1 + process_list = [] + ctx = multiprocessing.get_context("spawn") + for i in range(n_processors): + if i != n_processors - 1: + sf = ids[i * num_elements : i * num_elements + num_elements] + else: + sf = ids[i * num_elements :] + # print(sf) + process = ctx.Process(target=process_files, args=(sf, i, fm_data, output_dir)) + + process.start() + process_list.append(process) + + for process in process_list: + process.join() + + +if __name__ == "__main__": + main() diff --git a/examples/domino/requirements.txt b/examples/domino/requirements.txt index 2ff5dd1ccf..57cc3b0363 100644 --- a/examples/domino/requirements.txt +++ b/examples/domino/requirements.txt @@ -1,4 +1,6 @@ hydra-core importlib_metadata pyvista==0.34.2 +termcolor +treelib warp-lang From 5d0f63e73215cec1bf112ea15db7cd4032d13f63 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Wed, 21 May 2025 00:37:48 +0800 Subject: [PATCH 04/11] feat(domino): support domino for training and test --- examples/domino/README.md | 43 --------------------------------------- 1 file changed, 43 deletions(-) delete mode 100644 examples/domino/README.md diff --git a/examples/domino/README.md b/examples/domino/README.md deleted file mode 100644 index 0994a2e878..0000000000 --- a/examples/domino/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# DoMINO: Decomposable Multi-scale Iterative Neural Operator for External Aerodynamics - -DoMINO代码复现。 - -## 安装依赖 - -```shell -# 安装PaddlePaddle - -cd /path/PaddleScience -pip install -e . - -cd /path/PaddleScience/examples/domino -pip install -r requirements.txt -``` - -## 数据集下载与处理 - -1. 参考[DoMINO](https://github.com/NVIDIA/modulus/tree/main/examples/cfd/external_aerodynamics/domino#training-the-domino-model)数据下载和处理方式,执行`download_aws_dataset.sh`和`process_data.py`,获取数据。 - -2. 本仓库训练数据为`process_data.py`的后处理数据。 - -## 训练 - -1. 修改`conf/config.yaml`路径, 原始配置文件参考[DoMINO](https://github.com/NVIDIA/modulus/blob/main/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml)。 - -2. 训练 - -```shell -cd /path/PaddleScience/examples/domino -python train.py -``` - -## 推理 - -1. 修改`conf/config.yaml`路径, 原始配置文件参考[DoMINO](https://github.com/NVIDIA/modulus/blob/main/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml)。 - -2. 推理 - -```shell -cd /path/PaddleScience/examples/domino -python test.py -``` From baa40d39f331c0383bbf2331f3295027bb63a26e Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Thu, 22 May 2025 22:38:38 +0800 Subject: [PATCH 05/11] feat(domino): remove something of torch --- .../domino/modulus/distributed/manager.py | 14 ++++----- .../domino/modulus/launch/logging/launch.py | 29 +++---------------- .../domino/modulus/launch/logging/mlflow.py | 4 --- .../domino/modulus/launch/logging/utils.py | 14 ++++----- .../domino/modulus/launch/utils/checkpoint.py | 4 +-- examples/domino/modulus/models/model.py | 28 +----------------- examples/domino/train.py | 3 -- 7 files changed, 19 insertions(+), 77 deletions(-) diff --git a/examples/domino/modulus/distributed/manager.py b/examples/domino/modulus/distributed/manager.py index eccc69f6a3..e24aaa4103 100644 --- a/examples/domino/modulus/distributed/manager.py +++ b/examples/domino/modulus/distributed/manager.py @@ -200,7 +200,7 @@ def group_name(self, group=None): @property def broadcast_buffers(self): - """broadcast_buffers in PyTorch DDP""" + """broadcast_buffers in DDP""" return self._broadcast_buffers @broadcast_buffers.setter @@ -210,7 +210,7 @@ def broadcast_buffers(self, broadcast: bool): @property def find_unused_parameters(self): - """find_unused_parameters in PyTorch DDP""" + """find_unused_parameters in DDP""" return self._find_unused_parameters @find_unused_parameters.setter @@ -315,8 +315,7 @@ def initialize(): Initialize distributed manager Current supported initialization methods are: - `ENV`: PyTorch environment variable initialization - https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + `ENV`: Environment variable initialization `SLURM`: Initialization on SLURM systems. Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and `SLURM_LAUNCH_NODE_IPADDR` environment variables. @@ -335,7 +334,6 @@ def initialize(): addr = os.getenv("MASTER_ADDR", "localhost") port = os.getenv("MASTER_PORT", "12355") - # https://pytorch.org/docs/master/notes/cuda.html#id5 os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD") if initialization_method is None: @@ -379,7 +377,7 @@ def setup( backend="nccl", method="env", ): - """Set up PyTorch distributed process group and update manager attributes""" + """Set up distributed process group and update manager attributes""" os.environ["MASTER_ADDR"] = addr os.environ["MASTER_PORT"] = str(port) @@ -443,7 +441,7 @@ def create_process_subgroup( raise AssertionError( "paddle.distributed is unavailable. " "Check paddle build to ensure the distributed package is available. " - "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "If building from source, set `USE_DISTRIBUTED=1` " "to enable the distributed package" ) @@ -513,7 +511,7 @@ def create_orthogonal_process_group( raise AssertionError( "paddle.distributed is unavailable. " "Check paddle build to ensure the distributed package is available. " - "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "If building from source, set `USE_DISTRIBUTED=1` " "to enable the distributed package" ) diff --git a/examples/domino/modulus/launch/logging/launch.py b/examples/domino/modulus/launch/logging/launch.py index c8714da61a..8200032d76 100644 --- a/examples/domino/modulus/launch/logging/launch.py +++ b/examples/domino/modulus/launch/logging/launch.py @@ -26,8 +26,6 @@ from typing import Tuple from typing import Union -import torch -import torch.cuda.profiler as profiler from modulus.distributed import DistributedManager from modulus.distributed import reduce_loss @@ -97,14 +95,7 @@ def __new__(cls, name_space, *args, **kwargs): if DistributedManager.is_initialized(): self.root = DistributedManager().rank == 0 # Profiler utils - if torch.cuda.is_available(): - self.profiler = torch.autograd.profiler.emit_nvtx( - enabled=cls.enable_profiling - ) - self.start_event = torch.cuda.Event(enable_timing=True) - self.end_event = torch.cuda.Event(enable_timing=True) - else: - self.profiler = None + self.profiler = None return self @@ -203,13 +194,9 @@ def __enter__(self): if self.profile and self.profiler: self.logger.warning(f"Starting profile for epoch {self.epoch}") self.profiler.__enter__() - profiler.start() # Timing stuff - if torch.cuda.is_available(): - self.start_event.record() - else: - self.start_event = time.time() + self.start_event = time.time() if self.mlflow_backend: self.mlflow_client.update_run(self.mlflow_run.info.run_id, "RUNNING") @@ -250,18 +237,10 @@ def __exit__(self, exc_type, exc_value, exc_tb): if self.profile and self.profiler: self.logger.warning("Ending profile") self.profiler.__exit__() - profiler.end() # Timing stuff, TODO: histograms not line plots - if torch.cuda.is_available(): - self.end_event.record() - torch.cuda.synchronize() - # Returns milliseconds - # https://pytorch.org/docs/stable/generated/torch.cuda.Event.html#torch.cuda.Event.elapsed_time - epoch_time = self.start_event.elapsed_time(self.end_event) / 1000.0 - else: - end_event = time.time() - epoch_time = end_event - self.start_event + end_event = time.time() + epoch_time = end_event - self.start_event # Return MS for time / iter time_per_iter = 1000 * epoch_time / max([1, self.mini_batch_index]) diff --git a/examples/domino/modulus/launch/logging/mlflow.py b/examples/domino/modulus/launch/logging/mlflow.py index 69e85e0e8e..ae1b68ca51 100644 --- a/examples/domino/modulus/launch/logging/mlflow.py +++ b/examples/domino/modulus/launch/logging/mlflow.py @@ -20,8 +20,6 @@ from typing import Literal from typing import Tuple -import torch - try: import mlflow # noqa: F401 for docs from mlflow.entities.run import Run @@ -156,8 +154,6 @@ def initialize_mlflow( time_string = start_time.strftime("%m/%d/%y %H:%M:%S") client.set_tag(run.info.run_id, "date", time_string) client.set_tag(run.info.run_id, "host", os.uname()[1]) - if torch.cuda.is_available(): - client.set_tag(run.info.run_id, "gpu", torch.cuda.get_device_name(dist.device)) client.set_tag(run.info.run_id, "group", group_name) run = client.get_run(run.info.run_id) diff --git a/examples/domino/modulus/launch/logging/utils.py b/examples/domino/modulus/launch/logging/utils.py index 0a9e66dde8..ad9d032f29 100644 --- a/examples/domino/modulus/launch/logging/utils.py +++ b/examples/domino/modulus/launch/logging/utils.py @@ -16,14 +16,14 @@ from datetime import datetime -import torch +import paddle from modulus.distributed import DistributedManager def create_ddp_group_tag(group_name: str = None) -> str: """Creates a common group tag for logging - For some reason this does not work with multi-node. Seems theres a bug in PyTorch + For some reason this does not work with multi-node. Seems theres a bug in when one uses a distributed util before DDP Parameters @@ -42,15 +42,13 @@ def create_ddp_group_tag(group_name: str = None) -> str: def tint(x): return int(datetime.now().strftime(f"%{x}")) - time_index = torch.IntTensor( - [tint(x) for x in ["m", "d", "y", "H", "M", "S"]] - ).to(dist.device) + time_index = paddle.to_tensor([tint(x) for x in ["m", "d", "y", "H", "M", "S"]]) else: - time_index = torch.IntTensor([0, 0, 0, 0, 0, 0]).to(dist.device) + time_index = paddle.to_tensor([0, 0, 0, 0, 0, 0]) - if torch.distributed.is_available(): + if paddle.distributed.is_available(): # Broadcast group ID to all processes - torch.distributed.broadcast(time_index, src=0) + paddle.distributed.broadcast(time_index, src=0) time_string = f"{time_index[0]}/{time_index[1]}/{time_index[2]}_\ {time_index[3]}-{time_index[4]}-{time_index[5]}" diff --git a/examples/domino/modulus/launch/utils/checkpoint.py b/examples/domino/modulus/launch/utils/checkpoint.py index 1f5f062155..b795dc4048 100644 --- a/examples/domino/modulus/launch/utils/checkpoint.py +++ b/examples/domino/modulus/launch/utils/checkpoint.py @@ -67,7 +67,7 @@ def _get_checkpoint_filename( saving : bool, optional Get filename for saving a new checkpoint, by default False model_type : str - Model type, by default "mdlus" for Modulus models and "pdparams" for PyTorch models + Model type, by default "mdlus" for Modulus models and "pdparams" for models Returns @@ -296,7 +296,7 @@ def load_checkpoint( path : str Path to training checkpoint models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional - A single or list of PyTorch models, by default None + A single or list of models, by default None optimizer : Union[optimizer, None], optional Optimizer, by default None scheduler : Union[scheduler, None], optional diff --git a/examples/domino/modulus/models/model.py b/examples/domino/modulus/models/model.py index ca87ceb4aa..3372e8ecb0 100644 --- a/examples/domino/modulus/models/model.py +++ b/examples/domino/modulus/models/model.py @@ -172,7 +172,6 @@ def forward(self, x, radius=0.025, neighbors_in_radius=10): # paddle does not support multiplication with boolean tensors, # so we convert the mask to float - # x = torch.sum(x * mask, 2) x = paddle.sum(x * mask.to(dtype=x.dtype), 2) x = paddle.reshape(x, (batch_size, x.shape[-1], nx, ny, nz)) @@ -440,23 +439,6 @@ def forward(self, x): return out -# @dataclass -# class MetaData(ModelMetaData): -# name: str = "DoMINO" -# # Optimization -# jit: bool = False -# cuda_graphs: bool = True -# amp: bool = True -# # Inference -# onnx_cpu: bool = True -# onnx_gpu: bool = True -# onnx_runtime: bool = True -# # Physics informed -# var_dim: int = 1 -# func_torch: bool = False -# auto_grad: bool = False - - class DoMINO(nn.Layer): """DoMINO model architecture Parameters @@ -473,10 +455,9 @@ class DoMINO(nn.Layer): Example ------- >>> from modulus.models.domino.model import DoMINO - >>> import torch, os + >>> import os >>> from hydra import compose, initialize >>> from omegaconf import OmegaConf - >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") >>> cfg = OmegaConf.register_new_resolver("eval", eval) >>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"): ... cfg = compose(config_name="config") @@ -537,7 +518,6 @@ class DoMINO(nn.Layer): >>> output = model(input_dict) Module ... >>> print(f"{output[0].shape}, {output[1].shape}") - torch.Size([1, 100, 5]), torch.Size([1, 100, 4]) """ def __init__( @@ -726,9 +706,6 @@ def geo_encoding_local_surface(self, encoding_g, volume_mesh_centers, p_grid): [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] ) - # geo_encoding_sampled = torch.gather(geo_encoding, 2, mapping) * mask - # sdf_encoding_sampled = torch.gather(sdf_encoding, 2, mapping) * mask - # geo_encoding_long_sampled = torch.gather(geo_encoding_long, 2, mapping) * mask geo_encoding_sampled = paddle.take_along_axis( geo_encoding, axis=2, indices=mapping ) * mask.to(dtype=geo_encoding.dtype) @@ -778,9 +755,6 @@ def geo_encoding_local(self, encoding_g, volume_mesh_centers, p_grid): [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] ) - # geo_encoding_sampled = torch.gather(geo_encoding, 2, mapping) * mask - # sdf_encoding_sampled = torch.gather(sdf_encoding, 2, mapping) * mask - # geo_encoding_long_sampled = torch.gather(geo_encoding_long, 2, mapping) * mask geo_encoding_sampled = paddle.take_along_axis( geo_encoding, axis=2, indices=mapping ) * mask.to(dtype=geo_encoding.dtype) diff --git a/examples/domino/train.py b/examples/domino/train.py index dcb96dbb0e..4a9d536d33 100644 --- a/examples/domino/train.py +++ b/examples/domino/train.py @@ -802,9 +802,6 @@ def main(cfg: DictConfig) -> None: output_features_surf=num_surf_vars, model_parameters=cfg.model, ) - # model = torch.compile(model, disable=True) # TODO make this configurable - - # Print model summary (structure and parmeter count). if dist.world_size > 1: model = DataParallel( From f7918c7312a304670082e6a7e32f6838366b4df2 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Thu, 22 May 2025 23:21:35 +0800 Subject: [PATCH 06/11] feat(domino): move physicsnemo to arch --- .../domino/modulus/launch/logging/__init__.py | 17 - .../domino/modulus/launch/logging/console.py | 88 ---- .../domino/modulus/launch/logging/launch.py | 422 ------------------ .../domino/modulus/launch/logging/mlflow.py | 195 -------- .../domino/modulus/launch/logging/utils.py | 58 --- .../domino/modulus/launch/logging/wandb.py | 131 ------ examples/domino/test.py | 39 +- examples/domino/train.py | 18 +- .../arch/physicsnemo}/__init__.py | 0 .../arch/physicsnemo}/datapipes/__init__.py | 0 .../physicsnemo}/datapipes/cae/__init__.py | 0 .../datapipes/cae/domino_datapipe.py | 8 +- .../arch/physicsnemo}/distributed/__init__.py | 0 .../arch/physicsnemo}/distributed/config.py | 0 .../arch/physicsnemo}/distributed/manager.py | 0 .../arch/physicsnemo}/launch/__init__.py | 0 .../physicsnemo}/launch/utils/__init__.py | 0 .../physicsnemo}/launch/utils/checkpoint.py | 37 +- .../arch/physicsnemo}/models/__init__.py | 0 .../physicsnemo}/models/layers/__init__.py | 0 .../physicsnemo}/models/layers/ball_query.py | 0 .../arch/physicsnemo}/models/model.py | 0 .../arch/physicsnemo}/utils/__init__.py | 0 .../arch/physicsnemo}/utils/domino/utils.py | 4 +- .../arch/physicsnemo}/utils/sdf.py | 0 25 files changed, 46 insertions(+), 971 deletions(-) delete mode 100644 examples/domino/modulus/launch/logging/__init__.py delete mode 100644 examples/domino/modulus/launch/logging/console.py delete mode 100644 examples/domino/modulus/launch/logging/launch.py delete mode 100644 examples/domino/modulus/launch/logging/mlflow.py delete mode 100644 examples/domino/modulus/launch/logging/utils.py delete mode 100644 examples/domino/modulus/launch/logging/wandb.py rename {examples/domino/modulus => ppsci/arch/physicsnemo}/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/datapipes/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/datapipes/cae/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/datapipes/cae/domino_datapipe.py (98%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/distributed/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/distributed/config.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/distributed/manager.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/launch/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/launch/utils/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/launch/utils/checkpoint.py (91%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/models/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/models/layers/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/models/layers/ball_query.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/models/model.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/utils/__init__.py (100%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/utils/domino/utils.py (99%) rename {examples/domino/modulus => ppsci/arch/physicsnemo}/utils/sdf.py (100%) diff --git a/examples/domino/modulus/launch/logging/__init__.py b/examples/domino/modulus/launch/logging/__init__.py deleted file mode 100644 index f9b51fe41d..0000000000 --- a/examples/domino/modulus/launch/logging/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .console import PythonLogger # noqa: F401 diff --git a/examples/domino/modulus/launch/logging/console.py b/examples/domino/modulus/launch/logging/console.py deleted file mode 100644 index 423157693e..0000000000 --- a/examples/domino/modulus/launch/logging/console.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -from termcolor import colored - - -class PythonLogger: - """Simple console logger for DL training - This is a WIP - """ - - def __init__(self, name: str = "launch"): - self.logger = logging.getLogger(name) - - def file_logging(self, file_name: str = "launch.log"): - """Log to file""" - if os.path.exists(file_name): - try: - os.remove(file_name) - except FileNotFoundError: - # ignore if already removed (can happen with multiple processes) - pass - formatter = logging.Formatter( - "[%(asctime)s - %(name)s - %(levelname)s] %(message)s", - datefmt="%H:%M:%S", - ) - filehandler = logging.FileHandler(file_name) - filehandler.setFormatter(formatter) - filehandler.setLevel(logging.DEBUG) - self.logger.addHandler(filehandler) - - def log(self, message: str): - """Log message""" - self.logger.info(message) - - def info(self, message: str): - """Log info""" - self.logger.info(colored(message, "light_blue")) - - def success(self, message: str): - """Log success""" - self.logger.info(colored(message, "light_green")) - - def warning(self, message: str): - """Log warning""" - self.logger.warning(colored(message, "light_yellow")) - - def error(self, message: str): - """Log error""" - self.logger.error(colored(message, "light_red")) - - -class RankZeroLoggingWrapper: - """Wrapper class to only log from rank 0 process in distributed training.""" - - def __init__(self, obj, dist): - self.obj = obj - self.dist = dist - - def __getattr__(self, name): - attr = getattr(self.obj, name) - if callable(attr): - - def wrapper(*args, **kwargs): - if self.dist.rank == 0: - return attr(*args, **kwargs) - else: - return None - - return wrapper - else: - return attr diff --git a/examples/domino/modulus/launch/logging/launch.py b/examples/domino/modulus/launch/logging/launch.py deleted file mode 100644 index 8200032d76..0000000000 --- a/examples/domino/modulus/launch/logging/launch.py +++ /dev/null @@ -1,422 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import sys -import time -from os import getcwd -from os import makedirs -from os.path import abspath -from os.path import exists -from os.path import join -from typing import Dict -from typing import Tuple -from typing import Union - -from modulus.distributed import DistributedManager -from modulus.distributed import reduce_loss - -from .console import PythonLogger - - -class LaunchLogger(object): - """Modulus Launch logger - - An abstracted logger class that takes care of several fundamental logging functions. - This class should first be initialized and then used via a context manager. This will - auto compute epoch metrics. This is the standard logger for Modulus examples. - - Parameters - ---------- - name_space : str - Namespace of logger to use. This will define the loggers title in the console and - the wandb group the metric is plotted - epoch : int, optional - Current epoch, by default 1 - num_mini_batch : Union[int, None], optional - Number of mini-batches used to calculate the epochs progress, by default None - profile : bool, optional - Profile code using nvtx markers, by default False - mini_batch_log_freq : int, optional - Frequency to log mini-batch losses, by default 100 - epoch_alert_freq : Union[int, None], optional - Epoch frequency to send training alert, by default None - - Example - ------- - >>> from modulus.launch.logging import LaunchLogger - >>> LaunchLogger.initialize() - >>> epochs = 3 - >>> for i in range(epochs): - ... with LaunchLogger("Train", epoch=i) as log: - ... # Log 3 mini-batches manually - ... log.log_minibatch({"loss": 1.0}) - ... log.log_minibatch({"loss": 2.0}) - ... log.log_minibatch({"loss": 3.0}) - """ - - _instances = {} - console_backend = True - wandb_backend = False - mlflow_backend = False - tensorboard_backend = False - enable_profiling = False - - mlflow_run = None - mlflow_client = None - - def __new__(cls, name_space, *args, **kwargs): - # If namespace already has an instance just return that - if name_space in cls._instances: - return cls._instances[name_space] - - # Otherwise create new singleton instance for this namespace - self = super().__new__(cls) # don't pass remaining parameters to object.__new__ - cls._instances[name_space] = self - - # Constructor set up to only be ran once by a logger - self.pyLogger = PythonLogger(name_space) - self.total_iteration_index = None - # Distributed - self.root = True - if DistributedManager.is_initialized(): - self.root = DistributedManager().rank == 0 - # Profiler utils - self.profiler = None - - return self - - def __init__( - self, - name_space: str, - epoch: int = 1, - num_mini_batch: Union[int, None] = None, - profile: bool = False, - mini_batch_log_freq: int = 100, - epoch_alert_freq: Union[int, None] = None, - ): - self.name_space = name_space - self.mini_batch_index = 0 - self.minibatch_losses = {} - self.epoch_losses = {} - - self.mini_batch_log_freq = mini_batch_log_freq - self.epoch_alert_freq = epoch_alert_freq - self.epoch = epoch - self.num_mini_batch = num_mini_batch - self.profile = profile - # Init initial iteration based on current epoch - if self.total_iteration_index is None: - if num_mini_batch is not None: - self.total_iteration_index = (epoch - 1) * num_mini_batch - else: - self.total_iteration_index = 0 - - # Set x axis metric to epoch for this namespace - if self.wandb_backend: - import wandb - - wandb.define_metric(name_space + "/mini_batch_*", step_metric="iter") - wandb.define_metric(name_space + "/*", step_metric="epoch") - - def log_minibatch(self, losses: Dict[str, float]): - """Logs metrics for a mini-batch epoch - - This function should be called every mini-batch iteration. It will accumulate - loss values over a datapipe. At the end of a epoch the average of these losses - from each mini-batch will get calculated. - - Parameters - ---------- - losses : Dict[str, float] - Dictionary of metrics/loss values to log - """ - self.mini_batch_index += 1 - self.total_iteration_index += 1 - for name, value in losses.items(): - if name not in self.minibatch_losses: - self.minibatch_losses[name] = 0 - self.minibatch_losses[name] += value - - # Log of mini-batch loss - if self.mini_batch_index % self.mini_batch_log_freq == 0: - # Backend Logging - mini_batch_metrics = {} - for name, value in losses.items(): - mini_batch_metrics[f"{self.name_space}/mini_batch_{name}"] = value - self._log_backends( - mini_batch_metrics, step=("iter", self.total_iteration_index) - ) - - # Console - if self.root: - message = "Mini-Batch Losses:" - for name, value in losses.items(): - message += f" {name} = {value:10.3e}," - message = message[:-1] - # If we have datapipe length we can get a percent complete - if self.num_mini_batch: - mbp = 100 * (float(self.mini_batch_index) / self.num_mini_batch) - message = f"[{mbp:.02f}%] " + message - - self.pyLogger.log(message) - - def log_epoch(self, losses: Dict[str, float]): - """Logs metrics for a single epoch - - Parameters - ---------- - losses : Dict[str, float] - Dictionary of metrics/loss values to log - """ - for name, value in losses.items(): - self.epoch_losses[name] = value - - def __enter__(self): - self.mini_batch_index = 0 - self.minibatch_losses = {} - self.epoch_losses = {} - - # Trigger profiling - if self.profile and self.profiler: - self.logger.warning(f"Starting profile for epoch {self.epoch}") - self.profiler.__enter__() - - # Timing stuff - self.start_event = time.time() - - if self.mlflow_backend: - self.mlflow_client.update_run(self.mlflow_run.info.run_id, "RUNNING") - - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - # Abnormal exit dont log - if exc_type is not None: - if self.mlflow_backend: - self.mlflow_client.set_terminated( - self.mlflow_run.info.run_id, status="KILLED" - ) - return - # Reduce mini-batch losses - for name, value in self.minibatch_losses.items(): - process_loss = value / self.mini_batch_index - self.epoch_losses[name] = process_loss - # Compute global loss - if DistributedManager.is_initialized() and DistributedManager().distributed: - self.epoch_losses[name] = reduce_loss(process_loss) - - if self.root: - # Console printing - # TODO: add out of total epochs progress - message = f"Epoch {self.epoch} Metrics:" - for name, value in self.epoch_losses.items(): - message += f" {name} = {value:10.3e}," - message = message[:-1] - self.pyLogger.info(message) - - metrics = { - f"{self.name_space}/{key}": value - for key, value in self.epoch_losses.items() - } - - # Exit profiling - if self.profile and self.profiler: - self.logger.warning("Ending profile") - self.profiler.__exit__() - - # Timing stuff, TODO: histograms not line plots - end_event = time.time() - epoch_time = end_event - self.start_event - - # Return MS for time / iter - time_per_iter = 1000 * epoch_time / max([1, self.mini_batch_index]) - - if self.root: - message = f"Epoch Execution Time: {epoch_time:10.3e}s" - message += f", Time/Iter: {time_per_iter:10.3e}ms" - self.pyLogger.info(message) - - metrics[f"{self.name_space}/Epoch Time (s)"] = epoch_time - metrics[f"{self.name_space}/Time per iter (ms)"] = time_per_iter - - self._log_backends(metrics, step=("epoch", self.epoch)) - - # TODO this should be in some on delete method / clean up - if self.mlflow_backend: - self.mlflow_client.set_terminated( - self.mlflow_run.info.run_id, status="FINISHED" - ) - - # Alert - if ( - self.epoch_alert_freq - and self.root - and self.epoch % self.epoch_alert_freq == 0 - ): - if self.wandb_backend: - import wandb - - from .wandb import alert - - # TODO: Make this a little more informative? - alert( - title=f"{sys.argv[0]} training progress report", - text=f"Run {wandb.run.name} is at epoch {self.epoch}.", - ) - - def _log_backends( - self, - metric_dict: Dict[str, float], - step: Tuple[str, int] = None, - ): - """Logs a dictionary of metrics to different supported backends - - Parameters - ---------- - metric_dict : Dict[str, float] - Metric dictionary - step : Tuple[str, int], optional - Tuple containing (step name, step index), by default None - print : bool, optional - Print metrics, by default False - """ - - # MLFlow Logging - if self.mlflow_backend: - for key, value in metric_dict.items(): - # If value is None just skip - if value is None: - continue - # Keys only allow alpha numeric, ., -, /, _ and spaces - key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key) - self.mlflow_client.log_metric( - self.mlflow_run.info.run_id, key, value, step=step[1] - ) - - # WandB Logging - if self.wandb_backend: - import wandb - - # For WandB send step in as a metric - # Step argument in lod function does not work with multiple log calls at - # different intervals - metric_dict[step[0]] = step[1] - wandb.log(metric_dict) - - def log_figure( - self, - figure, - artifact_file: str = "artifact", - plot_dir: str = "./", - log_to_file: bool = False, - ): - """Logs figures on root process to wand or mlflow. Will store it to file in case neither are selected. - - Parameters - ---------- - figure : Figure - matplotlib or plotly figure to plot - artifact_file : str, optional - File name. CAUTION overrides old files of same name - plot_dir : str, optional - output directory for plot - log_to_file : bool, optional - set to true in case figure shall be stored to file in addition to logging it to mlflow/wandb - """ - dist = DistributedManager() - if dist.rank != 0: - return - - if self.wandb_backend: - import wandb - - wandb.log({artifact_file: figure}) - - if self.mlflow_backend: - self.mlflow_client.log_figure( - figure=figure, - artifact_file=artifact_file, - run_id=self.mlflow_run.info.run_id, - ) - - if (not self.wandb_backend) and (not self.mlflow_backend): - log_to_file = True - - if log_to_file: - plot_dir = abspath(join(getcwd(), plot_dir)) - if not exists(plot_dir): - makedirs(plot_dir) - if not artifact_file.endswith(".png"): - artifact_file += ".png" - figure.savefig(join(plot_dir, artifact_file)) - - @classmethod - def toggle_wandb(cls, value: bool): - """Toggle WandB logging - - Parameters - ---------- - value : bool - Use WandB logging - """ - cls.wandb_backend = value - - @classmethod - def toggle_mlflow(cls, value: bool): - """Toggle MLFlow logging - - Parameters - ---------- - value : bool - Use MLFlow logging - """ - cls.mlflow_backend = value - - @staticmethod - def initialize(use_wandb: bool = False, use_mlflow: bool = False): - """Initialize logging singleton - - Parameters - ---------- - use_wandb : bool, optional - Use WandB logging, by default False - use_mlflow : bool, optional - Use MLFlow logging, by default False - """ - if use_wandb: - import wandb - - if wandb.run is None: - PythonLogger().warning("WandB not initialized, turning off") - use_wandb = False - - if use_wandb: - LaunchLogger.toggle_wandb(True) - wandb.define_metric("epoch") - wandb.define_metric("iter") - - # let only root process log to mlflow - if DistributedManager.is_initialized(): - if DistributedManager().rank != 0: - return - - if LaunchLogger.mlflow_run is None and use_mlflow: - PythonLogger().warning("MLFlow not initialized, turning off") - use_mlflow = False - - if use_mlflow: - LaunchLogger.toggle_mlflow(True) diff --git a/examples/domino/modulus/launch/logging/mlflow.py b/examples/domino/modulus/launch/logging/mlflow.py deleted file mode 100644 index ae1b68ca51..0000000000 --- a/examples/domino/modulus/launch/logging/mlflow.py +++ /dev/null @@ -1,195 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from datetime import datetime -from pathlib import Path -from typing import Literal -from typing import Tuple - -try: - import mlflow # noqa: F401 for docs - from mlflow.entities.run import Run - from mlflow.tracking import MlflowClient -except ImportError: - raise ImportError( - "These utilities require the MLFlow library. Install MLFlow using `pip install mlflow`. " - + "For more info, refer: https://www.mlflow.org/docs/2.5.0/quickstart.html#install-mlflow" - ) - -from modulus.distributed import DistributedManager - -from .console import PythonLogger -from .launch import LaunchLogger - -logger = PythonLogger("mlflow") - - -def initialize_mlflow( - experiment_name: str, - experiment_desc: str = None, - run_name: str = None, - run_desc: str = None, - user_name: str = None, - mode: Literal["offline", "online", "ngc"] = "offline", - tracking_location: str = None, - artifact_location: str = None, -) -> Tuple[MlflowClient, Run]: - """Initializes MLFlow logging client and run. - - Parameters - ---------- - experiment_name : str - Experiment name - experiment_desc : str, optional - Experiment description, by default None - run_name : str, optional - Run name, by default None - run_desc : str, optional - Run description, by default None - user_name : str, optional - User name, by default None - mode : str, optional - MLFlow mode. Supports "offline", "online" and "ngc". Offline mode records logs to - local file system. Online mode is for remote tracking servers. NGC is specific - standardized setup for NGC runs, default "offline" - tracking_location : str, optional - Tracking location for MLFlow. For offline this would be an absolute folder directory. - For online mode this would be a http URI or databricks. For NGC, this option is - ignored, by default "//mlruns" - artifact_location : str, optional - Optional separate artifact location, by default None - - Note - ---- - For NGC mode, one needs to mount a NGC workspace / folder system with a metric folder - at `/mlflow/mlflow_metrics/` and a artifact folder at `/mlflow/mlflow_artifacts/`. - - Note - ---- - This will set up Modulus Launch logger for MLFlow logging. Only one MLFlow logging - client is supported with the Modulus Launch logger. - - Returns - ------- - Tuple[MlflowClient, Run] - Returns MLFlow logging client and active run object - """ - dist = DistributedManager() - if dist.rank != 0: # only root process should be logging to mlflow - return - - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y_%H-%M-%S") - group_name = f"{run_name}_{time_string}" - - # Set default value here for Hydra - if tracking_location is None: - tracking_location = str(Path("./mlruns").absolute()) - - # Set up URI (remote or local) - if mode == "online": - tracking_uri = tracking_location - elif mode == "offline": - if not tracking_location.startswith("file://"): - tracking_location = "file://" + tracking_location - tracking_uri = tracking_location - elif mode == "ngc": - if not Path("/mlflow/mlflow_metrics").is_dir(): - raise IOError( - "NGC MLFlow config select but metrics folder '/mlflow/mlflow_metrics'" - + " not found. Aborting MLFlow setup." - ) - return - - if not Path("/mlflow/mlflow_artifacts").is_dir(): - raise IOError( - "NGC MLFlow config select but artifact folder '/mlflow/mlflow_artifacts'" - + " not found. Aborting MLFlow setup." - ) - return - tracking_uri = "file:///mlflow/mlflow_metrics" - artifact_location = "file:///mlflow/mlflow_artifacts" - else: - logger.warning(f"Unsupported MLFlow mode '{mode}' provided") - tracking_uri = "file://" + str(Path("./mlruns").absolute()) - - mlflow.set_tracking_uri(tracking_uri) - client = MlflowClient() - - check_mlflow_logged_in(client) - - experiment = client.get_experiment_by_name(experiment_name) - # If experiment does not exist create one - if experiment is None: - logger.info(f"No {experiment_name} experiment found, creating...") - experiment_id = client.create_experiment( - experiment_name, artifact_location=artifact_location - ) - client.set_experiment_tag(experiment_id, "mlflow.note.content", experiment_desc) - else: - logger.success(f"Existing {experiment_name} experiment found") - experiment_id = experiment.experiment_id - - # Create an run and set its tags - run = client.create_run( - experiment_id, tags={"mlflow.user": user_name}, run_name=run_name - ) - client.set_tag(run.info.run_id, "mlflow.note.content", run_desc) - - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y %H:%M:%S") - client.set_tag(run.info.run_id, "date", time_string) - client.set_tag(run.info.run_id, "host", os.uname()[1]) - client.set_tag(run.info.run_id, "group", group_name) - - run = client.get_run(run.info.run_id) - - # Set run instance in Modulus logger - LaunchLogger.mlflow_run = run - LaunchLogger.mlflow_client = client - - return client, run - - -def check_mlflow_logged_in(client: MlflowClient): - """Checks to see if MLFlow URI is functioning - - This isn't the best solution right now and overrides http timeout. Can update if MLFlow - use is increased. - """ - - logger.warning( - "Checking MLFlow logging location is working (if this hangs it's not)" - ) - t0 = os.environ.get("MLFLOW_HTTP_REQUEST_TIMEOUT", None) - try: - # Adjust http timeout to 5 seconds - os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = str(max(int(t0), 5)) if t0 else "5" - experiment = client.create_experiment("test") - client.delete_experiment(experiment) - - except Exception as e: - logger.error("Failed to validate MLFlow logging location works") - raise e - finally: - # Restore http request - if t0: - os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = t0 - else: - del os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] - - logger.success("MLFlow logging location is working") diff --git a/examples/domino/modulus/launch/logging/utils.py b/examples/domino/modulus/launch/logging/utils.py deleted file mode 100644 index ad9d032f29..0000000000 --- a/examples/domino/modulus/launch/logging/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime - -import paddle -from modulus.distributed import DistributedManager - - -def create_ddp_group_tag(group_name: str = None) -> str: - """Creates a common group tag for logging - - For some reason this does not work with multi-node. Seems theres a bug in - when one uses a distributed util before DDP - - Parameters - ---------- - group_name : str, optional - Optional group name prefix. If None will use ``"DDP_Group_"``, by default None - - Returns - ------- - str - Group tag - """ - dist = DistributedManager() - if dist.rank == 0: - # Store time stamp as int tensor for broadcasting - def tint(x): - return int(datetime.now().strftime(f"%{x}")) - - time_index = paddle.to_tensor([tint(x) for x in ["m", "d", "y", "H", "M", "S"]]) - else: - time_index = paddle.to_tensor([0, 0, 0, 0, 0, 0]) - - if paddle.distributed.is_available(): - # Broadcast group ID to all processes - paddle.distributed.broadcast(time_index, src=0) - - time_string = f"{time_index[0]}/{time_index[1]}/{time_index[2]}_\ - {time_index[3]}-{time_index[4]}-{time_index[5]}" - - if group_name is None: - group_name = "DDP_Group" - return group_name + "_" + time_string diff --git a/examples/domino/modulus/launch/logging/wandb.py b/examples/domino/modulus/launch/logging/wandb.py deleted file mode 100644 index f147cb6888..0000000000 --- a/examples/domino/modulus/launch/logging/wandb.py +++ /dev/null @@ -1,131 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Weights and Biases Routines and Utilities""" - -import logging -import os -from datetime import datetime -from pathlib import Path -from typing import Literal - -import wandb -from modulus.distributed import DistributedManager -from wandb import AlertLevel - -from .utils import create_ddp_group_tag - -DEFAULT_WANDB_CONFIG = "~/.netrc" -logger = logging.getLogger(__name__) - -_WANDB_INITIALIZED = False - - -def initialize_wandb( - project: str, - entity: str, - name: str = "train", - group: str = None, - sync_tensorboard: bool = False, - save_code: bool = False, - resume: str = None, - wandb_id: str = None, - config=None, - mode: Literal["offline", "online", "disabled"] = "offline", - results_dir: str = None, -): - """Function to initialize wandb client with the weights and biases server. - - Parameters - ---------- - project : str - Name of the project to sync data with - entity : str, - Name of the wanbd entity - sync_tensorboard : bool, optional - sync tensorboard summary writer with wandb, by default False - save_code : bool, optional - Whether to push a copy of the code to wandb dashboard, by default False - name : str, optional - Name of the task running, by default "train" - group : str, optional - Group name of the task running. Good to set for ddp runs, by default None - resume: str, optional - Sets the resuming behavior. Options: "allow", "must", "never", "auto" or None, - by default None. - wandb_id: str, optional - A unique ID for this run, used for resuming. Used in conjunction with `resume` - parameter to enable experiment resuming. - See W&B documentation for more details: - https://docs.wandb.ai/guides/runs/resuming/ - config : optional - a dictionary-like object for saving inputs , like hyperparameters. - If dict, argparse or absl.flags, it will load the key value pairs into the - wandb.config object. If str, it will look for a yaml file by that name, - by default None. - mode: str, optional - Can be "offline", "online" or "disabled", by default "offline" - results_dir : str, optional - Output directory of the experiment, by default "//wandb" - """ - - # Set default value here for Hydra - if results_dir is None: - results_dir = str(Path("./wandb").absolute()) - - wandb_dir = results_dir - if DistributedManager.is_initialized() and DistributedManager().distributed: - if group is None: - group = create_ddp_group_tag() - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") - wandb_name = f"{name}_Process_{DistributedManager().rank}_{time_string}" - else: - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") - wandb_name = f"{name}_{time_string}" - - if not os.path.exists(wandb_dir): - os.makedirs(wandb_dir, exist_ok=True) - - wandb.init( - project=project, - entity=entity, - sync_tensorboard=sync_tensorboard, - name=wandb_name, - resume=resume, - config=config, - mode=mode, - dir=wandb_dir, - group=group, - save_code=save_code, - id=wandb_id, - ) - - -def alert(title, text, duration=300, level=0, is_master=True): - """Send alert.""" - alert_levels = {0: AlertLevel.INFO, 1: AlertLevel.WARN, 2: AlertLevel.ERROR} - if is_wandb_initialized() and is_master: - wandb.alert( - title=title, text=text, level=alert_levels[level], wait_duration=duration - ) - - -def is_wandb_initialized(): - """Check if wandb has been initialized.""" - global _WANDB_INITIALIZED - return _WANDB_INITIALIZED diff --git a/examples/domino/test.py b/examples/domino/test.py index 45a43f20e1..a98c79e3ee 100644 --- a/examples/domino/test.py +++ b/examples/domino/test.py @@ -35,27 +35,28 @@ import pyvista as pv import vtk from hydra.utils import to_absolute_path -from modulus.distributed import DistributedManager -from modulus.models.model import DoMINO -from modulus.utils.domino.utils import KDTree -from modulus.utils.domino.utils import calculate_center_of_mass -from modulus.utils.domino.utils import calculate_normal_positional_encoding -from modulus.utils.domino.utils import create_directory -from modulus.utils.domino.utils import create_grid -from modulus.utils.domino.utils import get_fields -from modulus.utils.domino.utils import get_filenames -from modulus.utils.domino.utils import get_node_to_elem -from modulus.utils.domino.utils import get_volume_data -from modulus.utils.domino.utils import normalize -from modulus.utils.domino.utils import unnormalize -from modulus.utils.domino.utils import write_to_vtp -from modulus.utils.domino.utils import write_to_vtu -from modulus.utils.sdf import signed_distance_field from omegaconf import DictConfig from omegaconf import OmegaConf from paddle import DataParallel from vtk.util import numpy_support +from ppsci.arch.physicsnemo.distributed import DistributedManager +from ppsci.arch.physicsnemo.models.model import DoMINO +from ppsci.arch.physicsnemo.utils.domino.utils import KDTree +from ppsci.arch.physicsnemo.utils.domino.utils import cal_normal_positional_encoding +from ppsci.arch.physicsnemo.utils.domino.utils import calculate_center_of_mass +from ppsci.arch.physicsnemo.utils.domino.utils import create_directory +from ppsci.arch.physicsnemo.utils.domino.utils import create_grid +from ppsci.arch.physicsnemo.utils.domino.utils import get_fields +from ppsci.arch.physicsnemo.utils.domino.utils import get_filenames +from ppsci.arch.physicsnemo.utils.domino.utils import get_node_to_elem +from ppsci.arch.physicsnemo.utils.domino.utils import get_volume_data +from ppsci.arch.physicsnemo.utils.domino.utils import normalize +from ppsci.arch.physicsnemo.utils.domino.utils import unnormalize +from ppsci.arch.physicsnemo.utils.domino.utils import write_to_vtp +from ppsci.arch.physicsnemo.utils.domino.utils import write_to_vtu +from ppsci.arch.physicsnemo.utils.sdf import signed_distance_field + AIR_DENSITY = 1.205 STREAM_VELOCITY = 30.00 @@ -484,7 +485,7 @@ def main(cfg: DictConfig): ) if cfg.model.positional_encoding: - pos_surface_center_of_mass = calculate_normal_positional_encoding( + pos_surface_center_of_mass = cal_normal_positional_encoding( surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] ) else: @@ -562,10 +563,10 @@ def main(cfg: DictConfig): sdf_node_closest_point = sdf_node_closest_point.numpy() if cfg.model.positional_encoding: - pos_volume_closest = calculate_normal_positional_encoding( + pos_volume_closest = cal_normal_positional_encoding( volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz] ) - pos_volume_center_of_mass = calculate_normal_positional_encoding( + pos_volume_center_of_mass = cal_normal_positional_encoding( volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] ) else: diff --git a/examples/domino/train.py b/examples/domino/train.py index 4a9d536d33..74c5267e90 100644 --- a/examples/domino/train.py +++ b/examples/domino/train.py @@ -29,7 +29,6 @@ import os import re -import sys import time import hydra @@ -44,16 +43,13 @@ from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(os.path.dirname(SCRIPT_DIR)) - -from modulus.datapipes.cae.domino_datapipe import DoMINODataPipe # noqa: E402 -from modulus.distributed import DistributedManager # noqa: E402 -from modulus.launch.utils import load_checkpoint # noqa: E402 -from modulus.launch.utils import save_checkpoint # noqa: E402 -from modulus.models.model import DoMINO # noqa: E402 -from modulus.utils.domino.utils import create_directory # noqa: E402 -from modulus.utils.domino.utils import mean_std_sampling # noqa: E402 +from ppsci.arch.physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe +from ppsci.arch.physicsnemo.distributed import DistributedManager +from ppsci.arch.physicsnemo.launch.utils import load_checkpoint +from ppsci.arch.physicsnemo.launch.utils import save_checkpoint +from ppsci.arch.physicsnemo.models.model import DoMINO +from ppsci.arch.physicsnemo.utils.domino.utils import create_directory +from ppsci.arch.physicsnemo.utils.domino.utils import mean_std_sampling def relative_loss_fn(output, target, padded_value=-10): diff --git a/examples/domino/modulus/__init__.py b/ppsci/arch/physicsnemo/__init__.py similarity index 100% rename from examples/domino/modulus/__init__.py rename to ppsci/arch/physicsnemo/__init__.py diff --git a/examples/domino/modulus/datapipes/__init__.py b/ppsci/arch/physicsnemo/datapipes/__init__.py similarity index 100% rename from examples/domino/modulus/datapipes/__init__.py rename to ppsci/arch/physicsnemo/datapipes/__init__.py diff --git a/examples/domino/modulus/datapipes/cae/__init__.py b/ppsci/arch/physicsnemo/datapipes/cae/__init__.py similarity index 100% rename from examples/domino/modulus/datapipes/cae/__init__.py rename to ppsci/arch/physicsnemo/datapipes/cae/__init__.py diff --git a/examples/domino/modulus/datapipes/cae/domino_datapipe.py b/ppsci/arch/physicsnemo/datapipes/cae/domino_datapipe.py similarity index 98% rename from examples/domino/modulus/datapipes/cae/domino_datapipe.py rename to ppsci/arch/physicsnemo/datapipes/cae/domino_datapipe.py index c5768eb4f7..17157ea2a0 100644 --- a/examples/domino/modulus/datapipes/cae/domino_datapipe.py +++ b/ppsci/arch/physicsnemo/datapipes/cae/domino_datapipe.py @@ -37,8 +37,8 @@ from ...utils.domino.utils import KDTree from ...utils.domino.utils import area_weighted_shuffle_array +from ...utils.domino.utils import cal_normal_positional_encoding from ...utils.domino.utils import calculate_center_of_mass -from ...utils.domino.utils import calculate_normal_positional_encoding from ...utils.domino.utils import create_grid from ...utils.domino.utils import get_filenames from ...utils.domino.utils import normalize @@ -276,12 +276,12 @@ def __getitem__(self, idx): sdf_node_closest_point = sdf_node_closest_point.numpy() if self.positional_encoding: - pos_normals_closest_vol = calculate_normal_positional_encoding( + pos_normals_closest_vol = cal_normal_positional_encoding( volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz], ) - pos_normals_com_vol = calculate_normal_positional_encoding( + pos_normals_com_vol = cal_normal_positional_encoding( volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] ) else: @@ -377,7 +377,7 @@ def __getitem__(self, idx): ) if self.positional_encoding: - pos_normals_com_surface = calculate_normal_positional_encoding( + pos_normals_com_surface = cal_normal_positional_encoding( surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] ) else: diff --git a/examples/domino/modulus/distributed/__init__.py b/ppsci/arch/physicsnemo/distributed/__init__.py similarity index 100% rename from examples/domino/modulus/distributed/__init__.py rename to ppsci/arch/physicsnemo/distributed/__init__.py diff --git a/examples/domino/modulus/distributed/config.py b/ppsci/arch/physicsnemo/distributed/config.py similarity index 100% rename from examples/domino/modulus/distributed/config.py rename to ppsci/arch/physicsnemo/distributed/config.py diff --git a/examples/domino/modulus/distributed/manager.py b/ppsci/arch/physicsnemo/distributed/manager.py similarity index 100% rename from examples/domino/modulus/distributed/manager.py rename to ppsci/arch/physicsnemo/distributed/manager.py diff --git a/examples/domino/modulus/launch/__init__.py b/ppsci/arch/physicsnemo/launch/__init__.py similarity index 100% rename from examples/domino/modulus/launch/__init__.py rename to ppsci/arch/physicsnemo/launch/__init__.py diff --git a/examples/domino/modulus/launch/utils/__init__.py b/ppsci/arch/physicsnemo/launch/utils/__init__.py similarity index 100% rename from examples/domino/modulus/launch/utils/__init__.py rename to ppsci/arch/physicsnemo/launch/utils/__init__.py diff --git a/examples/domino/modulus/launch/utils/checkpoint.py b/ppsci/arch/physicsnemo/launch/utils/checkpoint.py similarity index 91% rename from examples/domino/modulus/launch/utils/checkpoint.py rename to ppsci/arch/physicsnemo/launch/utils/checkpoint.py index b795dc4048..4352cebdd4 100644 --- a/examples/domino/modulus/launch/utils/checkpoint.py +++ b/ppsci/arch/physicsnemo/launch/utils/checkpoint.py @@ -28,17 +28,14 @@ from paddle.amp import GradScaler from paddle.optimizer.lr import LRScheduler -from ...distributed import DistributedManager -from ...launch.logging import PythonLogger +from ppsci.utils import logger -# from modulus.utils.capture import _StaticCapture +from ...distributed import DistributedManager optimizer = NewType("optimizer", paddle.optimizer) scheduler = NewType("scheduler", LRScheduler) scaler = NewType("scaler", GradScaler) -checkpoint_logging = PythonLogger("checkpoint") - def _get_checkpoint_filename( path: str, @@ -80,7 +77,7 @@ def _get_checkpoint_filename( # model_parallel_rank should be the same as the process rank itself and # only rank 0 saves if not DistributedManager.is_initialized(): - checkpoint_logging.warning( + logger.warning( "`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors" ) DistributedManager.initialize() @@ -219,7 +216,7 @@ def save_checkpoint( """ # Create checkpoint directory if it does not exist if not Path(path).is_dir(): - checkpoint_logging.warning( + logger.warning( f"Output directory {path} does not exist, will " "attempt to create" ) Path(path).mkdir(parents=True, exist_ok=True) @@ -240,7 +237,7 @@ def save_checkpoint( # Save state dictionary paddle.save(model.state_dict(), file_name) - checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + logger.info(f"Saved model state dictionary: {file_name}") # == Saving training checkpoint == checkpoint_dict = {} @@ -273,7 +270,7 @@ def save_checkpoint( checkpoint_dict, output_filename, ) - checkpoint_logging.success(f"Saved training checkpoint: {output_filename}") + logger.info(f"Saved training checkpoint: {output_filename}") def load_checkpoint( @@ -316,7 +313,7 @@ def load_checkpoint( """ # Check if checkpoint directory exists if not Path(path).is_dir(): - checkpoint_logging.warning( + logger.warning( f"Provided checkpoint directory {path} does not exist, skipping load" ) return 0 @@ -335,46 +332,40 @@ def load_checkpoint( path, name, index=epoch, model_type=model_type ) if not Path(file_name).exists(): - checkpoint_logging.error( + logger.error( f"Could not find valid model file {file_name}, skipping load" ) continue # Load state dictionary model.set_state_dict(paddle.load(file_name)) - checkpoint_logging.success(f"Loaded model state dictionary {file_name}") + logger.info(f"Loaded model state dictionary {file_name}") # == Loading training checkpoint == checkpoint_filename = _get_checkpoint_filename( path, index=epoch, model_type="pdparams" ) if not Path(checkpoint_filename).is_file(): - checkpoint_logging.warning( - "Could not find valid checkpoint file, skipping load" - ) + logger.warning("Could not find valid checkpoint file, skipping load") return 0 checkpoint_dict = paddle.load(checkpoint_filename) - checkpoint_logging.success(f"Loaded checkpoint file {checkpoint_filename}") + logger.info(f"Loaded checkpoint file {checkpoint_filename}") # Optimizer state dict if optimizer and "optimizer_state_dict" in checkpoint_dict: optimizer.set_state_dict(checkpoint_dict["optimizer_state_dict"]) - checkpoint_logging.success("Loaded optimizer state dictionary") + logger.info("Loaded optimizer state dictionary") # Scheduler state dict if scheduler and "scheduler_state_dict" in checkpoint_dict: scheduler.set_state_dict(checkpoint_dict["scheduler_state_dict"]) - checkpoint_logging.success("Loaded scheduler state dictionary") + logger.info("Loaded scheduler state dictionary") # Scaler state dict if scaler and "scaler_state_dict" in checkpoint_dict: scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) - checkpoint_logging.success("Loaded grad scaler state dictionary") - - # if "static_capture_state_dict" in checkpoint_dict: - # _StaticCapture.set_state_dict(checkpoint_dict["static_capture_state_dict"]) - # checkpoint_logging.success("Loaded static capture state dictionary") + logger.info("Loaded grad scaler state dictionary") epoch = 0 if "epoch" in checkpoint_dict: diff --git a/examples/domino/modulus/models/__init__.py b/ppsci/arch/physicsnemo/models/__init__.py similarity index 100% rename from examples/domino/modulus/models/__init__.py rename to ppsci/arch/physicsnemo/models/__init__.py diff --git a/examples/domino/modulus/models/layers/__init__.py b/ppsci/arch/physicsnemo/models/layers/__init__.py similarity index 100% rename from examples/domino/modulus/models/layers/__init__.py rename to ppsci/arch/physicsnemo/models/layers/__init__.py diff --git a/examples/domino/modulus/models/layers/ball_query.py b/ppsci/arch/physicsnemo/models/layers/ball_query.py similarity index 100% rename from examples/domino/modulus/models/layers/ball_query.py rename to ppsci/arch/physicsnemo/models/layers/ball_query.py diff --git a/examples/domino/modulus/models/model.py b/ppsci/arch/physicsnemo/models/model.py similarity index 100% rename from examples/domino/modulus/models/model.py rename to ppsci/arch/physicsnemo/models/model.py diff --git a/examples/domino/modulus/utils/__init__.py b/ppsci/arch/physicsnemo/utils/__init__.py similarity index 100% rename from examples/domino/modulus/utils/__init__.py rename to ppsci/arch/physicsnemo/utils/__init__.py diff --git a/examples/domino/modulus/utils/domino/utils.py b/ppsci/arch/physicsnemo/utils/domino/utils.py similarity index 99% rename from examples/domino/modulus/utils/domino/utils.py rename to ppsci/arch/physicsnemo/utils/domino/utils.py index 351eead555..28f61de49f 100644 --- a/examples/domino/modulus/utils/domino/utils.py +++ b/ppsci/arch/physicsnemo/utils/domino/utils.py @@ -229,9 +229,7 @@ def get_surface_data(polydata, variables): return vertices, fields, edges -def calculate_normal_positional_encoding( - coordinates_a, coordinates_b=None, cell_length=[] -): +def cal_normal_positional_encoding(coordinates_a, coordinates_b=None, cell_length=[]): """Function to get normal positional encoding""" dx = cell_length[0] dy = cell_length[1] diff --git a/examples/domino/modulus/utils/sdf.py b/ppsci/arch/physicsnemo/utils/sdf.py similarity index 100% rename from examples/domino/modulus/utils/sdf.py rename to ppsci/arch/physicsnemo/utils/sdf.py From 905be9ad273890d7fa316d748fb9a9e6b4f998cf Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Thu, 29 May 2025 02:00:26 +0800 Subject: [PATCH 07/11] feat(domino): refactor physicsnemo --- examples/domino/process_data.py | 36 +- examples/domino/test.py | 24 +- examples/domino/train.py | 95 +- ppsci/arch/physicsnemo/__init__.py | 0 ppsci/arch/physicsnemo/datapipes/__init__.py | 15 - .../arch/physicsnemo/distributed/__init__.py | 20 - ppsci/arch/physicsnemo/distributed/config.py | 250 ---- ppsci/arch/physicsnemo/distributed/manager.py | 609 --------- ppsci/arch/physicsnemo/launch/__init__.py | 15 - .../arch/physicsnemo/launch/utils/__init__.py | 18 - .../physicsnemo/launch/utils/checkpoint.py | 378 ------ ppsci/arch/physicsnemo/models/__init__.py | 0 .../physicsnemo/models/layers/__init__.py | 15 - .../physicsnemo/models/layers/ball_query.py | 322 ----- ppsci/arch/physicsnemo/models/model.py | 1178 ----------------- ppsci/arch/physicsnemo/utils/__init__.py | 15 - ppsci/arch/physicsnemo/utils/domino/utils.py | 383 ------ ppsci/data/dataset/__init__.py | 2 + .../cae => data/dataset}/domino_datapipe.py | 128 +- ppsci/data/process/__init__.py | 1 + .../cae => data/process/openfoam}/__init__.py | 18 +- ppsci/data/process/openfoam/preprocess.py | 42 + ppsci/utils/__init__.py | 2 + ppsci/{arch/physicsnemo => }/utils/sdf.py | 15 +- requirements.txt | 6 + 25 files changed, 232 insertions(+), 3355 deletions(-) delete mode 100644 ppsci/arch/physicsnemo/__init__.py delete mode 100644 ppsci/arch/physicsnemo/datapipes/__init__.py delete mode 100644 ppsci/arch/physicsnemo/distributed/__init__.py delete mode 100644 ppsci/arch/physicsnemo/distributed/config.py delete mode 100644 ppsci/arch/physicsnemo/distributed/manager.py delete mode 100644 ppsci/arch/physicsnemo/launch/__init__.py delete mode 100644 ppsci/arch/physicsnemo/launch/utils/__init__.py delete mode 100644 ppsci/arch/physicsnemo/launch/utils/checkpoint.py delete mode 100644 ppsci/arch/physicsnemo/models/__init__.py delete mode 100644 ppsci/arch/physicsnemo/models/layers/__init__.py delete mode 100644 ppsci/arch/physicsnemo/models/layers/ball_query.py delete mode 100644 ppsci/arch/physicsnemo/models/model.py delete mode 100644 ppsci/arch/physicsnemo/utils/__init__.py delete mode 100644 ppsci/arch/physicsnemo/utils/domino/utils.py rename ppsci/{arch/physicsnemo/datapipes/cae => data/dataset}/domino_datapipe.py (86%) rename ppsci/{arch/physicsnemo/datapipes/cae => data/process/openfoam}/__init__.py (68%) create mode 100644 ppsci/data/process/openfoam/preprocess.py rename ppsci/{arch/physicsnemo => }/utils/sdf.py (95%) diff --git a/examples/domino/process_data.py b/examples/domino/process_data.py index d8240354a8..9fb0d5d572 100644 --- a/examples/domino/process_data.py +++ b/examples/domino/process_data.py @@ -1,18 +1,18 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino """ This code runs the data processing in parallel to load OpenFoam files, process them @@ -22,8 +22,6 @@ """ import multiprocessing -import os -import time import hydra import numpy as np @@ -32,27 +30,7 @@ from openfoam_datapipe import OpenFoamDataset from physicsnemo.utils.domino.utils import * # noqa: F403 - -def process_files(*args_list): - ids = args_list[0] - processor_id = args_list[1] - fm_data = args_list[2] - output_dir = args_list[3] - for j in ids: - fname = fm_data.filenames[j] - if len(os.listdir(os.path.join(fm_data.data_path, fname))) == 0: - print(f"Skipping {fname} - empty.") - continue - outname = os.path.join(output_dir, fname) - print("Filename:%s on processor: %d" % (outname, processor_id)) - filename = f"{outname}.npy" - if os.path.exists(filename): - print(f"Skipping {filename} - already exists.") - continue - start_time = time.time() - data_dict = fm_data[j] - np.save(filename, data_dict) - print("Time taken for %d = %f" % (j, time.time() - start_time)) +from ppsci.data.process.openfoam import process_files # noqa: F401 @hydra.main(version_base="1.3", config_path="conf", config_name="config") diff --git a/examples/domino/test.py b/examples/domino/test.py index a98c79e3ee..2a0326f818 100644 --- a/examples/domino/test.py +++ b/examples/domino/test.py @@ -1,30 +1,18 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -This code defines a distributed pipeline for testing the DoMINO model on -CFD datasets. It includes the instantiating the DoMINO model and datapipe, -automatically loading the most recent checkpoint, reading the VTP/VTU/STL -testing files, calculation of parameters required for DoMINO model and -evaluating the model in parallel using DataParallel across multiple -GPUs. This is a common recipe that enables training of combined models for surface -and volume as well either of them separately. The model predictions are loaded in -the the VTP/VTU files and saved in the specified directory. The eval tab in -config.yaml can be used to specify the input and output directories. -""" +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino import os import re diff --git a/examples/domino/train.py b/examples/domino/train.py index 74c5267e90..40ea1f3b91 100644 --- a/examples/domino/train.py +++ b/examples/domino/train.py @@ -1,31 +1,18 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -This code defines a distributed pipeline for training the DoMINO model on -CFD datasets. It includes the computation of scaling factors, instantiating -the DoMINO model and datapipe, automatically loading the most recent checkpoint, -training the model in parallel using DistributedDataParallel across multiple -GPUs, calculating the loss and updating model parameters using mixed precision. -This is a common recipe that enables training of combined models for surface and -volume as well either of them separately. Validation is also conducted every epoch, -where predictions are compared against ground truth values. The code logs training -and validation metrics to TensorBoard. The train tab in config.yaml can be used to -specify batch size, number of epochs and other training parameters. -""" +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino import os import re @@ -34,6 +21,7 @@ import hydra import numpy as np import paddle +import paddle.distributed as dist from hydra.utils import to_absolute_path from omegaconf import DictConfig from omegaconf import OmegaConf @@ -43,13 +31,14 @@ from paddle.io import DataLoader from paddle.io import DistributedBatchSampler -from ppsci.arch.physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe -from ppsci.arch.physicsnemo.distributed import DistributedManager -from ppsci.arch.physicsnemo.launch.utils import load_checkpoint -from ppsci.arch.physicsnemo.launch.utils import save_checkpoint -from ppsci.arch.physicsnemo.models.model import DoMINO -from ppsci.arch.physicsnemo.utils.domino.utils import create_directory -from ppsci.arch.physicsnemo.utils.domino.utils import mean_std_sampling +from ppsci.arch.physicsnemo import DoMINO +from ppsci.arch.physicsnemo import create_directory +from ppsci.arch.physicsnemo import load_checkpoint +from ppsci.arch.physicsnemo import mean_std_sampling +from ppsci.arch.physicsnemo import save_checkpoint +from ppsci.data.dataset.domino_datapipe import DoMINODataPipe + +paddle.set_device("gpu") def relative_loss_fn(output, target, padded_value=-10): @@ -685,9 +674,7 @@ def main(cfg: DictConfig) -> None: input_path_val = cfg.data.input_dir_val model_type = cfg.model.model_type - # initialize distributed manager - DistributedManager.initialize() - dist = DistributedManager() + dist.init_parallel_env() print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") @@ -772,20 +759,23 @@ def main(cfg: DictConfig) -> None: bounding_box_dims_surf=cfg.data.bounding_box_surface, num_surface_neighbors=cfg.model.num_surface_neighbors, ) - + print(f">>>>>> paddle.distributed.get_rank(): {paddle.distributed.get_rank()}") + print( + f">>>>>> paddle.distributed.get_world_size(): {paddle.distributed.get_world_size()}" + ) train_sampler = DistributedBatchSampler( train_dataset, batch_size=1, - num_replicas=dist.world_size, - rank=dist.rank, + num_replicas=paddle.distributed.get_world_size(), + rank=paddle.distributed.get_rank(), **cfg.train.sampler, ) val_sampler = DistributedBatchSampler( val_dataset, batch_size=1, - num_replicas=dist.world_size, - rank=dist.rank, + num_replicas=paddle.distributed.get_world_size(), + rank=paddle.distributed.get_rank(), **cfg.val.sampler, ) @@ -799,10 +789,9 @@ def main(cfg: DictConfig) -> None: model_parameters=cfg.model, ) - if dist.world_size > 1: + if paddle.distributed.get_world_size() > 1: model = DataParallel( model, - find_unused_parameters=dist.find_unused_parameters, ) optimizer = paddle.optimizer.Adam( @@ -823,12 +812,12 @@ def main(cfg: DictConfig) -> None: model_save_path = os.path.join(cfg.output, "models") param_save_path = os.path.join(cfg.output, "param") best_model_path = os.path.join(model_save_path, "best_model") - if dist.rank == 0: + if paddle.distributed.get_rank() == 0: create_directory(model_save_path) create_directory(param_save_path) create_directory(best_model_path) - if dist.world_size > 1: + if paddle.distributed.get_world_size() > 1: paddle.distributed.barrier() init_epoch = load_checkpoint( @@ -857,7 +846,7 @@ def main(cfg: DictConfig) -> None: for epoch in range(init_epoch, cfg.train.epochs): start_time = time.time() - print(f"Device {dist.device}, epoch {epoch_number}:") + print(f"Device {paddle.distributed.get_rank()}, epoch {epoch_number}:") train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) @@ -871,7 +860,7 @@ def main(cfg: DictConfig) -> None: optimizer=optimizer, scaler=scaler, epoch_index=epoch, - device=dist.device, + device=paddle.distributed.get_rank(), integral_scaling_factor=initial_integral_factor, loss_fn_type=cfg.model.loss_function, ) @@ -880,7 +869,7 @@ def main(cfg: DictConfig) -> None: avg_vloss = validation_step( dataloader=val_dataloader, model=model, - device=dist.device, + device=paddle.distributed.get_rank(), use_sdf_basis=cfg.model.use_sdf_in_basis_func, use_surface_normals=cfg.model.use_surface_normals, integral_scaling_factor=initial_integral_factor, @@ -889,14 +878,14 @@ def main(cfg: DictConfig) -> None: scheduler.step() print( - f"Device {dist.device} " + f"Device {paddle.distributed.get_rank()} " f"LOSS train {avg_loss:.5f} " f"valid {avg_vloss:.5f} " f"Current lr {scheduler.get_lr()}" f"Integral factor {initial_integral_factor}" ) - # if dist.rank == 0: + # if paddle.distributed.get_rank() == 0: # writer.add_scalars( # "Training vs. Validation Loss", # {"Training": avg_loss, "Validation": avg_vloss}, @@ -905,27 +894,19 @@ def main(cfg: DictConfig) -> None: # writer.flush() # Track best performance, and save the model's state - if dist.world_size > 1: + if paddle.distributed.get_world_size() > 1: paddle.distributed.barrier() if avg_vloss < best_vloss: # This only considers GPU: 0, is that okay? best_vloss = avg_vloss - # if dist.rank == 0: - save_checkpoint( - to_absolute_path(best_model_path), - models=model, - optimizer=optimizer, - scheduler=scheduler, - scaler=scaler, - epoch=str( - best_vloss.item() - ), # hacky way of using epoch to store metadata - ) print( - f"Device { dist.device}, Best val loss {best_vloss}, Time taken {time.time() - start_time}" + f"Device { paddle.distributed.get_rank()}, Best val loss {best_vloss}, Time taken {time.time() - start_time}" ) - if dist.rank == 0 and (epoch + 1) % cfg.train.checkpoint_interval == 0.0: + if ( + paddle.distributed.get_rank() == 0 + and (epoch + 1) % cfg.train.checkpoint_interval == 0.0 + ): save_checkpoint( to_absolute_path(model_save_path), models=model, diff --git a/ppsci/arch/physicsnemo/__init__.py b/ppsci/arch/physicsnemo/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ppsci/arch/physicsnemo/datapipes/__init__.py b/ppsci/arch/physicsnemo/datapipes/__init__.py deleted file mode 100644 index b2f171d4ac..0000000000 --- a/ppsci/arch/physicsnemo/datapipes/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/ppsci/arch/physicsnemo/distributed/__init__.py b/ppsci/arch/physicsnemo/distributed/__init__.py deleted file mode 100644 index 4bc2bff921..0000000000 --- a/ppsci/arch/physicsnemo/distributed/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .manager import DistributedManager # noqa: F401 -from .manager import ModulusUndefinedGroupError # noqa: F401 -from .manager import ModulusUninitializedDistributedManagerWarning # noqa: F401 diff --git a/ppsci/arch/physicsnemo/distributed/config.py b/ppsci/arch/physicsnemo/distributed/config.py deleted file mode 100644 index 831fed1aef..0000000000 --- a/ppsci/arch/physicsnemo/distributed/config.py +++ /dev/null @@ -1,250 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict -from typing import List -from typing import Optional -from typing import Union - -from treelib import Tree - - -class ProcessGroupNode: - """ - Class to store the attributes of a distributed process group - - Attributes - ---------- - name : str - Name of the process group - size : Optional[int] - Optional, number of processes in the process group - """ - - def __init__( - self, - name: str, - size: Optional[int] = None, - ): - """ - Constructor for the ProcessGroupNode class - - Parameters - ---------- - name : str - Name of the process group - size : Optional[int] - Optional, size of the process group - """ - self.name = name - self.size = size - - def __str__(self): - """ - String representation of the process group node - - Returns - ------- - str - String representation of the process group node - """ - return "ProcessGroupNode(" f"name={self.name}, " f"size={self.size}, " - - def __repr__(self): - """ - String representation of the process group node - - Returns - ------- - str - String representation of the process group node - """ - return self.__str__() - - -class ProcessGroupConfig: - """ - Class to define the configuration of a model's parallel process group structure as a - tree. Each node of the tree is of type `ProcessGroupNode`. - - Once the process group config structure (i.e, the tree structure) is set, it is - sufficient to set only the sizes for each leaf process group. Then, the size of - every parent group can be automatically computed as the product reduction of the - sub-tree of that parent group node. - - Examples - -------- - >>> from modulus.distributed import ProcessGroupNode, ProcessGroupConfig - >>> - >>> # Create world group that contains all processes that are part of this job - >>> world = ProcessGroupNode("world") - >>> - >>> # Create the process group config with the highest level process group - >>> config = ProcessGroupConfig(world) - >>> - >>> # Create model and data parallel sub-groups - >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction - >>> # Nodes can be added with either the name of the node or the node itself - >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world) - >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world") - >>> - >>> # Create spatial and channel parallel sub-groups - >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel") - >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel") - >>> - >>> config.leaf_groups() - ['data_parallel', 'spatial_parallel', 'channel_parallel'] - >>> - >>> # Set leaf group sizes - >>> # Note: product of all leaf-node sizes should be the world size - >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4} - >>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too - >>> config.get_node("model_parallel").size - 6 - """ - - def __init__(self, node: ProcessGroupNode): - """ - Constructor to the ProcessGroupConfig class - - Parameters - ---------- - node : ProcessGroupNode - Root node of the tree, typically would be 'world' - Note, it is generally recommended to set the child groups for 'world' - to 'model_parallel' and 'data_parallel' to aid with distributed - data parallel training unless there is a specific reason to choose a - different structure - """ - self.root = node - self.root_id = node.name - self.tree = Tree() - self.tree.create_node(node.name, node.name, data=node) - - def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]): - """ - Add a node to the process group config - - Parameters - ---------- - node : ProcessGroupNode - The new node to be added to the config - parent : Union[str, ProcessGroupNode] - Parent node of the node to be added. Should already be in the config. - If str, it is the name of the parent node. Otherwise, the parent - ProcessGroupNode itself. - """ - if isinstance(parent, ProcessGroupNode): - parent = parent.name - self.tree.create_node(node.name, node.name, data=node, parent=parent) - - def get_node(self, name: str) -> ProcessGroupNode: - """ - Method to get the node given the name of the node - - Parameters - ---------- - name : str - Name of the node to retrieve - - Returns - ------- - ProcessGroupNode - Node with the given name from the config - """ - return self.tree.get_node(name).data - - def update_parent_sizes(self, verbose: bool = False) -> int: - """ - Method to update parent node sizes after setting the sizes for each leaf node - - Parameters - ---------- - verbose : bool - If True, print a message each time a parent node size was updated - - Returns - ------- - int - Size of the root node - """ - return _tree_product_reduction(self.tree, self.root_id, verbose=verbose) - - def leaf_groups(self) -> List[str]: - """ - Get a list of all leaf group names - - Returns - ------- - List[str] - List of all leaf node names - """ - return [n.identifier for n in self.tree.leaves()] - - def set_leaf_group_sizes( - self, group_sizes: Dict[str, int], update_parent_sizes: bool = True - ): - """ - Set process group sizes for all leaf groups - - Parameters - ---------- - group_sizes : Dict[str, int] - Dictionary with a mapping of each leaf group name to its size - update_parent_sizes : bool - Update all parent group sizes based on the leaf group if True - If False, only set the leaf group sizes. - """ - for id, size in group_sizes.items(): - if not self.tree.contains(id): - raise AssertionError( - f"Process group {id} is not in this process group config" - ) - node = self.tree.get_node(id) - if not node.is_leaf(): - raise AssertionError(f"Process group {id} is not a leaf group") - node.data.size = size - - if update_parent_sizes: - self.update_parent_sizes() - - -def _tree_product_reduction(tree, node_id, verbose=False): - """ - Function to traverse a tree and compute the product reduction of - the sub-tree for each node starting from `node_id` - """ - children = tree.children(node_id) - node = tree.get_node(node_id) - if not children: - if node.data.size is None: - raise AssertionError("Leaf nodes should have a valid size set") - return node.data.size - - product = 1 - - for child in children: - product *= _tree_product_reduction(tree, child.identifier) - - if node.data.size != product: - if verbose: - print( - "Updating size of node " - f"{node.data.name} from {node.data.size} to {product}" - ) - node.data.size = product - - return product diff --git a/ppsci/arch/physicsnemo/distributed/manager.py b/ppsci/arch/physicsnemo/distributed/manager.py deleted file mode 100644 index e24aaa4103..0000000000 --- a/ppsci/arch/physicsnemo/distributed/manager.py +++ /dev/null @@ -1,609 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import queue -from typing import Optional -from warnings import warn - -import numpy as np -import paddle -import paddle.distributed as dist - -from .config import ProcessGroupConfig -from .config import ProcessGroupNode - - -class ModulusUndefinedGroupError(Exception): - """Exception for querying an undefined process group using the Modulus DistributedManager""" - - def __init__(self, name: str): - """ - - Parameters - ---------- - name : str - Name of the process group being queried. - - """ - message = ( - f"Cannot query process group '{name}' before it is explicitly created." - ) - super().__init__(message) - - -class ModulusUninitializedDistributedManagerWarning(Warning): - """Warning to indicate usage of an uninitialized DistributedManager""" - - def __init__(self): - message = ( - "A DistributedManager object is being instantiated before " - + "this singleton class has been initialized. Instantiating a manager before " - + "initialization can lead to unexpected results where processes fail " - + "to communicate. Initialize the distributed manager via " - + "DistributedManager.initialize() before instantiating." - ) - super().__init__(message) - - -class DistributedManager(object): - """Distributed Manager for setting up distributed training environment. - - This is a singleton that creates a persistance class instance for storing parallel - environment information through out the life time of the program. This should be - used to help set up Distributed Data Parallel and parallel datapipes. - - Note - ---- - One should call `DistributedManager.initialize()` prior to constructing a manager - object - - Example - ------- - >>> DistributedManager.initialize() - >>> manager = DistributedManager() - >>> manager.rank - 0 - >>> manager.world_size - 1 - """ - - _shared_state = {} - - def __new__(cls): - obj = super(DistributedManager, cls).__new__(cls) - obj.__dict__ = cls._shared_state - - # Set the defaults - if not hasattr(obj, "_rank"): - obj._rank = 0 - if not hasattr(obj, "_world_size"): - obj._world_size = 1 - if not hasattr(obj, "_local_rank"): - obj._local_rank = 0 - if not hasattr(obj, "_distributed"): - obj._distributed = False - if not hasattr(obj, "_device"): - obj._device = "gpu:0" if paddle.device.cuda.device_count() else "cpu" - if not hasattr(obj, "_cuda"): - obj._cuda = paddle.device.cuda.device_count() >= 1 - if not hasattr(obj, "_broadcast_buffers"): - obj._broadcast_buffers = False - if not hasattr(obj, "_find_unused_parameters"): - obj._find_unused_parameters = False - if not hasattr(obj, "_initialization_method"): - obj._initialization_method = "None" - if not hasattr(obj, "_groups"): - obj._groups = {} - if not hasattr(obj, "_group_ranks"): - obj._group_ranks = {} - if not hasattr(obj, "_group_names"): - obj._group_names = {} - if not hasattr(obj, "_is_initialized"): - obj._is_initialized = False - - return obj - - def __init__(self): - if not self._is_initialized: - raise ModulusUninitializedDistributedManagerWarning() - super().__init__() - - @property - def rank(self): - """Process rank""" - return self._rank - - @property - def local_rank(self): - """Process rank on local machine""" - return self._local_rank - - @property - def world_size(self): - """Number of processes in distributed enviroment""" - return self._world_size - - @property - def device(self): - """Process device""" - return self._device - - @property - def distributed(self): - """Distributed enviroment""" - return self._distributed - - @property - def cuda(self): - """If cuda is available""" - return self._cuda - - @property - def group_names(self): - """ - Returns a list of all named process groups created - """ - return self._groups.keys() - - def group(self, name=None): - """ - Returns a process group with the given name - If name is None, group is also None indicating the default process group - If named group does not exist, ModulusUndefinedGroupError exception is raised - """ - if name in self._groups.keys(): - return self._groups[name] - elif name is None: - return None - else: - raise ModulusUndefinedGroupError(name) - - def group_size(self, name=None): - """ - Returns the size of named process group - """ - if name is None: - return self._world_size - group = self.group(name) - return dist.get_world_size(group=group) - - def group_rank(self, name=None): - """ - Returns the rank in named process group - """ - if name is None: - return self._rank - group = self.group(name) - return dist.get_rank(group=group) - - def group_name(self, group=None): - """ - Returns the name of process group - """ - if group is None: - return None - return self._group_names[group] - - @property - def broadcast_buffers(self): - """broadcast_buffers in DDP""" - return self._broadcast_buffers - - @broadcast_buffers.setter - def broadcast_buffers(self, broadcast: bool): - """Setter for broadcast_buffers""" - self._broadcast_buffers = broadcast - - @property - def find_unused_parameters(self): - """find_unused_parameters in DDP""" - return self._find_unused_parameters - - @find_unused_parameters.setter - def find_unused_parameters(self, find_params: bool): - """Setter for find_unused_parameters""" - if find_params: - warn( - "Setting `find_unused_parameters` in DDP to true, " - "use only if necessary." - ) - self._find_unused_parameters = find_params - - def __str__(self): - output = ( - f"Initialized process {self.rank} of {self.world_size} using " - f"method '{self._initialization_method}'. Device set to {str(self.device)}" - ) - return output - - @classmethod - def is_initialized(cls) -> bool: - """If manager singleton has been initialized""" - return cls._shared_state.get("_is_initialized", False) - - @staticmethod - def get_available_backend(): - """Get communication backend""" - if ( - paddle.device.cuda.device_count() >= 1 - and paddle.core.is_compiled_with_nccl() - ): - return "nccl" - else: - return "gloo" - - @staticmethod - def initialize_env(): - """Setup method using generic initialization""" - rank = int(os.environ.get("RANK")) - world_size = int(os.environ.get("WORLD_SIZE")) - if "LOCAL_RANK" in os.environ: - local_rank = os.environ.get("LOCAL_RANK") - if local_rank is not None: - local_rank = int(local_rank) - else: - local_rank = rank % paddle.device.cuda.device_count() - - else: - local_rank = rank % paddle.device.cuda.device_count() - - # Read env variables - addr = os.environ.get("MASTER_ADDR") - port = os.environ.get("MASTER_PORT") - - DistributedManager.setup( - rank=rank, - world_size=world_size, - local_rank=local_rank, - addr=addr, - port=port, - backend=DistributedManager.get_available_backend(), - ) - - @staticmethod - def initialize_open_mpi(addr, port): - """Setup method using OpenMPI initialization""" - rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) - local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) - - DistributedManager.setup( - rank=rank, - world_size=world_size, - local_rank=local_rank, - addr=addr, - port=port, - backend=DistributedManager.get_available_backend(), - method="openmpi", - ) - - @staticmethod - def initialize_slurm(port): - """Setup method using SLURM initialization""" - rank = int(os.environ.get("SLURM_PROCID")) - world_size = int(os.environ.get("SLURM_NPROCS")) - local_rank = int(os.environ.get("SLURM_LOCALID")) - addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") - - DistributedManager.setup( - rank=rank, - world_size=world_size, - local_rank=local_rank, - addr=addr, - port=port, - backend=DistributedManager.get_available_backend(), - method="slurm", - ) - - @staticmethod - def initialize(): - """ - Initialize distributed manager - - Current supported initialization methods are: - `ENV`: Environment variable initialization - `SLURM`: Initialization on SLURM systems. - Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and - `SLURM_LAUNCH_NODE_IPADDR` environment variables. - `OPENMPI`: Initialization for OpenMPI launchers. - Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and - `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. - - Initialization by default is done using the first valid method in the order - listed above. Initialization method can also be explicitly controlled using the - `MODULUS_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it - to one of the options above. - """ - if DistributedManager.is_initialized(): - warn("Distributed manager is already intialized") - return - - addr = os.getenv("MASTER_ADDR", "localhost") - port = os.getenv("MASTER_PORT", "12355") - os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" - initialization_method = os.getenv("MODULUS_DISTRIBUTED_INITIALIZATION_METHOD") - if initialization_method is None: - try: - DistributedManager.initialize_env() - except TypeError: - if "SLURM_PROCID" in os.environ: - DistributedManager.initialize_slurm(port) - elif "OMPI_COMM_WORLD_RANK" in os.environ: - DistributedManager.initialize_open_mpi(addr, port) - else: - warn( - "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" - ) - DistributedManager._shared_state["_is_initialized"] = True - elif initialization_method == "ENV": - DistributedManager.initialize_env() - elif initialization_method == "SLURM": - DistributedManager.initialize_slurm(port) - elif initialization_method == "OPENMPI": - DistributedManager.initialize_open_mpi(addr, port) - else: - raise RuntimeError( - "Unknown initialization method " - f"{initialization_method}. " - "Supported values for " - "MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are " - "ENV, SLURM and OPENMPI" - ) - - # Set per rank numpy random seed for data sampling - np.random.seed(seed=DistributedManager().rank) - - @staticmethod - def setup( - rank=0, - world_size=1, - local_rank=None, - addr="localhost", - port="12355", - backend="nccl", - method="env", - ): - """Set up distributed process group and update manager attributes""" - os.environ["MASTER_ADDR"] = addr - os.environ["MASTER_PORT"] = str(port) - - DistributedManager._shared_state["_is_initialized"] = True - manager = DistributedManager() - - manager._distributed = paddle.distributed.is_available() - if manager._distributed: - # Update rank and world_size if using distributed - manager._rank = rank - manager._world_size = world_size - if local_rank is None: - manager._local_rank = rank % paddle.device.cuda.device_count() - else: - manager._local_rank = local_rank - - manager._device = ( - f"gpu:{manager.local_rank}" - if paddle.device.cuda.device_count() >= 1 - else "cpu" - ) - - if manager._distributed: - # Setup distributed process group - dist.init_process_group() - - if paddle.device.cuda.device_count() >= 0: - # Set device for this process and empty cache to optimize memory usage - paddle.set_device(manager.device) - paddle.device.cuda.empty_cache() - - manager._initialization_method = method - - @staticmethod - def create_process_subgroup( - name: str, size: int, group_name: Optional[str] = None, verbose: bool = False - ): # pragma: no cover - """ - Create a process subgroup of a parent process group. This must be a collective - call by all processes participating in this application. - - Parameters - ---------- - name : str - Name of the process subgroup to be created. - - size : int - Size of the process subgroup to be created. This must be an integer factor of - the parent group's size. - - group_name : Optional[str] - Name of the parent process group, optional. If None, the default process group - will be used. Default None. - - verbose : bool - Print out ranks of each created process group, default False. - - """ - manager = DistributedManager() - if not manager.distributed: - raise AssertionError( - "paddle.distributed is unavailable. " - "Check paddle build to ensure the distributed package is available. " - "If building from source, set `USE_DISTRIBUTED=1` " - "to enable the distributed package" - ) - - if name in manager._groups: - raise AssertionError(f"Group with name {name} already exists") - - # Get parent group's params - group = manager._groups[group_name] if group_name else None - group_size = dist.get_world_size(group=group) - num_groups = manager.world_size // group_size - - # Get number of sub-groups per parent group - if group_size % size != 0: - raise AssertionError( - f"Cannot divide group size {group_size} evenly into subgroups of" - f" size {size}" - ) - num_subgroups = group_size // size - - # Create all the sub-groups - # Note: all ranks in the job need to create all sub-groups in - # the same order even if a rank is not part of a sub-group - manager._group_ranks[name] = [] - for g in range(num_groups): - for i in range(num_subgroups): - # Get global ranks that are part of this sub-group - start = i * size - end = start + size - if group_name: - ranks = manager._group_ranks[group_name][g][start:end] - else: - ranks = list(range(start, end)) - # Create sub-group and keep track of ranks - tmp_group = dist.new_group(ranks=ranks) - manager._group_ranks[name].append(ranks) - if manager.rank in ranks: - # Set group in manager only if this rank is part of the group - manager._groups[name] = tmp_group - manager._group_names[tmp_group] = name - - if verbose and manager.rank == 0: - print(f"Process group '{name}':") - for grp in manager._group_ranks[name]: - print(" ", grp) - - @staticmethod - def create_orthogonal_process_group( - orthogonal_group_name: str, group_name: str, verbose: bool = False - ): # pragma: no cover - """ - Create a process group that is orthogonal to the specified process group. - - Parameters - ---------- - orthogonal_group_name : str - Name of the orthogonal process group to be created. - - group_name : str - Name of the existing process group. - - verbose : bool - Print out ranks of each created process group, default False. - - """ - manager = DistributedManager() - if not manager.distributed: - raise AssertionError( - "paddle.distributed is unavailable. " - "Check paddle build to ensure the distributed package is available. " - "If building from source, set `USE_DISTRIBUTED=1` " - "to enable the distributed package" - ) - - if group_name not in manager._groups: - raise ValueError(f"Group with name {group_name} does not exist") - if orthogonal_group_name in manager._groups: - raise ValueError(f"Group with name {orthogonal_group_name} already exists") - - group_ranks = manager._group_ranks[group_name] - orthogonal_ranks = [list(i) for i in zip(*group_ranks)] - - for ranks in orthogonal_ranks: - tmp_group = dist.new_group(ranks=ranks) - if manager.rank in ranks: - # Set group in manager only if this rank is part of the group - manager._groups[orthogonal_group_name] = tmp_group - manager._group_names[tmp_group] = orthogonal_group_name - - manager._group_ranks[orthogonal_group_name] = orthogonal_ranks - - if verbose and manager.rank == 0: - print(f"Process group '{orthogonal_group_name}':") - for grp in manager._group_ranks[orthogonal_group_name]: - print(" ", grp) - - @staticmethod - def create_group_from_node( - node: ProcessGroupNode, - parent: Optional[str] = None, - verbose: bool = False, - ): # pragma: no cover - if node.size is None: - raise AssertionError( - "Cannot create groups from a ProcessGroupNode that is not fully" - " populated. Ensure that config.set_leaf_group_sizes is called first" - " with `update_parent_sizes = True`" - ) - - DistributedManager.create_process_subgroup( - node.name, node.size, group_name=parent, verbose=verbose - ) - # Create orthogonal process group - orthogonal_group = f"__orthogonal_to_{node.name}" - DistributedManager.create_orthogonal_process_group( - orthogonal_group, node.name, verbose=verbose - ) - return orthogonal_group - - @staticmethod - def create_groups_from_config( - config: ProcessGroupConfig, verbose: bool = False - ): # pragma: no cover - # Traverse process group tree in breadth first order - # to create nested process groups - q = queue.Queue() - q.put(config.root_id) - DistributedManager.create_group_from_node(config.root) - - while not q.empty(): - node_id = q.get() - if verbose: - print(f"Node ID: {node_id}") - - children = config.tree.children(node_id) - if verbose: - print(f" Children: {children}") - - parent_group = node_id - for child in children: - # Create child group and replace parent group by orthogonal group so - # that each child forms an independent block of processes - parent_group = DistributedManager.create_group_from_node( - child.data, - parent=parent_group, - ) - - # Add child ids to the queue - q.put(child.identifier) - - @staticmethod - def cleanup(): - """Clean up distributed group and singleton""" - # Destroying group.WORLD is enough for all process groups to get destroyed - if ( - "_is_initialized" in DistributedManager._shared_state - and DistributedManager._shared_state["_is_initialized"] - and "_distributed" in DistributedManager._shared_state - and DistributedManager._shared_state["_distributed"] - ): - if paddle.device.cuda.device_count() >= 1: - dist.barrier() - else: - dist.barrier() - dist.destroy_process_group() - DistributedManager._shared_state = {} diff --git a/ppsci/arch/physicsnemo/launch/__init__.py b/ppsci/arch/physicsnemo/launch/__init__.py deleted file mode 100644 index b2f171d4ac..0000000000 --- a/ppsci/arch/physicsnemo/launch/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/ppsci/arch/physicsnemo/launch/utils/__init__.py b/ppsci/arch/physicsnemo/launch/utils/__init__.py deleted file mode 100644 index ddd9be3cdf..0000000000 --- a/ppsci/arch/physicsnemo/launch/utils/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .checkpoint import load_checkpoint # noqa: F401 -from .checkpoint import save_checkpoint # noqa: F401 diff --git a/ppsci/arch/physicsnemo/launch/utils/checkpoint.py b/ppsci/arch/physicsnemo/launch/utils/checkpoint.py deleted file mode 100644 index 4352cebdd4..0000000000 --- a/ppsci/arch/physicsnemo/launch/utils/checkpoint.py +++ /dev/null @@ -1,378 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import glob -import re -from pathlib import Path -from typing import Any -from typing import Dict -from typing import List -from typing import NewType -from typing import Optional -from typing import Union - -import paddle -from paddle.amp import GradScaler -from paddle.optimizer.lr import LRScheduler - -from ppsci.utils import logger - -from ...distributed import DistributedManager - -optimizer = NewType("optimizer", paddle.optimizer) -scheduler = NewType("scheduler", LRScheduler) -scaler = NewType("scaler", GradScaler) - - -def _get_checkpoint_filename( - path: str, - base_name: str = "checkpoint", - index: Union[int, None] = None, - saving: bool = False, - model_type: str = "mdlus", -) -> str: - """Gets the file name /path of checkpoint - - This function has three different ways of providing a checkout filename: - - If supplied an index this will return the checkpoint name using that index. - - If index is None and saving is false, this will get the checkpoint with the - largest index (latest save). - - If index is None and saving is true, it will return the next valid index file name - which is calculated by indexing the largest checkpoint index found by one. - - Parameters - ---------- - path : str - Path to checkpoints - base_name: str, optional - Base file name, by default checkpoint - index : Union[int, None], optional - Checkpoint index, by default None - saving : bool, optional - Get filename for saving a new checkpoint, by default False - model_type : str - Model type, by default "mdlus" for Modulus models and "pdparams" for models - - - Returns - ------- - str - Checkpoint file name - """ - # Get model parallel rank so all processes in the first model parallel group - # can save their checkpoint. In the case without model parallelism, - # model_parallel_rank should be the same as the process rank itself and - # only rank 0 saves - if not DistributedManager.is_initialized(): - logger.warning( - "`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors" - ) - DistributedManager.initialize() - manager = DistributedManager() - model_parallel_rank = ( - manager.group_rank("model_parallel") - if "model_parallel" in manager.group_names - else 0 - ) - - # Input file name - checkpoint_filename = str( - Path(path).resolve() / f"{base_name}.{model_parallel_rank}" - ) - - # File extension for Modulus models or PaddlePaddle models - file_extension = ".pdparams" - - # If epoch is provided load that file - if index is not None: - checkpoint_filename = checkpoint_filename + f".{index}" - checkpoint_filename += file_extension - # Otherwise try loading the latest epoch or rolling checkpoint - else: - file_names = [ - Path(fname).name - for fname in glob.glob( - checkpoint_filename + "*" + file_extension, recursive=False - ) - ] - - if len(file_names) > 0: - # If checkpoint from a null index save exists load that - # This is the most likely line to error since it will fail with - # invalid checkpoint names - file_idx = [ - int( - re.sub( - f"^{base_name}.{model_parallel_rank}.|" + file_extension, - "", - fname, - ) - ) - for fname in file_names - ] - file_idx.sort() - # If we are saving index by 1 to get the next free file name - if saving: - checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" - else: - checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" - checkpoint_filename += file_extension - else: - checkpoint_filename += ".0" + file_extension - - return checkpoint_filename - - -def _unique_model_names( - models: List[paddle.nn.Layer], -) -> Dict[str, paddle.nn.Layer]: - """Util to clean model names and index if repeat names, will also strip DDP wrappers - if they exist. - - Parameters - ---------- - model : List[paddle.nn.Layer] - List of models to generate names for - - Returns - ------- - Dict[str, paddle.nn.Layer] - Dictionary of model names and respective modules - """ - # Loop through provided models and set up base names - model_dict = {} - for model0 in models: - if hasattr(model0, "module"): - # Strip out DDP layer - model0 = model0.module - # Base name of model is meta.name unless paddle model - base_name = model0.__class__.__name__ - # if isinstance(model0, modulus): - # base_name = model0.meta.name - # If we have multiple models of the same name, introduce another index - if base_name in model_dict: - model_dict[base_name].append(model0) - else: - model_dict[base_name] = [model0] - - # Set up unique model names if needed - output_dict = {} - for key, model in model_dict.items(): - if len(model) > 1: - for i, model0 in enumerate(model): - output_dict[key + str(i)] = model0 - else: - output_dict[key] = model[0] - - return output_dict - - -def save_checkpoint( - path: str, - models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, - optimizer: Union[optimizer, None] = None, - scheduler: Union[scheduler, None] = None, - scaler: Union[scaler, None] = None, - epoch: Union[int, None] = None, - metadata: Optional[Dict[str, Any]] = None, -) -> None: - """Training checkpoint saving utility - - This will save a training checkpoint in the provided path following the file naming - convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint - method in Modulus core can then be used to read this file. - - Parameters - ---------- - path : str - Path to save the training checkpoint - models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional - A single or list of PaddlePaddle models, by default None - optimizer : Union[optimizer, None], optional - Optimizer, by default None - scheduler : Union[scheduler, None], optional - Learning rate scheduler, by default None - scaler : Union[scaler, None], optional - AMP grad scaler. Will attempt to save on in static capture if none provided, by - default None - epoch : Union[int, None], optional - Epoch checkpoint to load. If none this will save the checkpoint in the next - valid index, by default None - metadata : Optional[Dict[str, Any]], optional - Additional metadata to save, by default None - """ - # Create checkpoint directory if it does not exist - if not Path(path).is_dir(): - logger.warning( - f"Output directory {path} does not exist, will " "attempt to create" - ) - Path(path).mkdir(parents=True, exist_ok=True) - - # == Saving model checkpoint == - if models: - if not isinstance(models, list): - models = [models] - models = _unique_model_names(models) - for name, model in models.items(): - # Get model type - model_type = "pdparams" - - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, saving=True, model_type=model_type - ) - - # Save state dictionary - paddle.save(model.state_dict(), file_name) - logger.info(f"Saved model state dictionary: {file_name}") - - # == Saving training checkpoint == - checkpoint_dict = {} - # Optimizer state dict - if optimizer: - checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() - - # Scheduler state dict - if scheduler: - checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() - - # Scheduler state dict - if scaler: - checkpoint_dict["scaler_state_dict"] = scaler.state_dict() - # Static capture is being used, save its grad scaler - # if _StaticCapture._amp_scalers: - # checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() - - # Output file name - output_filename = _get_checkpoint_filename( - path, index=epoch, saving=True, model_type="pdparams" - ) - if epoch: - checkpoint_dict["epoch"] = epoch - if metadata: - checkpoint_dict["metadata"] = metadata - # Save checkpoint to memory - if bool(checkpoint_dict): - paddle.save( - checkpoint_dict, - output_filename, - ) - logger.info(f"Saved training checkpoint: {output_filename}") - - -def load_checkpoint( - path: str, - models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, - optimizer: Union[optimizer, None] = None, - scheduler: Union[scheduler, None] = None, - scaler: Union[scaler, None] = None, - epoch: Union[int, None] = None, - metadata_dict: Optional[Dict[str, Any]] = {}, -) -> int: - """Checkpoint loading utility - - This loader is designed to be used with the save checkpoint utility in Modulus - Launch. Given a path, this method will try to find a checkpoint and load state - dictionaries into the provided training objects. - - Parameters - ---------- - path : str - Path to training checkpoint - models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional - A single or list of models, by default None - optimizer : Union[optimizer, None], optional - Optimizer, by default None - scheduler : Union[scheduler, None], optional - Learning rate scheduler, by default None - scaler : Union[scaler, None], optional - AMP grad scaler, by default None - epoch : Union[int, None], optional - Epoch checkpoint to load. If none is provided this will attempt to load the - checkpoint with the largest index, by default None - metadata_dict: Optional[Dict[str, Any]], optional - Dictionary to store metadata from the checkpoint, by default None - - Returns - ------- - int - Loaded epoch - """ - # Check if checkpoint directory exists - if not Path(path).is_dir(): - logger.warning( - f"Provided checkpoint directory {path} does not exist, skipping load" - ) - return 0 - - # == Loading model checkpoint == - if models: - if not isinstance(models, list): - models = [models] - models = _unique_model_names(models) - for name, model in models.items(): - # Get model type - model_type = "pdparams" - - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, model_type=model_type - ) - if not Path(file_name).exists(): - logger.error( - f"Could not find valid model file {file_name}, skipping load" - ) - continue - # Load state dictionary - model.set_state_dict(paddle.load(file_name)) - - logger.info(f"Loaded model state dictionary {file_name}") - - # == Loading training checkpoint == - checkpoint_filename = _get_checkpoint_filename( - path, index=epoch, model_type="pdparams" - ) - if not Path(checkpoint_filename).is_file(): - logger.warning("Could not find valid checkpoint file, skipping load") - return 0 - - checkpoint_dict = paddle.load(checkpoint_filename) - logger.info(f"Loaded checkpoint file {checkpoint_filename}") - - # Optimizer state dict - if optimizer and "optimizer_state_dict" in checkpoint_dict: - optimizer.set_state_dict(checkpoint_dict["optimizer_state_dict"]) - logger.info("Loaded optimizer state dictionary") - - # Scheduler state dict - if scheduler and "scheduler_state_dict" in checkpoint_dict: - scheduler.set_state_dict(checkpoint_dict["scheduler_state_dict"]) - logger.info("Loaded scheduler state dictionary") - - # Scaler state dict - if scaler and "scaler_state_dict" in checkpoint_dict: - scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) - logger.info("Loaded grad scaler state dictionary") - - epoch = 0 - if "epoch" in checkpoint_dict: - epoch = checkpoint_dict["epoch"] - # Update metadata if exists and the dictionary object is provided - metadata = checkpoint_dict.get("metadata", {}) - for key, value in metadata.items(): - metadata_dict[key] = value - - return epoch diff --git a/ppsci/arch/physicsnemo/models/__init__.py b/ppsci/arch/physicsnemo/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ppsci/arch/physicsnemo/models/layers/__init__.py b/ppsci/arch/physicsnemo/models/layers/__init__.py deleted file mode 100644 index b2f171d4ac..0000000000 --- a/ppsci/arch/physicsnemo/models/layers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/ppsci/arch/physicsnemo/models/layers/ball_query.py b/ppsci/arch/physicsnemo/models/layers/ball_query.py deleted file mode 100644 index 58f5cdf594..0000000000 --- a/ppsci/arch/physicsnemo/models/layers/ball_query.py +++ /dev/null @@ -1,322 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import paddle -import warp as wp - - -class BallQuery(paddle.autograd.PyLayer): - """ - Warp based Ball Query. - """ - - @wp.kernel - def ball_query( - points1: wp.array(dtype=wp.vec3), - points2: wp.array(dtype=wp.vec3), - grid: wp.uint64, - k: wp.int32, - radius: wp.float32, - mapping: wp.array3d(dtype=wp.int32), - num_neighbors: wp.array2d(dtype=wp.int32), - ): - - # Get index of point1 - tid = wp.tid() - - # Get position from points1 - pos = points1[tid] - - # particle contact - neighbors = wp.hash_grid_query(grid, pos, radius) - - # Keep track of the number of neighbors found - nr_found = wp.int32(0) - - # loop through neighbors to compute density - for index in neighbors: - # Check if outside the radius - pos2 = points2[index] - if wp.length(pos - pos2) > radius: - continue - - # Add neighbor to the list - mapping[0, tid, nr_found] = index - - # Increment the number of neighbors found - nr_found += 1 - - # Break if we have found enough neighbors - if nr_found == k: - num_neighbors[0, tid] = k - break - - # Set the number of neighbors - num_neighbors[0, tid] = nr_found - - @wp.kernel - def sparse_ball_query( - points2: wp.array(dtype=wp.vec3), - mapping: wp.array3d(dtype=wp.int32), - num_neighbors: wp.array2d(dtype=wp.int32), - outputs: wp.array4d(dtype=wp.float32), - ): - # Get index of point1 - p1 = wp.tid() - - # Get number of neighbors - k = num_neighbors[0, p1] - - # Loop through neighbors - for _k in range(k): - # Get point2 index - index = mapping[0, p1, _k] - - # Get position from points2 - pos = points2[index] - - # Set the output - outputs[0, p1, _k, 0] = pos[0] - outputs[0, p1, _k, 1] = pos[1] - outputs[0, p1, _k, 2] = pos[2] - - @staticmethod - def forward( - ctx, - points1, - points2, - lengths1, - lengths2, - k, - radius, - hash_grid, - ): - # Only works for batch size 1 - if points1.shape[0] != 1: - raise AssertionError("nly works for batch size 1") - - # Convert from paddle to warp - ctx.points1 = wp.from_paddle( - points1[0], dtype=wp.vec3, requires_grad=points1.stop_gradient - ) - ctx.points2 = wp.from_paddle( - points2[0], dtype=wp.vec3, requires_grad=points2.stop_gradient - ) - ctx.lengths1 = wp.from_paddle(lengths1, dtype=wp.int32, requires_grad=False) - ctx.lengths2 = wp.from_paddle(lengths2, dtype=wp.int32, requires_grad=False) - ctx.k = k - ctx.radius = radius - - # Allocate the mapping and outputs - mapping = paddle.zeros([1, points1.shape[1], k], dtype=paddle.int32) - mapping.stop_gradient = False - ctx.mapping = wp.from_paddle(mapping, dtype=wp.int32, requires_grad=False) - num_neighbors = paddle.zeros([1, points1.shape[1]], dtype=paddle.int32) - num_neighbors.stop_gradient = False - ctx.num_neighbors = wp.from_paddle( - num_neighbors, dtype=wp.int32, requires_grad=False - ) - outputs = paddle.zeros([1, points1.shape[1], k, 3], dtype=paddle.float32) - outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient - ctx.outputs = wp.from_paddle(outputs, dtype=wp.float32) - outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient - - # Make grid - ctx.hash_grid = hash_grid - - # Build the grid - ctx.hash_grid.build(ctx.points2, radius) - - # Run the kernel to get mapping - wp.launch( - BallQuery.ball_query, - inputs=[ - ctx.points1, - ctx.points2, - ctx.hash_grid.id, - k, - radius, - ], - outputs=[ - ctx.mapping, - ctx.num_neighbors, - ], - dim=[ctx.points1.shape[0]], - ) - - # Run the kernel to get outputs - wp.launch( - BallQuery.sparse_ball_query, - inputs=[ - ctx.points2, - ctx.mapping, - ctx.num_neighbors, - ], - outputs=[ - ctx.outputs, - ], - dim=[ctx.points1.shape[0]], - ) - - return ( - wp.to_paddle(ctx.mapping), - wp.to_paddle(ctx.num_neighbors), - wp.to_paddle(ctx.outputs), - ) - - @staticmethod - def backward(ctx, grad_mapping, grad_num_neighbors, grad_outputs): - # Map incoming paddle grads to our output variable - ctx.outputs.grad = wp.from_paddle(grad_outputs, dtype=wp.float32) - - # Run the kernel in adjoint mode - wp.launch( - BallQuery.sparse_ball_query, - inputs=[ - ctx.points2, - ctx.mapping, - ctx.num_neighbors, - ], - outputs=[ - ctx.outputs, - ], - adj_inputs=[ctx.points2.grad, ctx.mapping.grad, ctx.num_neighbors.grad], - adj_outputs=[ - ctx.outputs.grad, - ], - dim=[ctx.points1.shape[0]], - adjoint=True, - ) - - # Return the gradients - return ( - wp.to_paddle(ctx.points1.grad).unsqueeze(0), - wp.to_paddle(ctx.points2.grad).unsqueeze(0), - None, - None, - None, - None, - None, - ) - - -class BallQueryLayer(paddle.nn.Layer): - """ - Paddle layer for differentiable and accelerated Ball Query - operation using Warp. - Args: - k (int): Number of neighbors. - radius (float): Radius of influence. - grid_size (int): Uniform grid resolution - """ - - def __init__(self, k, radius, grid_size=32): - super().__init__() - wp.init() - self.k = k - self.radius = radius - self.hash_grid = wp.HashGrid(grid_size, grid_size, grid_size) - - def forward(self, points1, points2, lengths1, lengths2): - return BallQuery.apply( - points1, - points2, - lengths1, - lengths2, - self.k, - self.radius, - self.hash_grid, - ) - - -if __name__ == "__main__": - # Make function for saving point clouds - import pyvista as pv - - def save_point_cloud(points, name): - cloud = pv.PolyData(points.detach().cpu().numpy()) - cloud.save(name) - - # Check forward pass - # Initialize tensors - n = 1 # number of point clouds - p1 = 128000 # 100000 # number of points in point cloud 1 - d = 3 # dimension of the points - p2 = 39321 # 100000 # number of points in point cloud 2 - points1 = paddle.rand([n, p1, d]) - points1.stop_gradient = False - - points2 = paddle.rand([n, p2, d]) - points2.stop_gradient = False - lengths1 = paddle.full((n,), p1, dtype=paddle.int32) - lengths2 = paddle.full((n,), p2, dtype=paddle.int32) - k = 256 # maximum number of neighbors - radius = 0.1 - - # Make ball query layer - layer = BallQueryLayer(k, radius) - - # Make ball query - with wp.ScopedTimer("ball query", active=True): - mapping, num_neighbors, outputs = layer( - points1, - points2, - lengths1, - lengths2, - ) - - for i in range(2): - p1 += 100 - p2 += 100 - points1 = paddle.rand([n, p1, d]) - points1.stop_gradient = True - points2 = paddle.rand([n, p2, d]) - points2.stop_gradient = True - lengths1 = paddle.full((n,), p1, dtype=paddle.int32) - lengths2 = paddle.full((n,), p2, dtype=paddle.int32) - - mapping, num_neighbors, outputs = layer( - points1, - points2, - lengths1, - lengths2, - ) - - # Perform matrix multiplication as comparison for timing - with wp.ScopedTimer("matrix multiplication 256", active=True): - outputs2 = paddle.matmul(points1[0], paddle.ones([3, k])) - - # Save the point clouds - save_point_cloud(points1[0], "point1.vtk") - save_point_cloud(points2[0], "point2.vtk") - save_point_cloud(outputs[0].reshape([-1, 3]), "outputs.vtk") - - # Optimize the background points to move to the query points - optimizer = paddle.optimizer.SGD(parameters=[points2], learning_rate=0.01) - - # Test optimization - for i in range(2): - optimizer.clear_gradients() - mapping, num_neighbors, outputs = layer(points1, points2, lengths1, lengths2) - - loss = (points1.unsqueeze(2) - outputs).pow(2).sum() - loss.backward() - optimizer.step() - - # Save the point clouds - save_point_cloud(points1[0], "point1_{}.vtk".format(i)) - save_point_cloud(outputs[0].reshape([-1, 3]), "outputs_{}.vtk".format(i)) diff --git a/ppsci/arch/physicsnemo/models/model.py b/ppsci/arch/physicsnemo/models/model.py deleted file mode 100644 index 3372e8ecb0..0000000000 --- a/ppsci/arch/physicsnemo/models/model.py +++ /dev/null @@ -1,1178 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This code contains the DoMINO model architecture. -The DoMINO class contains an architecture to model both surface and -volume quantities together as well as separately (controlled using -the config.yaml file) -""" - -# from dataclasses import dataclass -import math - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F - -import ppsci - -from .layers.ball_query import BallQueryLayer - -# from modulus.models.meta import ModelMetaData -# from modulus.models.module import Module - - -def kaiming_init(layer): - if isinstance(layer, (nn.layer.conv._ConvNd, nn.Linear)): - print(f"layer: {layer} ") - init_kaimingUniform = paddle.nn.initializer.KaimingUniform( - nonlinearity="leaky_relu", negative_slope=math.sqrt(5) - ) - init_kaimingUniform(layer.weight) - if layer.bias is not None: - fan_in, _ = ppsci.utils.initializer._calculate_fan_in_and_fan_out( - layer.weight - ) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - init_uniform = paddle.nn.initializer.Uniform(low=-bound, high=bound) - init_uniform(layer.bias) - - -def calculate_pos_encoding(nx, d=8): - """Function to caluculate positional encoding""" - vec = [] - for k in range(int(d / 2)): - vec.append(paddle.sin(nx / 10000 ** (2 * (k) / d))) - vec.append(paddle.cos(nx / 10000 ** (2 * (k) / d))) - return vec - - -def scale_sdf(sdf): - """Function to scale SDF""" - return sdf / (0.4 + abs(sdf)) - - -def calculate_gradient(sdf): - """Function to calculate the gradients of SDF""" - m, n, o = sdf.shape[2], sdf.shape[3], sdf.shape[4] - sdf_x = sdf[:, :, 2:m, :, :] - sdf[:, :, 0 : m - 2, :, :] - sdf_y = sdf[:, :, :, 2:n, :] - sdf[:, :, :, 0 : n - 2, :] - sdf_z = sdf[:, :, :, :, 2:o] - sdf[:, :, :, :, 0 : o - 2] - - sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 0, 1), mode="constant", value=0.0) - sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 1, 0), mode="constant", value=0.0) - sdf_y = F.pad(x=sdf_y, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=0.0) - sdf_y = F.pad(x=sdf_y, pad=(0, 0, 1, 0, 0, 0), mode="constant", value=0.0) - sdf_z = F.pad(x=sdf_z, pad=(0, 1, 0, 0, 0, 0), mode="constant", value=0.0) - sdf_z = F.pad(x=sdf_z, pad=(1, 0, 0, 0, 0, 0), mode="constant", value=0.0) - - return sdf_x, sdf_y, sdf_z - - -def binarize_sdf(sdf): - """Function to calculate the binarize the SDF""" - sdf = paddle.where(sdf >= 0, 0.0, 1.0).to(dtype=sdf.dtype) - return sdf - - -class BQWarp(nn.Layer): - """Warp based ball-query layer""" - - def __init__( - self, - input_features, - grid_resolution=[256, 96, 64], - radius=0.25, - neighbors_in_radius=10, - ): - super().__init__() - self.ball_query_layer = BallQueryLayer(neighbors_in_radius, radius) - self.grid_resolution = grid_resolution - - def forward(self, x, p_grid, reverse_mapping=True): - batch_size = x.shape[0] - nx, ny, nz = ( - self.grid_resolution[0], - self.grid_resolution[1], - self.grid_resolution[2], - ) - - p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) - p1 = nx * ny * nz - p2 = x.shape[1] - - if reverse_mapping: - lengths1 = paddle.full((batch_size,), p1, dtype=paddle.int32) - lengths2 = paddle.full((batch_size,), p2, dtype=paddle.int32) - mapping, num_neighbors, outputs = self.ball_query_layer( - p_grid, - x, - lengths1, - lengths2, - ) - else: - lengths1 = paddle.full((batch_size,), p2, dtype=paddle.int32) - lengths2 = paddle.full((batch_size,), p1, dtype=paddle.int32) - mapping, num_neighbors, outputs = self.ball_query_layer( - x, - p_grid, - lengths1, - lengths2, - ) - - return mapping, outputs - - -class GeoConvOut(nn.Layer): - """Geometry layer to project STLs on grids""" - - def __init__(self, input_features, model_parameters, grid_resolution=[256, 96, 64]): - super().__init__() - base_neurons = model_parameters.base_neurons - - self.fc1 = nn.Linear(input_features, base_neurons) - self.fc2 = nn.Linear(base_neurons, int(base_neurons / 2)) - self.fc3 = nn.Linear(int(base_neurons / 2), model_parameters.base_neurons_out) - - self.grid_resolution = grid_resolution - - self.activation = F.relu - - def forward(self, x, radius=0.025, neighbors_in_radius=10): - batch_size = x.shape[0] - nx, ny, nz = ( - self.grid_resolution[0], - self.grid_resolution[1], - self.grid_resolution[2], - ) - - mask = abs(x - 0) > 1e-6 - - x = self.activation(self.fc1(x)) - x = self.activation(self.fc2(x)) - x = F.tanh(self.fc3(x)) - mask = mask[:, :, :, 0:1].expand( - [mask.shape[0], mask.shape[1], mask.shape[2], x.shape[-1]] - ) - - # paddle does not support multiplication with boolean tensors, - # so we convert the mask to float - x = paddle.sum(x * mask.to(dtype=x.dtype), 2) - - x = paddle.reshape(x, (batch_size, x.shape[-1], nx, ny, nz)) - return x - - -class GeoProcessor(nn.Layer): - """Geometry processing layer using CNNs""" - - def __init__(self, input_filters, model_parameters): - super().__init__() - base_filters = model_parameters.base_filters - self.conv1 = nn.Conv3D( - input_filters, base_filters, kernel_size=3, padding="same" - ) - self.conv_bn1 = nn.BatchNorm3D(int(base_filters)) - self.conv2 = nn.Conv3D( - base_filters, 2 * base_filters, kernel_size=3, padding="same" - ) - self.conv_bn2 = nn.BatchNorm3D(int(2 * base_filters)) - self.conv3 = nn.Conv3D( - 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" - ) - self.conv_bn3 = nn.BatchNorm3D(int(4 * base_filters)) - self.conv3_1 = nn.Conv3D( - 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" - ) - self.conv4 = nn.Conv3D( - 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" - ) - self.conv_bn4 = nn.BatchNorm3D(int(2 * base_filters)) - self.conv5 = nn.Conv3D( - 4 * base_filters, base_filters, kernel_size=3, padding="same" - ) - self.conv_bn5 = nn.BatchNorm3D(int(base_filters)) - self.conv6 = nn.Conv3D( - 2 * base_filters, input_filters, kernel_size=3, padding="same" - ) - self.conv_bn6 = nn.BatchNorm3D(int(input_filters)) - self.conv7 = nn.Conv3D( - 2 * input_filters, input_filters, kernel_size=3, padding="same" - ) - self.conv8 = nn.Conv3D(input_filters, 1, kernel_size=3, padding="same") - self.avg_pool = paddle.nn.AvgPool3D((2, 2, 2)) - self.max_pool = nn.MaxPool3D(2) - self.upsample = nn.Upsample(scale_factor=2, mode="nearest") - self.activation = F.relu - self.batch_norm = False - - def forward(self, x): - # Encoder - x0 = x - if self.batch_norm: - x = self.activation(self.conv_bn1(self.conv1(x))) - else: - x = self.activation(self.conv1(x)) - x = self.max_pool(x) - x1 = x - if self.batch_norm: - x = self.activation(self.conv_bn2(self.conv2(x))) - else: - x = self.activation((self.conv2(x))) - x = self.max_pool(x) - - x2 = x - if self.batch_norm: - x = self.activation(self.conv_bn3(self.conv2(x))) - else: - x = self.activation((self.conv3(x))) - x = self.max_pool(x) - - # Processor loop - x = F.relu(self.conv3_1(x)) - - # Decoder - if self.batch_norm: - x = self.activation(self.conv_bn4(self.conv4(x))) - else: - x = self.activation((self.conv4(x))) - x = self.upsample(x) - x = paddle.concat((x, x2), axis=1) - - if self.batch_norm: - x = self.activation(self.conv_bn5(self.conv5(x))) - else: - x = self.activation((self.conv5(x))) - x = self.upsample(x) - x = paddle.concat((x, x1), axis=1) - if self.batch_norm: - x = self.activation(self.conv_bn6(self.conv6(x))) - else: - x = self.activation((self.conv6(x))) - x = self.upsample(x) - x = paddle.concat((x, x0), axis=1) - - x = self.activation(self.conv7(x)) - x = self.conv8(x) - - return x - - -class GeometryRep(nn.Layer): - """Geometry representation from STLs block""" - - def __init__(self, input_features, model_parameters=None): - super().__init__() - geometry_rep = model_parameters.geometry_rep - - self.bq_warp_short = BQWarp( - input_features=input_features, - grid_resolution=model_parameters.interp_res, - radius=geometry_rep.geo_conv.radius_short, - ) - - self.bq_warp_long = BQWarp( - input_features=input_features, - grid_resolution=model_parameters.interp_res, - radius=geometry_rep.geo_conv.radius_long, - ) - - self.geo_conv_out = GeoConvOut( - input_features=input_features, - model_parameters=geometry_rep.geo_conv, - grid_resolution=model_parameters.interp_res, - ) - - self.geo_processor_short_range = GeoProcessor( - input_filters=geometry_rep.geo_conv.base_neurons_out, - model_parameters=geometry_rep.geo_processor, - ) - self.geo_processor_long_range = GeoProcessor( - input_filters=geometry_rep.geo_conv.base_neurons_out, - model_parameters=geometry_rep.geo_processor, - ) - self.geo_processor_sdf = GeoProcessor( - input_filters=6, model_parameters=geometry_rep.geo_processor - ) - self.activation = F.relu - self.radius_short = geometry_rep.geo_conv.radius_short - self.radius_long = geometry_rep.geo_conv.radius_long - self.hops = geometry_rep.geo_conv.hops - - def forward(self, x, p_grid, sdf): - - # Expand SDF - sdf = paddle.unsqueeze(sdf, 1) - - # Calculate short-range geoemtry dependency - mapping, k_short = self.bq_warp_short(x, p_grid) - x_encoding_short = self.geo_conv_out(k_short) - - # Calculate long-range geometry dependency - mapping, k_long = self.bq_warp_long(x, p_grid) - x_encoding_long = self.geo_conv_out(k_long) - - # Scaled sdf to emphasis on surface - scaled_sdf = scale_sdf(sdf) - # Binary sdf - binary_sdf = binarize_sdf(sdf) - # Gradients of SDF - sdf_x, sdf_y, sdf_z = calculate_gradient(sdf) - - # Propagate information in the geometry enclosed BBox - for _ in range(self.hops): - dx = self.geo_processor_short_range(x_encoding_short) / self.hops - x_encoding_short = x_encoding_short + dx - - # Propagate information in the computational domain BBox - for _ in range(self.hops): - dx = self.geo_processor_long_range(x_encoding_long) / self.hops - x_encoding_long = x_encoding_long + dx - - # Process SDF and its computed features - sdf = paddle.concat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) - sdf_encoding = self.geo_processor_sdf(sdf) - - # Geometry encoding comprised of short-range, long-range and SDF features - encoding_g = paddle.concat((x_encoding_short, sdf_encoding, x_encoding_long), 1) - - return encoding_g - - -class NNBasisFunctions(nn.Layer): - """Basis function layer for point clouds""" - - def __init__(self, input_features, model_parameters=None): - super(NNBasisFunctions, self).__init__() - self.input_features = input_features - - base_layer = model_parameters.base_layer - self.fc1 = nn.Linear(self.input_features, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - self.bn1 = nn.BatchNorm1D(base_layer) - self.bn2 = nn.BatchNorm1D(int(base_layer)) - self.bn3 = nn.BatchNorm1D(int(base_layer)) - - self.activation = F.relu - - def forward(self, x, padded_value=-10): - facets = x - facets = self.activation(self.fc1(facets)) - facets = self.activation(self.fc2(facets)) - facets = self.fc3(facets) - - return facets - - -class ParameterModel(nn.Layer): - """Layer to encode parameters such as inlet velocity and air density""" - - def __init__(self, input_features, model_parameters=None): - super(ParameterModel, self).__init__() - self.input_features = input_features - - base_layer = model_parameters.base_layer - self.fc1 = nn.Linear(self.input_features, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - self.bn1 = nn.BatchNorm1D(base_layer) - self.bn2 = nn.BatchNorm1D(int(base_layer)) - self.bn3 = nn.BatchNorm1D(int(base_layer)) - - self.activation = F.relu - - def forward(self, x, padded_value=-10): - params = x - params = self.activation(self.fc1(params)) - params = self.activation(self.fc2(params)) - params = self.fc3(params) - - return params - - -class AggregationModel(nn.Layer): - """Layer to aggregate local geometry encoding with basis functions""" - - def __init__( - self, input_features, output_features, model_parameters=None, new_change=True - ): - super(AggregationModel, self).__init__() - self.input_features = input_features - self.output_features = output_features - self.new_change = new_change - base_layer = model_parameters.base_layer - self.fc1 = nn.Linear(self.input_features, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - self.fc4 = nn.Linear(int(base_layer), int(base_layer)) - self.fc5 = nn.Linear(int(base_layer), self.output_features) - self.bn1 = nn.BatchNorm1D(base_layer) - self.bn2 = nn.BatchNorm1D(int(base_layer)) - self.bn3 = nn.BatchNorm1D(int(base_layer)) - self.bn4 = nn.BatchNorm1D(int(base_layer)) - self.activation = F.relu - - def forward(self, x): - out = self.activation(self.fc1(x)) - out = self.activation(self.fc2(out)) - out = self.activation(self.fc3(out)) - out = self.activation(self.fc4(out)) - - out = self.fc5(out) - - return out - - -class DoMINO(nn.Layer): - """DoMINO model architecture - Parameters - ---------- - input_features : int - Number of point input features - output_features_vol : int - Number of output features in volume - output_features_surf : int - Number of output features on surface - model_parameters: dict - Dictionary of model parameters controlled by config.yaml - - Example - ------- - >>> from modulus.models.domino.model import DoMINO - >>> import os - >>> from hydra import compose, initialize - >>> from omegaconf import OmegaConf - >>> cfg = OmegaConf.register_new_resolver("eval", eval) - >>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"): - ... cfg = compose(config_name="config") - >>> cfg.model.model_type = "combined" - >>> model = DoMINO( - ... input_features=3, - ... output_features_vol=5, - ... output_features_surf=4, - ... model_parameters=cfg.model - ... ) - - Warp ... - >>> bsize = 1 - >>> nx, ny, nz = 128, 64, 48 - >>> num_neigh = 7 - >>> pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) - >>> pos_normals_com_vol = paddle.randn([bsize, 100, 3]) - >>> pos_normals_com_surface = paddle.randn([bsize, 100, 3]) - >>> geom_centers = paddle.randn([bsize, 100, 3]) - >>> grid = paddle.randn([bsize, nx, ny, nz, 3]) - >>> surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) - >>> sdf_grid = paddle.randn([bsize, nx, ny, nz]) - >>> sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) - >>> sdf_nodes = paddle.randn([bsize, 100, 1]) - >>> surface_coordinates = paddle.randn([bsize, 100, 3]) - >>> surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) - >>> surface_normals = paddle.randn([bsize, 100, 3]) - >>> surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) - >>> surface_sizes = paddle.randn([bsize, 100, 3]) - >>> surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) - >>> volume_coordinates = paddle.randn([bsize, 100, 3]) - >>> vol_grid_max_min = paddle.randn([bsize, 2, 3]) - >>> surf_grid_max_min = paddle.randn([bsize, 2, 3]) - >>> stream_velocity = paddle.randn([bsize, 1]) - >>> air_density = paddle.randn([bsize, 1]) - >>> input_dict = { - ... "pos_volume_closest": pos_normals_closest_vol, - ... "pos_volume_center_of_mass": pos_normals_com_vol, - ... "pos_surface_center_of_mass": pos_normals_com_surface, - ... "geometry_coordinates": geom_centers, - ... "grid": grid, - ... "surf_grid": surf_grid, - ... "sdf_grid": sdf_grid, - ... "sdf_surf_grid": sdf_surf_grid, - ... "sdf_nodes": sdf_nodes, - ... "surface_mesh_centers": surface_coordinates, - ... "surface_mesh_neighbors": surface_neighbors, - ... "surface_normals": surface_normals, - ... "surface_neighbors_normals": surface_neighbors_normals, - ... "surface_areas": surface_sizes, - ... "surface_neighbors_areas": surface_neighbors_sizes, - ... "volume_mesh_centers": volume_coordinates, - ... "volume_min_max": vol_grid_max_min, - ... "surface_min_max": surf_grid_max_min, - ... "stream_velocity": stream_velocity, - ... "air_density": air_density, - ... } - >>> output = model(input_dict) - Module ... - >>> print(f"{output[0].shape}, {output[1].shape}") - """ - - def __init__( - self, - input_features, - output_features_vol=None, - output_features_surf=None, - model_parameters=None, - ): - super(DoMINO, self).__init__() - self.input_features = input_features - self.output_features_vol = output_features_vol - self.output_features_surf = output_features_surf - - if self.output_features_vol is None and self.output_features_surf is None: - raise ValueError("Need to specify number of volume or surface features") - - self.num_variables_vol = output_features_vol - self.num_variables_surf = output_features_surf - self.grid_resolution = model_parameters.interp_res - self.surface_neighbors = model_parameters.surface_neighbors - self.use_surface_normals = model_parameters.use_surface_normals - self.use_only_normals = model_parameters.use_only_normals - self.encode_parameters = model_parameters.encode_parameters - self.param_scaling_factors = model_parameters.parameter_model.scaling_params - - if self.use_surface_normals: - if self.use_only_normals: - input_features_surface = input_features + 3 - else: - input_features_surface = input_features + 4 - else: - input_features_surface = input_features - - if self.encode_parameters: - # Defining the parameter model - base_layer_p = model_parameters.parameter_model.base_layer - self.parameter_model = ParameterModel( - input_features=2, model_parameters=model_parameters.parameter_model - ) - else: - base_layer_p = 0 - - self.geo_rep = GeometryRep( - input_features=input_features, - model_parameters=model_parameters, - ) - - # Basis functions for surface and volume - base_layer_nn = model_parameters.nn_basis_functions.base_layer - if self.output_features_surf is not None: - self.nn_basis_surf = nn.LayerList() - for _ in range(self.num_variables_surf): - self.nn_basis_surf.append( - NNBasisFunctions( - input_features=input_features_surface, - model_parameters=model_parameters.nn_basis_functions, - ) - ) - - if self.output_features_vol is not None: - self.nn_basis_vol = nn.LayerList() - for _ in range(self.num_variables_vol): - self.nn_basis_vol.append( - NNBasisFunctions( - input_features=input_features, - model_parameters=model_parameters.nn_basis_functions, - ) - ) - - # Positional encoding - position_encoder_base_neurons = model_parameters.position_encoder.base_neurons - if self.output_features_vol is not None: - if model_parameters.positional_encoding: - inp_pos_vol = 25 if model_parameters.use_sdf_in_basis_func else 12 - else: - inp_pos_vol = 7 if model_parameters.use_sdf_in_basis_func else 3 - - self.fc_p_vol = nn.Linear(inp_pos_vol, position_encoder_base_neurons) - - if self.output_features_surf is not None: - if model_parameters.positional_encoding: - inp_pos_surf = 12 - else: - inp_pos_surf = 3 - - self.fc_p_surf = nn.Linear(inp_pos_surf, position_encoder_base_neurons) - - # Positional encoding hidden layers - self.fc_p1 = nn.Linear( - position_encoder_base_neurons, position_encoder_base_neurons - ) - self.fc_p2 = nn.Linear( - position_encoder_base_neurons, position_encoder_base_neurons - ) - - # BQ for surface and volume - self.neighbors_in_radius = model_parameters.geometry_local.neighbors_in_radius - self.radius = model_parameters.geometry_local.radius - self.bq_warp = BQWarp( - input_features=input_features, - grid_resolution=model_parameters.interp_res, - radius=self.radius, - neighbors_in_radius=self.neighbors_in_radius, - ) - - base_layer_geo = model_parameters.geometry_local.base_layer - self.fc_1 = nn.Linear(self.neighbors_in_radius * 3, base_layer_geo) - self.fc_2 = nn.Linear(base_layer_geo, base_layer_geo) - self.activation = F.relu - - # Aggregation model - if self.output_features_surf is not None: - # Surface - self.agg_model_surf = nn.LayerList() - for _ in range(self.num_variables_surf): - self.agg_model_surf.append( - AggregationModel( - input_features=position_encoder_base_neurons - + base_layer_nn - + base_layer_geo - + base_layer_p, - output_features=1, - model_parameters=model_parameters.aggregation_model, - ) - ) - - if self.output_features_vol is not None: - # Volume - self.agg_model_vol = nn.LayerList() - for _ in range(self.num_variables_vol): - self.agg_model_vol.append( - AggregationModel( - input_features=position_encoder_base_neurons - + base_layer_nn - + base_layer_geo - + base_layer_p, - output_features=1, - model_parameters=model_parameters.aggregation_model, - ) - ) - - self.apply(kaiming_init) - - def geometry_encoder(self, geo_centers, p_grid, sdf): - """Function to return local geometry encoding""" - return self.geo_rep(geo_centers, p_grid, sdf) - - def position_encoder(self, encoding_node, eval_mode="volume"): - """Function to calculate positional encoding""" - if eval_mode == "volume": - x = self.activation(self.fc_p_vol(encoding_node)) - elif eval_mode == "surface": - x = self.activation(self.fc_p_surf(encoding_node)) - x = self.activation(self.fc_p1(x)) - x = self.fc_p2(x) - return x - - def geo_encoding_local_surface(self, encoding_g, volume_mesh_centers, p_grid): - """Function to calculate local geometry encoding from global encoding for surface""" - batch_size = volume_mesh_centers.shape[0] - nx, ny, nz = ( - self.grid_resolution[0], - self.grid_resolution[1], - self.grid_resolution[2], - ) - p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) - mapping, outputs = self.bq_warp( - volume_mesh_centers, p_grid, reverse_mapping=False - ) - mapping = mapping.astype(paddle.int64) - mask = mapping != 0 - - geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) - geo_encoding = geo_encoding.expand( - [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] - ) - sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) - sdf_encoding = sdf_encoding.expand( - [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] - ) - geo_encoding_long = paddle.reshape( - encoding_g[:, 2], (batch_size, 1, nx * ny * nz) - ) - geo_encoding_long = geo_encoding_long.expand( - [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] - ) - - geo_encoding_sampled = paddle.take_along_axis( - geo_encoding, axis=2, indices=mapping - ) * mask.to(dtype=geo_encoding.dtype) - sdf_encoding_sampled = paddle.take_along_axis( - sdf_encoding, axis=2, indices=mapping - ) * mask.to(dtype=geo_encoding.dtype) - geo_encoding_long_sampled = paddle.take_along_axis( - geo_encoding_long, axis=2, indices=mapping - ) * mask.to(dtype=geo_encoding.dtype) - - encoding_g = paddle.concat( - (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), - axis=2, - ) - encoding_g = self.activation(self.fc_1(encoding_g)) - encoding_g = self.fc_2(encoding_g) - - return encoding_g - - def geo_encoding_local(self, encoding_g, volume_mesh_centers, p_grid): - """Function to calculate local geometry encoding from global encoding""" - batch_size = volume_mesh_centers.shape[0] - nx, ny, nz = ( - self.grid_resolution[0], - self.grid_resolution[1], - self.grid_resolution[2], - ) - p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) - mapping, outputs = self.bq_warp( - volume_mesh_centers, p_grid, reverse_mapping=False - ) - mapping = mapping.astype(paddle.int64) - mask = mapping != 0 - - geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) - geo_encoding = geo_encoding.expand( - [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] - ) - sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) - sdf_encoding = sdf_encoding.expand( - [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] - ) - geo_encoding_long = paddle.reshape( - encoding_g[:, 2], (batch_size, 1, nx * ny * nz) - ) - geo_encoding_long = geo_encoding_long.expand( - [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] - ) - - geo_encoding_sampled = paddle.take_along_axis( - geo_encoding, axis=2, indices=mapping - ) * mask.to(dtype=geo_encoding.dtype) - sdf_encoding_sampled = paddle.take_along_axis( - sdf_encoding, axis=2, indices=mapping - ) * mask.to(dtype=geo_encoding.dtype) - geo_encoding_long_sampled = paddle.take_along_axis( - geo_encoding_long, axis=2, indices=mapping - ) * mask.to(dtype=geo_encoding.dtype) - - encoding_g = paddle.concat( - (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), - axis=2, - ) - encoding_g = self.activation(self.fc_1(encoding_g)) - encoding_g = self.fc_2(encoding_g) - - return encoding_g - - def calculate_solution_with_neighbors( - self, - surface_mesh_centers, - encoding_g, - encoding_node, - surface_mesh_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - inlet_velocity, - air_density, - ): - """Function to approximate solution given the neighborhood information""" - num_variables = self.num_variables_surf - nn_basis = self.nn_basis_surf - agg_model = self.agg_model_surf - num_sample_points = surface_mesh_neighbors.shape[2] + 1 - - if self.encode_parameters: - inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) - inlet_velocity = inlet_velocity.expand( - [ - inlet_velocity.shape[0], - surface_mesh_centers.shape[1], - inlet_velocity.shape[2], - ] - ) - inlet_velocity = inlet_velocity / self.param_scaling_factors[0] - - air_density = paddle.unsqueeze(air_density, 1) - air_density = air_density.expand( - [ - air_density.shape[0], - surface_mesh_centers.shape[1], - air_density.shape[2], - ] - ) - air_density = air_density / self.param_scaling_factors[1] - - params = paddle.concat((inlet_velocity, air_density), axis=-1) - param_encoding = self.parameter_model(params) - - if self.use_surface_normals: - if self.use_only_normals: - surface_mesh_centers = paddle.concat( - (surface_mesh_centers, surface_normals), - axis=-1, - ) - surface_mesh_neighbors = paddle.concat( - ( - surface_mesh_neighbors, - surface_neighbors_normals, - ), - axis=-1, - ) - - else: - surface_mesh_centers = paddle.concat( - (surface_mesh_centers, surface_normals, 10**5 * surface_areas), - axis=-1, - ) - surface_mesh_neighbors = paddle.concat( - ( - surface_mesh_neighbors, - surface_neighbors_normals, - 10**5 * surface_neighbors_areas, - ), - axis=-1, - ) - - for f in range(num_variables): - for p in range(num_sample_points): - if p == 0: - volume_m_c = surface_mesh_centers - else: - volume_m_c = surface_mesh_neighbors[:, :, p - 1] - noise = surface_mesh_centers - volume_m_c - dist = paddle.sqrt( - noise[:, :, 0:1] ** 2.0 - + noise[:, :, 1:2] ** 2.0 - + noise[:, :, 2:3] ** 2.0 - ) - basis_f = nn_basis[f](volume_m_c) - output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) - if self.encode_parameters: - output = paddle.concat((output, param_encoding), axis=-1) - if p == 0: - output_center = agg_model[f](output) - else: - if p == 1: - output_neighbor = agg_model[f](output) * (1.0 / dist) - dist_sum = 1.0 / dist - else: - output_neighbor += agg_model[f](output) * (1.0 / dist) - dist_sum += 1.0 / dist - if num_sample_points > 1: - output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum - else: - output_res = output_center - if f == 0: - output_all = output_res - else: - output_all = paddle.concat((output_all, output_res), axis=-1) - - return output_all - - def calculate_solution( - self, - volume_mesh_centers, - encoding_g, - encoding_node, - inlet_velocity, - air_density, - eval_mode, - num_sample_points=20, - noise_intensity=50, - ): - """Function to approximate solution sampling the neighborhood information""" - if eval_mode == "volume": - num_variables = self.num_variables_vol - nn_basis = self.nn_basis_vol - agg_model = self.agg_model_vol - elif eval_mode == "surface": - num_variables = self.num_variables_surf - nn_basis = self.nn_basis_surf - agg_model = self.agg_model_surf - - if self.encode_parameters: - inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) - inlet_velocity = inlet_velocity.expand( - [ - inlet_velocity.shape[0], - volume_mesh_centers.shape[1], - inlet_velocity.shape[2], - ] - ) - inlet_velocity = inlet_velocity / self.param_scaling_factors[0] - - air_density = paddle.unsqueeze(air_density, 1) - air_density = air_density.expand( - [ - air_density.shape[0], - volume_mesh_centers.shape[1], - air_density.shape[2], - ] - ) - air_density = air_density / self.param_scaling_factors[1] - - params = paddle.concat((inlet_velocity, air_density), axis=-1) - param_encoding = self.parameter_model(params) - - for f in range(num_variables): - for p in range(num_sample_points): - if p == 0: - volume_m_c = volume_mesh_centers - else: - noise = paddle.rand( - shape=volume_mesh_centers.shape, dtype=volume_mesh_centers.dtype - ) - noise = 2 * (noise - 0.5) - noise = noise / noise_intensity - dist = paddle.sqrt( - noise[:, :, 0:1] ** 2.0 - + noise[:, :, 1:2] ** 2.0 - + noise[:, :, 2:3] ** 2.0 - ) - volume_m_c = volume_mesh_centers + noise - basis_f = nn_basis[f](volume_m_c) - output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) - if self.encode_parameters: - output = paddle.concat((output, param_encoding), axis=-1) - if p == 0: - output_center = agg_model[f](output) - else: - if p == 1: - output_neighbor = agg_model[f](output) * (1.0 / dist) - dist_sum = 1.0 / dist - else: - output_neighbor += agg_model[f](output) * (1.0 / dist) - dist_sum += 1.0 / dist - if num_sample_points > 1: - output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum - else: - output_res = output_center - if f == 0: - output_all = output_res - else: - output_all = paddle.concat((output_all, output_res), axis=-1) - - return output_all - - def forward( - self, - data_dict, - ): - # Loading STL inputs, bounding box grids, precomputed SDF and scaling factors - - # STL nodes - geo_centers = data_dict["geometry_coordinates"] - - # Bounding box grid - s_grid = data_dict["surf_grid"] - sdf_surf_grid = data_dict["sdf_surf_grid"] - # Scaling factors - surf_max = data_dict["surface_min_max"][:, 1] - surf_min = data_dict["surface_min_max"][:, 0] - - # Parameters - stream_velocity = data_dict["stream_velocity"] - air_density = data_dict["air_density"] - - if self.output_features_vol is not None: - # Represent geometry on computational grid - # Computational domain grid - p_grid = data_dict["grid"] - sdf_grid = data_dict["sdf_grid"] - # Scaling factors - vol_max = data_dict["volume_min_max"][:, 1] - vol_min = data_dict["volume_min_max"][:, 0] - - # Normalize based on computational domain - geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 - encoding_g_vol = self.geo_rep(geo_centers_vol, p_grid, sdf_grid) - - # Normalize based on BBox around surface (car) - geo_centers_surf = ( - 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 - ) - encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) - - # SDF on volume mesh nodes - sdf_nodes = data_dict["sdf_nodes"] - # Positional encoding based on closest point on surface to a volume node - pos_volume_closest = data_dict["pos_volume_closest"] - # Positional encoding based on center of mass of geometry to volume node - pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] - encoding_node_vol = paddle.concat( - (sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), axis=-1 - ) - - # Calculate positional encoding on volume nodes - encoding_node_vol = self.position_encoder( - encoding_node_vol, eval_mode="volume" - ) - - if self.output_features_surf is not None: - # Represent geometry on bounding box - geo_centers_surf = ( - 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 - ) - encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) - - # Positional encoding based on center of mass of geometry to surface node - pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] - encoding_node_surf = pos_surface_center_of_mass - - # Calculate positional encoding on surface centers - encoding_node_surf = self.position_encoder( - encoding_node_surf, eval_mode="surface" - ) - - encoding_g = 0.5 * encoding_g_surf - # Average the encodings - if self.output_features_vol is not None: - encoding_g += 0.5 * encoding_g_vol - - if self.output_features_vol is not None: - # Calculate local geometry encoding for volume - # Sampled points on volume - volume_mesh_centers = data_dict["volume_mesh_centers"] - encoding_g_vol = self.geo_encoding_local( - encoding_g, volume_mesh_centers, p_grid - ) - - # Approximate solution on volume node - output_vol = self.calculate_solution( - volume_mesh_centers, - encoding_g_vol, - encoding_node_vol, - stream_velocity, - air_density, - eval_mode="volume", - ) - else: - output_vol = None - - if self.output_features_surf is not None: - # Sampled points on surface - surface_mesh_centers = data_dict["surface_mesh_centers"] - surface_normals = data_dict["surface_normals"] - surface_areas = data_dict["surface_areas"] - - # Neighbors of sampled points on surface - surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] - surface_neighbors_normals = data_dict["surface_neighbors_normals"] - surface_neighbors_areas = data_dict["surface_neighbors_areas"] - surface_areas = paddle.unsqueeze(surface_areas, -1) - surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) - # Calculate local geometry encoding for surface - encoding_g_surf = self.geo_encoding_local_surface( - 0.5 * encoding_g_surf, surface_mesh_centers, s_grid - ) - - # Approximate solution on surface cell center - if not self.surface_neighbors: - output_surf = self.calculate_solution( - surface_mesh_centers, - encoding_g_surf, - encoding_node_surf, - stream_velocity, - air_density, - eval_mode="surface", - num_sample_points=1, - noise_intensity=500, - ) - else: - output_surf = self.calculate_solution_with_neighbors( - surface_mesh_centers, - encoding_g_surf, - encoding_node_surf, - surface_mesh_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - stream_velocity, - air_density, - ) - else: - output_surf = None - - return output_vol, output_surf - - -if __name__ == "__main__": - from hydra import compose - from hydra import initialize - from omegaconf import OmegaConf - - if paddle.device.cuda.device_count() >= 1: - paddle.set_device("gpu") - else: - paddle.set_device("cpu") - cfg = OmegaConf.register_new_resolver("eval", eval) - with initialize(version_base="1.3", config_path="../../scripts/conf"): - cfg = compose(config_name="config") - cfg.model.model_type = "combined" - model = DoMINO( - input_features=3, - output_features_vol=5, - output_features_surf=4, - model_parameters=cfg.model, - ) - - bsize = 1 - nx, ny, nz = 128, 64, 48 - num_neigh = 7 - pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) - pos_normals_com_vol = paddle.randn([bsize, 100, 3]) - pos_normals_com_surface = paddle.randn([bsize, 100, 3]) - geom_centers = paddle.randn([bsize, 100, 3]) - grid = paddle.randn([bsize, nx, ny, nz, 3]) - surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) - sdf_grid = paddle.randn([bsize, nx, ny, nz]) - sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) - sdf_nodes = paddle.randn([bsize, 100, 1]) - surface_coordinates = paddle.randn([bsize, 100, 3]) - surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) - surface_normals = paddle.randn([bsize, 100, 3]) - surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) - surface_sizes = paddle.randn([bsize, 100, 3]) - surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) - volume_coordinates = paddle.randn([bsize, 100, 3]) - vol_grid_max_min = paddle.randn([bsize, 2, 3]) - surf_grid_max_min = paddle.randn([bsize, 2, 3]) - stream_velocity = paddle.randn([bsize, 1]) - air_density = paddle.randn([bsize, 1]) - input_dict = { - "pos_volume_closest": pos_normals_closest_vol, - "pos_volume_center_of_mass": pos_normals_com_vol, - "pos_surface_center_of_mass": pos_normals_com_surface, - "geometry_coordinates": geom_centers, - "grid": grid, - "surf_grid": surf_grid, - "sdf_grid": sdf_grid, - "sdf_surf_grid": sdf_surf_grid, - "sdf_nodes": sdf_nodes, - "surface_mesh_centers": surface_coordinates, - "surface_mesh_neighbors": surface_neighbors, - "surface_normals": surface_normals, - "surface_neighbors_normals": surface_neighbors_normals, - "surface_areas": surface_sizes, - "surface_neighbors_areas": surface_neighbors_sizes, - "volume_mesh_centers": volume_coordinates, - "volume_min_max": vol_grid_max_min, - "surface_min_max": surf_grid_max_min, - "stream_velocity": stream_velocity, - "air_density": air_density, - } - output = model(input_dict) - print(f"{output[0].shape}, {output[1].shape}") diff --git a/ppsci/arch/physicsnemo/utils/__init__.py b/ppsci/arch/physicsnemo/utils/__init__.py deleted file mode 100644 index b2f171d4ac..0000000000 --- a/ppsci/arch/physicsnemo/utils/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/ppsci/arch/physicsnemo/utils/domino/utils.py b/ppsci/arch/physicsnemo/utils/domino/utils.py deleted file mode 100644 index 28f61de49f..0000000000 --- a/ppsci/arch/physicsnemo/utils/domino/utils.py +++ /dev/null @@ -1,383 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Important utilities for data processing and training, testing DoMINO. -""" - -import os -import time - -import numpy as np -from scipy.spatial import KDTree - -try: - import pyvista as pv - - PV_AVAILABLE = True -except ImportError: - PV_AVAILABLE = False -try: - import vtk - from vtk import vtkDataSetTriangleFilter - from vtk.util import numpy_support - - VTK_AVAILABLE = True -except ImportError: - VTK_AVAILABLE = False - - -def calculate_center_of_mass(stl_centers, stl_sizes): - """Function to calculate center of mass""" - stl_sizes = np.expand_dims(stl_sizes, -1) - center_of_mass = np.sum(stl_centers * stl_sizes, axis=0) / np.sum(stl_sizes, axis=0) - return center_of_mass - - -def normalize(field, mx, mn): - """Function to normalize fields""" - return 2.0 * (field - mn) / (mx - mn) - 1.0 - - -def unnormalize(field, mx, mn): - """Function to unnormalize fields""" - return (field + 1.0) * (mx - mn) * 0.5 + mn - - -def standardize(field, mean, std): - """Function to standardize fields""" - return (field - mean) / std - - -def unstandardize(field, mean, std): - """Function to unstandardize fields""" - return field * std + mean - - -def write_to_vtp(polydata, filename): - """Function to write polydata to vtp""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - writer = vtk.vtkXMLPolyDataWriter() - writer.SetFileName(filename) - writer.SetInputData(polydata) - writer.Write() - - -def write_to_vtu(polydata, filename): - """Function to write polydata to vtu""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - writer = vtk.vtkXMLUnstructuredGridWriter() - writer.SetFileName(filename) - writer.SetInputData(polydata) - writer.Write() - - -def extract_surface_triangles(tet_mesh): - """Extracts the surface triangles from a triangular mesh.""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - if not PV_AVAILABLE: - raise ImportError("PyVista is not installed. This function cannot be used.") - surface_filter = vtk.vtkDataSetSurfaceFilter() - surface_filter.SetInputData(tet_mesh) - surface_filter.Update() - - surface_mesh = pv.wrap(surface_filter.GetOutput()) - triangle_indices = [] - faces = surface_mesh.faces.reshape((-1, 4)) - for face in faces: - if face[0] == 3: - triangle_indices.extend([face[1], face[2], face[3]]) - else: - raise ValueError("Face is not a triangle") - - return triangle_indices - - -def convert_to_tet_mesh(polydata): - """Function to convert tet to stl""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - # Create a VTK DataSetTriangleFilter object - tet_filter = vtkDataSetTriangleFilter() - tet_filter.SetInputData(polydata) - tet_filter.Update() # Update to apply the filter - - # Get the output as an UnstructuredGrid - # tet_mesh = pv.wrap(tet_filter.GetOutput()) - tet_mesh = tet_filter.GetOutput() - return tet_mesh - - -def get_node_to_elem(polydata): - """Function to convert node to elem""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - c2p = vtk.vtkPointDataToCellData() - c2p.SetInputData(polydata) - c2p.Update() - cell_data = c2p.GetOutput() - return cell_data - - -def get_fields_from_cell(ptdata, var_list): - """Function to get fields from elem""" - fields = [] - for var in var_list: - variable = ptdata.GetArray(var) - num_tuples = variable.GetNumberOfTuples() - cell_fields = [] - for j in range(num_tuples): - variable_value = np.array(variable.GetTuple(j)) - cell_fields.append(variable_value) - cell_fields = np.asarray(cell_fields) - fields.append(cell_fields) - fields = np.transpose(np.asarray(fields), (1, 0)) - - return fields - - -def get_fields(data, variables): - """Function to get fields from VTP/VTU""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - fields = [] - for array_name in variables: - try: - array = data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = numpy_support.vtk_to_numpy(array).reshape( - array.GetNumberOfTuples(), array.GetNumberOfComponents() - ) - fields.append(array_data) - return fields - - -def get_vertices(polydata): - """Function to get vertices""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - points = polydata.GetPoints() - vertices = numpy_support.vtk_to_numpy(points.GetData()) - return vertices - - -def get_volume_data(polydata, variables): - """Function to get volume data""" - vertices = get_vertices(polydata) - point_data = polydata.GetPointData() - - fields = get_fields(point_data, variables) - - return vertices, fields - - -def get_surface_data(polydata, variables): - """Function to get surface data""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - points = polydata.GetPoints() - vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) - - point_data = polydata.GetPointData() - fields = [] - for array_name in variables: - try: - array = point_data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = np.zeros( - (points.GetNumberOfPoints(), array.GetNumberOfComponents()) - ) - for j in range(points.GetNumberOfPoints()): - array.GetTuple(j, array_data[j]) - fields.append(array_data) - - polys = polydata.GetPolys() - if polys is None: - raise ValueError("Failed to get polygons from the polydata.") - polys.InitTraversal() - edges = [] - id_list = vtk.vtkIdList() - for _ in range(polys.GetNumberOfCells()): - polys.GetNextCell(id_list) - num_ids = id_list.GetNumberOfIds() - edges = [ - (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) - ] - - return vertices, fields, edges - - -def cal_normal_positional_encoding(coordinates_a, coordinates_b=None, cell_length=[]): - """Function to get normal positional encoding""" - dx = cell_length[0] - dy = cell_length[1] - dz = cell_length[2] - if coordinates_b is not None: - normals = coordinates_a - coordinates_b - pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) - pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) - pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) - pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) - else: - normals = coordinates_a - pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) - pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) - pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) - pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) - - return pos_normals - - -def nd_interpolator(coodinates, field, grid): - """Function to for nd interpolation""" - interp_func = KDTree(coodinates[0]) - dd, ii = interp_func.query(grid, k=2) - - field_grid = field[ii] - field_grid = np.float32(np.mean(field_grid, (3))) - return field_grid - - -def pad(arr, npoin, pad_value=0.0): - """Function for padding""" - arr_pad = pad_value * np.ones( - (npoin - arr.shape[0], arr.shape[1]), dtype=np.float32 - ) - arr_padded = np.concatenate((arr, arr_pad), axis=0) - return arr_padded - - -def pad_inp(arr, npoin, pad_value=0.0): - """Function for padding arrays""" - arr_pad = pad_value * np.ones( - (npoin - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=np.float32 - ) - arr_padded = np.concatenate((arr, arr_pad), axis=0) - return arr_padded - - -def shuffle_array(arr, npoin): - """Function for shuffling arrays""" - np.random.seed(seed=int(time.time())) - idx = np.arange(arr.shape[0]) - np.random.shuffle(idx) - idx = idx[:npoin] - return arr[idx], idx - - -def shuffle_array_without_sampling(arr): - """Function for shuffline arrays without sampling.""" - idx = np.arange(arr.shape[0]) - np.random.shuffle(idx) - return arr[idx], idx - - -def create_directory(filepath): - """Function to create directories""" - if not os.path.exists(filepath): - os.makedirs(filepath) - - -def get_filenames(filepath): - """Function to get filenames from a directory""" - if os.path.exists(filepath): - filenames = os.listdir(filepath) - return filenames - else: - FileNotFoundError() - - -def calculate_pos_encoding(nx, d=8): - """Function for calculating positional encoding""" - vec = [] - for k in range(int(d / 2)): - vec.append(np.sin(nx / 10000 ** (2 * (k) / d))) - vec.append(np.cos(nx / 10000 ** (2 * (k) / d))) - return vec - - -def combine_dict(old_dict, new_dict): - """Function to combine dictionaries""" - for j in old_dict.keys(): - old_dict[j] += new_dict[j] - return old_dict - - -def merge(*lists): - """Function to merge lists""" - newlist = lists[:] - for x in lists: - if x not in newlist: - newlist.extend(x) - return newlist - - -def create_grid(mx, mn, nres): - """Function to create grid""" - dx = np.linspace(mn[0], mx[0], nres[0]) - dy = np.linspace(mn[1], mx[1], nres[1]) - dz = np.linspace(mn[2], mx[2], nres[2]) - - xv, yv, zv = np.meshgrid(dx, dy, dz) - xv = np.expand_dims(xv, -1) - yv = np.expand_dims(yv, -1) - zv = np.expand_dims(zv, -1) - grid = np.concatenate((xv, yv, zv), axis=-1) - grid = np.transpose(grid, (1, 0, 2, 3)) - - return grid - - -def mean_std_sampling(field, mean, std, tolerance=3.0): - """Function for mean/std based sampling""" - idx_all = [] - for v in range(field.shape[-1]): - fv = field[:, v] - idx = np.where( - (fv > mean[v] + tolerance * std[v]) | (fv < mean[v] - tolerance * std[v]) - ) - if len(idx[0]) != 0: - idx_all += list(idx[0]) - - return idx_all - - -def dict_to_device(state_dict, device): - """Function to load dictionary to device""" - new_state_dict = {} - for k, v in state_dict.items(): - new_state_dict[k] = v.to(device) - return new_state_dict - - -def area_weighted_shuffle_array(arr, npoin, area): - factor = 1.0 - total_area = np.sum(area**factor) - probs = area**factor / total_area - np.random.seed(seed=int(time.time())) - idx = np.arange(arr.shape[0]) - np.random.shuffle(idx) - ids = np.random.choice(idx, npoin, p=probs[idx]) - return arr[ids], ids diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py index 9ece354700..4146846cf2 100644 --- a/ppsci/data/dataset/__init__.py +++ b/ppsci/data/dataset/__init__.py @@ -27,6 +27,7 @@ from ppsci.data.dataset.cylinder_dataset import MeshCylinderDataset from ppsci.data.dataset.darcyflow_dataset import DarcyFlowDataset from ppsci.data.dataset.dgmr_dataset import DGMRDataset +from ppsci.data.dataset.domino_datapipe import DoMINODataPipe from ppsci.data.dataset.drivaernet_dataset import DrivAerNetDataset from ppsci.data.dataset.drivaernetplusplus_dataset import DrivAerNetPlusPlusDataset from ppsci.data.dataset.enso_dataset import ENSODataset @@ -93,6 +94,7 @@ "DrivAerNetDataset", "DrivAerNetPlusPlusDataset", "IFMMoeDataset", + "DoMINODataPipe", ] diff --git a/ppsci/arch/physicsnemo/datapipes/cae/domino_datapipe.py b/ppsci/data/dataset/domino_datapipe.py similarity index 86% rename from ppsci/arch/physicsnemo/datapipes/cae/domino_datapipe.py rename to ppsci/data/dataset/domino_datapipe.py index 17157ea2a0..be9499bfeb 100644 --- a/ppsci/arch/physicsnemo/datapipes/cae/domino_datapipe.py +++ b/ppsci/data/dataset/domino_datapipe.py @@ -1,18 +1,18 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino """ This code provides the datapipe for reading the processed npy files, @@ -26,6 +26,8 @@ variable names, domain resolution, sampling size etc. are configurable in config.yaml. """ +import os +import time from pathlib import Path from typing import Literal from typing import Optional @@ -34,18 +36,110 @@ import numpy as np from paddle.io import Dataset +from scipy.spatial import KDTree + +from ppsci.utils.sdf import signed_distance_field + + +def calculate_center_of_mass(stl_centers, stl_sizes): + """Function to calculate center of mass""" + stl_sizes = np.expand_dims(stl_sizes, -1) + center_of_mass = np.sum(stl_centers * stl_sizes, axis=0) / np.sum(stl_sizes, axis=0) + return center_of_mass -from ...utils.domino.utils import KDTree -from ...utils.domino.utils import area_weighted_shuffle_array -from ...utils.domino.utils import cal_normal_positional_encoding -from ...utils.domino.utils import calculate_center_of_mass -from ...utils.domino.utils import create_grid -from ...utils.domino.utils import get_filenames -from ...utils.domino.utils import normalize -from ...utils.domino.utils import pad -from ...utils.domino.utils import shuffle_array -from ...utils.domino.utils import standardize -from ...utils.sdf import signed_distance_field + +def normalize(field, mx, mn): + """Function to normalize fields""" + return 2.0 * (field - mn) / (mx - mn) - 1.0 + + +def standardize(field, mean, std): + """Function to standardize fields""" + return (field - mean) / std + + +def cal_normal_positional_encoding(coordinates_a, coordinates_b=None, cell_length=[]): + """Function to get normal positional encoding""" + dx = cell_length[0] + dy = cell_length[1] + dz = cell_length[2] + if coordinates_b is not None: + normals = coordinates_a - coordinates_b + pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) + pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) + pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) + pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + else: + normals = coordinates_a + pos_x = np.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) + pos_y = np.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) + pos_z = np.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) + pos_normals = np.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + + return pos_normals + + +def pad(arr, npoin, pad_value=0.0): + """Function for padding""" + arr_pad = pad_value * np.ones( + (npoin - arr.shape[0], arr.shape[1]), dtype=np.float32 + ) + arr_padded = np.concatenate((arr, arr_pad), axis=0) + return arr_padded + + +def shuffle_array(arr, npoin): + """Function for shuffling arrays""" + np.random.seed(seed=int(time.time())) + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + idx = idx[:npoin] + return arr[idx], idx + + +def get_filenames(filepath): + """Function to get filenames from a directory""" + if os.path.exists(filepath): + filenames = os.listdir(filepath) + return filenames + else: + FileNotFoundError() + + +def calculate_pos_encoding(nx, d=8): + """Function for calculating positional encoding""" + vec = [] + for k in range(int(d / 2)): + vec.append(np.sin(nx / 10000 ** (2 * (k) / d))) + vec.append(np.cos(nx / 10000 ** (2 * (k) / d))) + return vec + + +def create_grid(mx, mn, nres): + """Function to create grid""" + dx = np.linspace(mn[0], mx[0], nres[0]) + dy = np.linspace(mn[1], mx[1], nres[1]) + dz = np.linspace(mn[2], mx[2], nres[2]) + + xv, yv, zv = np.meshgrid(dx, dy, dz) + xv = np.expand_dims(xv, -1) + yv = np.expand_dims(yv, -1) + zv = np.expand_dims(zv, -1) + grid = np.concatenate((xv, yv, zv), axis=-1) + grid = np.transpose(grid, (1, 0, 2, 3)) + + return grid + + +def area_weighted_shuffle_array(arr, npoin, area): + factor = 1.0 + total_area = np.sum(area**factor) + probs = area**factor / total_area + np.random.seed(seed=int(time.time())) + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + ids = np.random.choice(idx, npoin, p=probs[idx]) + return arr[ids], ids class DoMINODataPipe(Dataset): diff --git a/ppsci/data/process/__init__.py b/ppsci/data/process/__init__.py index f46c8dd9cf..cf217ef7f9 100644 --- a/ppsci/data/process/__init__.py +++ b/ppsci/data/process/__init__.py @@ -17,5 +17,6 @@ __all__ = [ "batch_transform", + "openfoam", "transform", ] diff --git a/ppsci/arch/physicsnemo/datapipes/cae/__init__.py b/ppsci/data/process/openfoam/__init__.py similarity index 68% rename from ppsci/arch/physicsnemo/datapipes/cae/__init__.py rename to ppsci/data/process/openfoam/__init__.py index 71e4d00436..92dc8f2974 100644 --- a/ppsci/arch/physicsnemo/datapipes/cae/__init__.py +++ b/ppsci/data/process/openfoam/__init__.py @@ -1,17 +1,21 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +from ppsci.data.process.openfoam import process_files -from .domino_datapipe import DoMINODataPipe # noqa: F401 +__all__ = [ + "process_files", +] diff --git a/ppsci/data/process/openfoam/preprocess.py b/ppsci/data/process/openfoam/preprocess.py new file mode 100644 index 0000000000..573674d879 --- /dev/null +++ b/ppsci/data/process/openfoam/preprocess.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +import os +import time + +import numpy as np + + +def process_files(*args_list): + ids = args_list[0] + processor_id = args_list[1] + fm_data = args_list[2] + output_dir = args_list[3] + for j in ids: + fname = fm_data.filenames[j] + if len(os.listdir(os.path.join(fm_data.data_path, fname))) == 0: + print(f"Skipping {fname} - empty.") + continue + outname = os.path.join(output_dir, fname) + print("Filename:%s on processor: %d" % (outname, processor_id)) + filename = f"{outname}.npy" + if os.path.exists(filename): + print(f"Skipping {filename} - already exists.") + continue + start_time = time.time() + data_dict = fm_data[j] + np.save(filename, data_dict) + print("Time taken for %d = %f" % (j, time.time() - start_time)) diff --git a/ppsci/utils/__init__.py b/ppsci/utils/__init__.py index 3382eee856..bc528205a6 100644 --- a/ppsci/utils/__init__.py +++ b/ppsci/utils/__init__.py @@ -35,6 +35,7 @@ from ppsci.utils.save_load import load_checkpoint from ppsci.utils.save_load import load_pretrain from ppsci.utils.save_load import save_checkpoint +from ppsci.utils.sdf import signed_distance_field from ppsci.utils.symbolic import lambdify from ppsci.utils.writer import save_csv_file from ppsci.utils.writer import save_tecplot_file @@ -63,4 +64,5 @@ "load_pretrain", "save_checkpoint", "lambdify", + "signed_distance_field", ] diff --git a/ppsci/arch/physicsnemo/utils/sdf.py b/ppsci/utils/sdf.py similarity index 95% rename from ppsci/arch/physicsnemo/utils/sdf.py rename to ppsci/utils/sdf.py index 914b54bd76..e789cdfe33 100644 --- a/ppsci/arch/physicsnemo/utils/sdf.py +++ b/ppsci/utils/sdf.py @@ -1,22 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino -# ruff: noqa: F401 - -import numpy as np import warp as wp from numpy.typing import NDArray diff --git a/requirements.txt b/requirements.txt index 7efcb16d5f..58c8be4a3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,12 +2,15 @@ colorlog einops h5py hydra-core +hydra-core imageio +importlib_metadata matplotlib meshio==5.3.4 numpy>=1.20.0,<2.0.0 pydantic>=2.5.0 pyevtk +pyvista==0.34.2 pyyaml requests scikit-learn<1.5.0 @@ -15,6 +18,9 @@ scikit-optimize scipy seaborn sympy +termcolor tqdm +treelib typing-extensions +warp-lang wget From 7074ccf6957723a14dc377f7654542e61815548e Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Thu, 29 May 2025 02:12:28 +0800 Subject: [PATCH 08/11] feat(domino): refactor physicsnemo --- examples/domino/test.py | 72 +- ppsci/arch/physicsnemo.py | 1993 +++++++++++++++++++++++++ ppsci/data/dataset/domino_datapipe.py | 5 + 3 files changed, 2042 insertions(+), 28 deletions(-) create mode 100644 ppsci/arch/physicsnemo.py diff --git a/examples/domino/test.py b/examples/domino/test.py index 2a0326f818..9b48e0de28 100644 --- a/examples/domino/test.py +++ b/examples/domino/test.py @@ -1,18 +1,30 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. - +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + +""" +This code defines a distributed pipeline for testing the DoMINO model on +CFD datasets. It includes the instantiating the DoMINO model and datapipe, +automatically loading the most recent checkpoint, reading the VTP/VTU/STL +testing files, calculation of parameters required for DoMINO model and +evaluating the model in parallel using DataParallel across multiple +GPUs. This is a common recipe that enables training of combined models for surface +and volume as well either of them separately. The model predictions are loaded in +the the VTP/VTU files and saved in the specified directory. The eval tab in +config.yaml can be used to specify the input and output directories. +""" import os import re @@ -20,34 +32,36 @@ import hydra import numpy as np import paddle +import paddle.distributed as dist import pyvista as pv import vtk from hydra.utils import to_absolute_path from omegaconf import DictConfig from omegaconf import OmegaConf from paddle import DataParallel +from scipy.spatial import KDTree from vtk.util import numpy_support -from ppsci.arch.physicsnemo.distributed import DistributedManager -from ppsci.arch.physicsnemo.models.model import DoMINO -from ppsci.arch.physicsnemo.utils.domino.utils import KDTree -from ppsci.arch.physicsnemo.utils.domino.utils import cal_normal_positional_encoding -from ppsci.arch.physicsnemo.utils.domino.utils import calculate_center_of_mass -from ppsci.arch.physicsnemo.utils.domino.utils import create_directory -from ppsci.arch.physicsnemo.utils.domino.utils import create_grid -from ppsci.arch.physicsnemo.utils.domino.utils import get_fields -from ppsci.arch.physicsnemo.utils.domino.utils import get_filenames -from ppsci.arch.physicsnemo.utils.domino.utils import get_node_to_elem -from ppsci.arch.physicsnemo.utils.domino.utils import get_volume_data -from ppsci.arch.physicsnemo.utils.domino.utils import normalize -from ppsci.arch.physicsnemo.utils.domino.utils import unnormalize -from ppsci.arch.physicsnemo.utils.domino.utils import write_to_vtp -from ppsci.arch.physicsnemo.utils.domino.utils import write_to_vtu -from ppsci.arch.physicsnemo.utils.sdf import signed_distance_field +from ppsci.arch.physicsnemo import DoMINO +from ppsci.arch.physicsnemo import create_directory +from ppsci.arch.physicsnemo import get_fields +from ppsci.arch.physicsnemo import get_node_to_elem +from ppsci.arch.physicsnemo import get_volume_data +from ppsci.arch.physicsnemo import write_to_vtp +from ppsci.arch.physicsnemo import write_to_vtu +from ppsci.data.dataset.domino_datapipe import cal_normal_positional_encoding +from ppsci.data.dataset.domino_datapipe import calculate_center_of_mass +from ppsci.data.dataset.domino_datapipe import create_grid +from ppsci.data.dataset.domino_datapipe import get_filenames +from ppsci.data.dataset.domino_datapipe import normalize +from ppsci.data.dataset.domino_datapipe import unnormalize +from ppsci.utils.sdf import signed_distance_field AIR_DENSITY = 1.205 STREAM_VELOCITY = 30.00 +paddle.set_device("gpu") + def loss_fn(output, target): masked_loss = paddle.mean(((output - target) ** 2.0), (0, 1, 2)) @@ -303,9 +317,7 @@ def main(cfg: DictConfig): model_type = cfg.model.model_type - # initialize distributed manager - DistributedManager.initialize() - dist = DistributedManager() + dist.init_parallel_env() if model_type == "volume" or model_type == "combined": volume_variable_names = list(cfg.variables.volume.solution.keys()) @@ -357,10 +369,9 @@ def main(cfg: DictConfig): print("Model loaded") - if dist.world_size > 1: + if paddle.distributed.get_world_size() > 1: model = DataParallel( model, - find_unused_parameters=dist.find_unused_parameters, ) dirnames_per_gpu = get_filenames(input_path) @@ -657,7 +668,12 @@ def main(cfg: DictConfig): } prediction_vol, prediction_surf = test_step( - data_dict, model, dist.device, cfg, vol_factors, surf_factors + data_dict, + model, + paddle.distributed.get_rank(), + cfg, + vol_factors, + surf_factors, ) if prediction_surf is not None: diff --git a/ppsci/arch/physicsnemo.py b/ppsci/arch/physicsnemo.py new file mode 100644 index 0000000000..1516694bfd --- /dev/null +++ b/ppsci/arch/physicsnemo.py @@ -0,0 +1,1993 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino + + +import glob +import math +import os +import re +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import NewType +from typing import Optional +from typing import Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import warp as wp +from paddle.amp import GradScaler +from paddle.optimizer.lr import LRScheduler +from scipy.spatial import KDTree + +import ppsci +from ppsci.utils import logger + +optimizer = NewType("optimizer", paddle.optimizer) +scheduler = NewType("scheduler", LRScheduler) +scaler = NewType("scaler", GradScaler) + +try: + import pyvista as pv + + PV_AVAILABLE = True +except ImportError: + PV_AVAILABLE = False +try: + import vtk + from vtk import vtkDataSetTriangleFilter + from vtk.util import numpy_support + + VTK_AVAILABLE = True +except ImportError: + VTK_AVAILABLE = False + + +def unstandardize(field, mean, std): + """Function to unstandardize fields""" + return field * std + mean + + +def write_to_vtp(polydata, filename): + """Function to write polydata to vtp""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def write_to_vtu(polydata, filename): + """Function to write polydata to vtu""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLUnstructuredGridWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def extract_surface_triangles(tet_mesh): + """Extracts the surface triangles from a triangular mesh.""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + if not PV_AVAILABLE: + raise ImportError("PyVista is not installed. This function cannot be used.") + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tet_mesh) + surface_filter.Update() + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise ValueError("Face is not a triangle") + + return triangle_indices + + +def convert_to_tet_mesh(polydata): + """Function to convert tet to stl""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + # Create a VTK DataSetTriangleFilter object + tet_filter = vtkDataSetTriangleFilter() + tet_filter.SetInputData(polydata) + tet_filter.Update() # Update to apply the filter + + # Get the output as an UnstructuredGrid + # tet_mesh = pv.wrap(tet_filter.GetOutput()) + tet_mesh = tet_filter.GetOutput() + return tet_mesh + + +def get_node_to_elem(polydata): + """Function to convert node to elem""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + c2p = vtk.vtkPointDataToCellData() + c2p.SetInputData(polydata) + c2p.Update() + cell_data = c2p.GetOutput() + return cell_data + + +def get_fields_from_cell(ptdata, var_list): + """Function to get fields from elem""" + fields = [] + for var in var_list: + variable = ptdata.GetArray(var) + num_tuples = variable.GetNumberOfTuples() + cell_fields = [] + for j in range(num_tuples): + variable_value = np.array(variable.GetTuple(j)) + cell_fields.append(variable_value) + cell_fields = np.asarray(cell_fields) + fields.append(cell_fields) + fields = np.transpose(np.asarray(fields), (1, 0)) + + return fields + + +def get_fields(data, variables): + """Function to get fields from VTP/VTU""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + fields = [] + for array_name in variables: + try: + array = data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = numpy_support.vtk_to_numpy(array).reshape( + array.GetNumberOfTuples(), array.GetNumberOfComponents() + ) + fields.append(array_data) + return fields + + +def get_vertices(polydata): + """Function to get vertices""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(points.GetData()) + return vertices + + +def get_volume_data(polydata, variables): + """Function to get volume data""" + vertices = get_vertices(polydata) + point_data = polydata.GetPointData() + + fields = get_fields(point_data, variables) + + return vertices, fields + + +def get_surface_data(polydata, variables): + """Function to get surface data""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) + + point_data = polydata.GetPointData() + fields = [] + for array_name in variables: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + fields.append(array_data) + + polys = polydata.GetPolys() + if polys is None: + raise ValueError("Failed to get polygons from the polydata.") + polys.InitTraversal() + edges = [] + id_list = vtk.vtkIdList() + for _ in range(polys.GetNumberOfCells()): + polys.GetNextCell(id_list) + num_ids = id_list.GetNumberOfIds() + edges = [ + (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) + ] + + return vertices, fields, edges + + +def nd_interpolator(coodinates, field, grid): + """Function to for nd interpolation""" + interp_func = KDTree(coodinates[0]) + dd, ii = interp_func.query(grid, k=2) + + field_grid = field[ii] + field_grid = np.float32(np.mean(field_grid, (3))) + return field_grid + + +def pad_inp(arr, npoin, pad_value=0.0): + """Function for padding arrays""" + arr_pad = pad_value * np.ones( + (npoin - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=np.float32 + ) + arr_padded = np.concatenate((arr, arr_pad), axis=0) + return arr_padded + + +def shuffle_array_without_sampling(arr): + """Function for shuffline arrays without sampling.""" + idx = np.arange(arr.shape[0]) + np.random.shuffle(idx) + return arr[idx], idx + + +def create_directory(filepath): + """Function to create directories""" + if not os.path.exists(filepath): + os.makedirs(filepath) + + +def calculate_pos_encoding(nx, d=8): + """Function for calculating positional encoding""" + vec = [] + for k in range(int(d / 2)): + vec.append(np.sin(nx / 10000 ** (2 * (k) / d))) + vec.append(np.cos(nx / 10000 ** (2 * (k) / d))) + return vec + + +def combine_dict(old_dict, new_dict): + """Function to combine dictionaries""" + for j in old_dict.keys(): + old_dict[j] += new_dict[j] + return old_dict + + +def merge(*lists): + """Function to merge lists""" + newlist = lists[:] + for x in lists: + if x not in newlist: + newlist.extend(x) + return newlist + + +def mean_std_sampling(field, mean, std, tolerance=3.0): + """Function for mean/std based sampling""" + idx_all = [] + for v in range(field.shape[-1]): + fv = field[:, v] + idx = np.where( + (fv > mean[v] + tolerance * std[v]) | (fv < mean[v] - tolerance * std[v]) + ) + if len(idx[0]) != 0: + idx_all += list(idx[0]) + + return idx_all + + +def dict_to_device(state_dict, device): + """Function to load dictionary to device""" + new_state_dict = {} + for k, v in state_dict.items(): + new_state_dict[k] = v.to(device) + return new_state_dict + + +class BallQuery(paddle.autograd.PyLayer): + """ + Warp based Ball Query. + """ + + @wp.kernel + def ball_query( + points1: wp.array(dtype=wp.vec3), + points2: wp.array(dtype=wp.vec3), + grid: wp.uint64, + k: wp.int32, + radius: wp.float32, + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + ): + + # Get index of point1 + tid = wp.tid() + + # Get position from points1 + pos = points1[tid] + + # particle contact + neighbors = wp.hash_grid_query(grid, pos, radius) + + # Keep track of the number of neighbors found + nr_found = wp.int32(0) + + # loop through neighbors to compute density + for index in neighbors: + # Check if outside the radius + pos2 = points2[index] + if wp.length(pos - pos2) > radius: + continue + + # Add neighbor to the list + mapping[0, tid, nr_found] = index + + # Increment the number of neighbors found + nr_found += 1 + + # Break if we have found enough neighbors + if nr_found == k: + num_neighbors[0, tid] = k + break + + # Set the number of neighbors + num_neighbors[0, tid] = nr_found + + @wp.kernel + def sparse_ball_query( + points2: wp.array(dtype=wp.vec3), + mapping: wp.array3d(dtype=wp.int32), + num_neighbors: wp.array2d(dtype=wp.int32), + outputs: wp.array4d(dtype=wp.float32), + ): + # Get index of point1 + p1 = wp.tid() + + # Get number of neighbors + k = num_neighbors[0, p1] + + # Loop through neighbors + for _k in range(k): + # Get point2 index + index = mapping[0, p1, _k] + + # Get position from points2 + pos = points2[index] + + # Set the output + outputs[0, p1, _k, 0] = pos[0] + outputs[0, p1, _k, 1] = pos[1] + outputs[0, p1, _k, 2] = pos[2] + + @staticmethod + def forward( + ctx, + points1, + points2, + lengths1, + lengths2, + k, + radius, + hash_grid, + ): + # Only works for batch size 1 + if points1.shape[0] != 1: + raise AssertionError("nly works for batch size 1") + + # Convert from paddle to warp + ctx.points1 = wp.from_paddle( + points1[0], dtype=wp.vec3, requires_grad=points1.stop_gradient + ) + ctx.points2 = wp.from_paddle( + points2[0], dtype=wp.vec3, requires_grad=points2.stop_gradient + ) + ctx.lengths1 = wp.from_paddle(lengths1, dtype=wp.int32, requires_grad=False) + ctx.lengths2 = wp.from_paddle(lengths2, dtype=wp.int32, requires_grad=False) + ctx.k = k + ctx.radius = radius + + # Allocate the mapping and outputs + mapping = paddle.zeros([1, points1.shape[1], k], dtype=paddle.int32) + mapping.stop_gradient = False + ctx.mapping = wp.from_paddle(mapping, dtype=wp.int32, requires_grad=False) + num_neighbors = paddle.zeros([1, points1.shape[1]], dtype=paddle.int32) + num_neighbors.stop_gradient = False + ctx.num_neighbors = wp.from_paddle( + num_neighbors, dtype=wp.int32, requires_grad=False + ) + outputs = paddle.zeros([1, points1.shape[1], k, 3], dtype=paddle.float32) + outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient + ctx.outputs = wp.from_paddle(outputs, dtype=wp.float32) + outputs.stop_gradient = points1.stop_gradient or points2.stop_gradient + + # Make grid + ctx.hash_grid = hash_grid + + # Build the grid + ctx.hash_grid.build(ctx.points2, radius) + + # Run the kernel to get mapping + wp.launch( + BallQuery.ball_query, + inputs=[ + ctx.points1, + ctx.points2, + ctx.hash_grid.id, + k, + radius, + ], + outputs=[ + ctx.mapping, + ctx.num_neighbors, + ], + dim=[ctx.points1.shape[0]], + ) + + # Run the kernel to get outputs + wp.launch( + BallQuery.sparse_ball_query, + inputs=[ + ctx.points2, + ctx.mapping, + ctx.num_neighbors, + ], + outputs=[ + ctx.outputs, + ], + dim=[ctx.points1.shape[0]], + ) + + return ( + wp.to_paddle(ctx.mapping), + wp.to_paddle(ctx.num_neighbors), + wp.to_paddle(ctx.outputs), + ) + + @staticmethod + def backward(ctx, grad_mapping, grad_num_neighbors, grad_outputs): + # Map incoming paddle grads to our output variable + ctx.outputs.grad = wp.from_paddle(grad_outputs, dtype=wp.float32) + + # Run the kernel in adjoint mode + wp.launch( + BallQuery.sparse_ball_query, + inputs=[ + ctx.points2, + ctx.mapping, + ctx.num_neighbors, + ], + outputs=[ + ctx.outputs, + ], + adj_inputs=[ctx.points2.grad, ctx.mapping.grad, ctx.num_neighbors.grad], + adj_outputs=[ + ctx.outputs.grad, + ], + dim=[ctx.points1.shape[0]], + adjoint=True, + ) + + # Return the gradients + return ( + wp.to_paddle(ctx.points1.grad).unsqueeze(0), + wp.to_paddle(ctx.points2.grad).unsqueeze(0), + None, + None, + None, + None, + None, + ) + + +def kaiming_init(layer): + if isinstance(layer, (nn.layer.conv._ConvNd, nn.Linear)): + print(f"layer: {layer} ") + init_kaimingUniform = paddle.nn.initializer.KaimingUniform( + nonlinearity="leaky_relu", negative_slope=math.sqrt(5) + ) + init_kaimingUniform(layer.weight) + if layer.bias is not None: + fan_in, _ = ppsci.utils.initializer._calculate_fan_in_and_fan_out( + layer.weight + ) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init_uniform = paddle.nn.initializer.Uniform(low=-bound, high=bound) + init_uniform(layer.bias) + + +def scale_sdf(sdf): + """Function to scale SDF""" + return sdf / (0.4 + abs(sdf)) + + +def calculate_gradient(sdf): + """Function to calculate the gradients of SDF""" + m, n, o = sdf.shape[2], sdf.shape[3], sdf.shape[4] + sdf_x = sdf[:, :, 2:m, :, :] - sdf[:, :, 0 : m - 2, :, :] + sdf_y = sdf[:, :, :, 2:n, :] - sdf[:, :, :, 0 : n - 2, :] + sdf_z = sdf[:, :, :, :, 2:o] - sdf[:, :, :, :, 0 : o - 2] + + sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 0, 1), mode="constant", value=0.0) + sdf_x = F.pad(x=sdf_x, pad=(0, 0, 0, 0, 1, 0), mode="constant", value=0.0) + sdf_y = F.pad(x=sdf_y, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=0.0) + sdf_y = F.pad(x=sdf_y, pad=(0, 0, 1, 0, 0, 0), mode="constant", value=0.0) + sdf_z = F.pad(x=sdf_z, pad=(0, 1, 0, 0, 0, 0), mode="constant", value=0.0) + sdf_z = F.pad(x=sdf_z, pad=(1, 0, 0, 0, 0, 0), mode="constant", value=0.0) + + return sdf_x, sdf_y, sdf_z + + +def binarize_sdf(sdf): + """Function to calculate the binarize the SDF""" + sdf = paddle.where(sdf >= 0, 0.0, 1.0).to(dtype=sdf.dtype) + return sdf + + +class BallQueryLayer(paddle.nn.Layer): + """ + Paddle layer for differentiable and accelerated Ball Query + operation using Warp. + Args: + k (int): Number of neighbors. + radius (float): Radius of influence. + grid_size (int): Uniform grid resolution + """ + + def __init__(self, k, radius, grid_size=32): + super().__init__() + wp.init() + self.k = k + self.radius = radius + self.hash_grid = wp.HashGrid(grid_size, grid_size, grid_size) + + def forward(self, points1, points2, lengths1, lengths2): + return BallQuery.apply( + points1, + points2, + lengths1, + lengths2, + self.k, + self.radius, + self.hash_grid, + ) + + +def _get_checkpoint_filename( + path: str, + base_name: str = "checkpoint", + index: Union[int, None] = None, + saving: bool = False, + model_type: str = "mdlus", +) -> str: + """Gets the file name /path of checkpoint + + This function has three different ways of providing a checkout filename: + - If supplied an index this will return the checkpoint name using that index. + - If index is None and saving is false, this will get the checkpoint with the + largest index (latest save). + - If index is None and saving is true, it will return the next valid index file name + which is calculated by indexing the largest checkpoint index found by one. + + Parameters + ---------- + path : str + Path to checkpoints + base_name: str, optional + Base file name, by default checkpoint + index : Union[int, None], optional + Checkpoint index, by default None + saving : bool, optional + Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for Modulus models and "pdparams" for models + + + Returns + ------- + str + Checkpoint file name + """ + # Get model parallel rank so all processes in the first model parallel group + # can save their checkpoint. In the case without model parallelism, + # model_parallel_rank should be the same as the process rank itself and + # only rank 0 saves + model_parallel_rank = 0 + + # Input file name + checkpoint_filename = str( + Path(path).resolve() / f"{base_name}.{model_parallel_rank}" + ) + + # File extension for Modulus models or PaddlePaddle models + file_extension = ".pdparams" + + # If epoch is provided load that file + if index is not None: + checkpoint_filename = checkpoint_filename + f".{index}" + checkpoint_filename += file_extension + # Otherwise try loading the latest epoch or rolling checkpoint + else: + file_names = [ + Path(fname).name + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ) + ] + + if len(file_names) > 0: + # If checkpoint from a null index save exists load that + # This is the most likely line to error since it will fail with + # invalid checkpoint names + file_idx = [ + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) + for fname in file_names + ] + file_idx.sort() + # If we are saving index by 1 to get the next free file name + if saving: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" + else: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" + checkpoint_filename += file_extension + else: + checkpoint_filename += ".0" + file_extension + + return checkpoint_filename + + +def _unique_model_names( + models: List[paddle.nn.Layer], +) -> Dict[str, paddle.nn.Layer]: + """Util to clean model names and index if repeat names, will also strip DDP wrappers + if they exist. + + Parameters + ---------- + model : List[paddle.nn.Layer] + List of models to generate names for + + Returns + ------- + Dict[str, paddle.nn.Layer] + Dictionary of model names and respective modules + """ + # Loop through provided models and set up base names + model_dict = {} + for model0 in models: + if hasattr(model0, "module"): + # Strip out DDP layer + model0 = model0.module + # Base name of model is meta.name unless paddle model + base_name = model0.__class__.__name__ + # if isinstance(model0, modulus): + # base_name = model0.meta.name + # If we have multiple models of the same name, introduce another index + if base_name in model_dict: + model_dict[base_name].append(model0) + else: + model_dict[base_name] = [model0] + + # Set up unique model names if needed + output_dict = {} + for key, model in model_dict.items(): + if len(model) > 1: + for i, model0 in enumerate(model): + output_dict[key + str(i)] = model0 + else: + output_dict[key] = model[0] + + return output_dict + + +def save_checkpoint( + path: str, + models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Training checkpoint saving utility + + This will save a training checkpoint in the provided path following the file naming + convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint + method in Modulus core can then be used to read this file. + + Parameters + ---------- + path : str + Path to save the training checkpoint + models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional + A single or list of PaddlePaddle models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler. Will attempt to save on in static capture if none provided, by + default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none this will save the checkpoint in the next + valid index, by default None + metadata : Optional[Dict[str, Any]], optional + Additional metadata to save, by default None + """ + # Create checkpoint directory if it does not exist + if not Path(path).is_dir(): + logger.warning( + f"Output directory {path} does not exist, will " "attempt to create" + ) + Path(path).mkdir(parents=True, exist_ok=True) + + # == Saving model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = "pdparams" + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type=model_type + ) + + # Save state dictionary + paddle.save(model.state_dict(), file_name) + logger.info(f"Saved model state dictionary: {file_name}") + + # == Saving training checkpoint == + checkpoint_dict = {} + # Optimizer state dict + if optimizer: + checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() + + # Scheduler state dict + if scheduler: + checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() + + # Scheduler state dict + if scaler: + checkpoint_dict["scaler_state_dict"] = scaler.state_dict() + # Static capture is being used, save its grad scaler + # if _StaticCapture._amp_scalers: + # checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() + + # Output file name + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pdparams" + ) + if epoch: + checkpoint_dict["epoch"] = epoch + if metadata: + checkpoint_dict["metadata"] = metadata + # Save checkpoint to memory + if bool(checkpoint_dict): + paddle.save( + checkpoint_dict, + output_filename, + ) + logger.info(f"Saved training checkpoint: {output_filename}") + + +def load_checkpoint( + path: str, + models: Union[paddle.nn.Layer, List[paddle.nn.Layer], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata_dict: Optional[Dict[str, Any]] = {}, +) -> int: + """Checkpoint loading utility + + This loader is designed to be used with the save checkpoint utility in Modulus + Launch. Given a path, this method will try to find a checkpoint and load state + dictionaries into the provided training objects. + + Parameters + ---------- + path : str + Path to training checkpoint + models : Union[paddle.nn.Layer, List[paddle.nn.Layer], None], optional + A single or list of models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler, by default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none is provided this will attempt to load the + checkpoint with the largest index, by default None + metadata_dict: Optional[Dict[str, Any]], optional + Dictionary to store metadata from the checkpoint, by default None + + Returns + ------- + int + Loaded epoch + """ + # Check if checkpoint directory exists + if not Path(path).is_dir(): + logger.warning( + f"Provided checkpoint directory {path} does not exist, skipping load" + ) + return 0 + + # == Loading model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = "pdparams" + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, model_type=model_type + ) + if not Path(file_name).exists(): + logger.error( + f"Could not find valid model file {file_name}, skipping load" + ) + continue + # Load state dictionary + model.set_state_dict(paddle.load(file_name)) + + logger.info(f"Loaded model state dictionary {file_name}") + + # == Loading training checkpoint == + checkpoint_filename = _get_checkpoint_filename( + path, index=epoch, model_type="pdparams" + ) + if not Path(checkpoint_filename).is_file(): + logger.warning("Could not find valid checkpoint file, skipping load") + return 0 + + checkpoint_dict = paddle.load(checkpoint_filename) + logger.info(f"Loaded checkpoint file {checkpoint_filename}") + + # Optimizer state dict + if optimizer and "optimizer_state_dict" in checkpoint_dict: + optimizer.set_state_dict(checkpoint_dict["optimizer_state_dict"]) + logger.info("Loaded optimizer state dictionary") + + # Scheduler state dict + if scheduler and "scheduler_state_dict" in checkpoint_dict: + scheduler.set_state_dict(checkpoint_dict["scheduler_state_dict"]) + logger.info("Loaded scheduler state dictionary") + + # Scaler state dict + if scaler and "scaler_state_dict" in checkpoint_dict: + scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) + logger.info("Loaded grad scaler state dictionary") + + epoch = 0 + if "epoch" in checkpoint_dict: + epoch = checkpoint_dict["epoch"] + # Update metadata if exists and the dictionary object is provided + metadata = checkpoint_dict.get("metadata", {}) + for key, value in metadata.items(): + metadata_dict[key] = value + + return epoch + + +class BQWarp(nn.Layer): + """Warp based ball-query layer""" + + def __init__( + self, + input_features, + grid_resolution=[256, 96, 64], + radius=0.25, + neighbors_in_radius=10, + ): + super().__init__() + self.ball_query_layer = BallQueryLayer(neighbors_in_radius, radius) + self.grid_resolution = grid_resolution + + def forward(self, x, p_grid, reverse_mapping=True): + batch_size = x.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + p1 = nx * ny * nz + p2 = x.shape[1] + + if reverse_mapping: + lengths1 = paddle.full((batch_size,), p1, dtype=paddle.int32) + lengths2 = paddle.full((batch_size,), p2, dtype=paddle.int32) + mapping, num_neighbors, outputs = self.ball_query_layer( + p_grid, + x, + lengths1, + lengths2, + ) + else: + lengths1 = paddle.full((batch_size,), p2, dtype=paddle.int32) + lengths2 = paddle.full((batch_size,), p1, dtype=paddle.int32) + mapping, num_neighbors, outputs = self.ball_query_layer( + x, + p_grid, + lengths1, + lengths2, + ) + + return mapping, outputs + + +class GeoConvOut(nn.Layer): + """Geometry layer to project STLs on grids""" + + def __init__(self, input_features, model_parameters, grid_resolution=[256, 96, 64]): + super().__init__() + base_neurons = model_parameters.base_neurons + + self.fc1 = nn.Linear(input_features, base_neurons) + self.fc2 = nn.Linear(base_neurons, int(base_neurons / 2)) + self.fc3 = nn.Linear(int(base_neurons / 2), model_parameters.base_neurons_out) + + self.grid_resolution = grid_resolution + + self.activation = F.relu + + def forward(self, x, radius=0.025, neighbors_in_radius=10): + batch_size = x.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + + mask = abs(x - 0) > 1e-6 + + x = self.activation(self.fc1(x)) + x = self.activation(self.fc2(x)) + x = F.tanh(self.fc3(x)) + mask = mask[:, :, :, 0:1].expand( + [mask.shape[0], mask.shape[1], mask.shape[2], x.shape[-1]] + ) + + # paddle does not support multiplication with boolean tensors, + # so we convert the mask to float + x = paddle.sum(x * mask.to(dtype=x.dtype), 2) + + x = paddle.reshape(x, (batch_size, x.shape[-1], nx, ny, nz)) + return x + + +class GeoProcessor(nn.Layer): + """Geometry processing layer using CNNs""" + + def __init__(self, input_filters, model_parameters): + super().__init__() + base_filters = model_parameters.base_filters + self.conv1 = nn.Conv3D( + input_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv_bn1 = nn.BatchNorm3D(int(base_filters)) + self.conv2 = nn.Conv3D( + base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn2 = nn.BatchNorm3D(int(2 * base_filters)) + self.conv3 = nn.Conv3D( + 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn3 = nn.BatchNorm3D(int(4 * base_filters)) + self.conv3_1 = nn.Conv3D( + 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv4 = nn.Conv3D( + 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv_bn4 = nn.BatchNorm3D(int(2 * base_filters)) + self.conv5 = nn.Conv3D( + 4 * base_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv_bn5 = nn.BatchNorm3D(int(base_filters)) + self.conv6 = nn.Conv3D( + 2 * base_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv_bn6 = nn.BatchNorm3D(int(input_filters)) + self.conv7 = nn.Conv3D( + 2 * input_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv8 = nn.Conv3D(input_filters, 1, kernel_size=3, padding="same") + self.avg_pool = paddle.nn.AvgPool3D((2, 2, 2)) + self.max_pool = nn.MaxPool3D(2) + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.activation = F.relu + self.batch_norm = False + + def forward(self, x): + # Encoder + x0 = x + if self.batch_norm: + x = self.activation(self.conv_bn1(self.conv1(x))) + else: + x = self.activation(self.conv1(x)) + x = self.max_pool(x) + x1 = x + if self.batch_norm: + x = self.activation(self.conv_bn2(self.conv2(x))) + else: + x = self.activation((self.conv2(x))) + x = self.max_pool(x) + + x2 = x + if self.batch_norm: + x = self.activation(self.conv_bn3(self.conv2(x))) + else: + x = self.activation((self.conv3(x))) + x = self.max_pool(x) + + # Processor loop + x = F.relu(self.conv3_1(x)) + + # Decoder + if self.batch_norm: + x = self.activation(self.conv_bn4(self.conv4(x))) + else: + x = self.activation((self.conv4(x))) + x = self.upsample(x) + x = paddle.concat((x, x2), axis=1) + + if self.batch_norm: + x = self.activation(self.conv_bn5(self.conv5(x))) + else: + x = self.activation((self.conv5(x))) + x = self.upsample(x) + x = paddle.concat((x, x1), axis=1) + if self.batch_norm: + x = self.activation(self.conv_bn6(self.conv6(x))) + else: + x = self.activation((self.conv6(x))) + x = self.upsample(x) + x = paddle.concat((x, x0), axis=1) + + x = self.activation(self.conv7(x)) + x = self.conv8(x) + + return x + + +class GeometryRep(nn.Layer): + """Geometry representation from STLs block""" + + def __init__(self, input_features, model_parameters=None): + super().__init__() + geometry_rep = model_parameters.geometry_rep + + self.bq_warp_short = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=geometry_rep.geo_conv.radius_short, + ) + + self.bq_warp_long = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=geometry_rep.geo_conv.radius_long, + ) + + self.geo_conv_out = GeoConvOut( + input_features=input_features, + model_parameters=geometry_rep.geo_conv, + grid_resolution=model_parameters.interp_res, + ) + + self.geo_processor_short_range = GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ) + self.geo_processor_long_range = GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ) + self.geo_processor_sdf = GeoProcessor( + input_filters=6, model_parameters=geometry_rep.geo_processor + ) + self.activation = F.relu + self.radius_short = geometry_rep.geo_conv.radius_short + self.radius_long = geometry_rep.geo_conv.radius_long + self.hops = geometry_rep.geo_conv.hops + + def forward(self, x, p_grid, sdf): + + # Expand SDF + sdf = paddle.unsqueeze(sdf, 1) + + # Calculate short-range geoemtry dependency + mapping, k_short = self.bq_warp_short(x, p_grid) + x_encoding_short = self.geo_conv_out(k_short) + + # Calculate long-range geometry dependency + mapping, k_long = self.bq_warp_long(x, p_grid) + x_encoding_long = self.geo_conv_out(k_long) + + # Scaled sdf to emphasis on surface + scaled_sdf = scale_sdf(sdf) + # Binary sdf + binary_sdf = binarize_sdf(sdf) + # Gradients of SDF + sdf_x, sdf_y, sdf_z = calculate_gradient(sdf) + + # Propagate information in the geometry enclosed BBox + for _ in range(self.hops): + dx = self.geo_processor_short_range(x_encoding_short) / self.hops + x_encoding_short = x_encoding_short + dx + + # Propagate information in the computational domain BBox + for _ in range(self.hops): + dx = self.geo_processor_long_range(x_encoding_long) / self.hops + x_encoding_long = x_encoding_long + dx + + # Process SDF and its computed features + sdf = paddle.concat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) + sdf_encoding = self.geo_processor_sdf(sdf) + + # Geometry encoding comprised of short-range, long-range and SDF features + encoding_g = paddle.concat((x_encoding_short, sdf_encoding, x_encoding_long), 1) + + return encoding_g + + +class NNBasisFunctions(nn.Layer): + """Basis function layer for point clouds""" + + def __init__(self, input_features, model_parameters=None): + super(NNBasisFunctions, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + facets = x + facets = self.activation(self.fc1(facets)) + facets = self.activation(self.fc2(facets)) + facets = self.fc3(facets) + + return facets + + +class ParameterModel(nn.Layer): + """Layer to encode parameters such as inlet velocity and air density""" + + def __init__(self, input_features, model_parameters=None): + super(ParameterModel, self).__init__() + self.input_features = input_features + + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + + self.activation = F.relu + + def forward(self, x, padded_value=-10): + params = x + params = self.activation(self.fc1(params)) + params = self.activation(self.fc2(params)) + params = self.fc3(params) + + return params + + +class AggregationModel(nn.Layer): + """Layer to aggregate local geometry encoding with basis functions""" + + def __init__( + self, input_features, output_features, model_parameters=None, new_change=True + ): + super(AggregationModel, self).__init__() + self.input_features = input_features + self.output_features = output_features + self.new_change = new_change + base_layer = model_parameters.base_layer + self.fc1 = nn.Linear(self.input_features, base_layer) + self.fc2 = nn.Linear(base_layer, int(base_layer)) + self.fc3 = nn.Linear(int(base_layer), int(base_layer)) + self.fc4 = nn.Linear(int(base_layer), int(base_layer)) + self.fc5 = nn.Linear(int(base_layer), self.output_features) + self.bn1 = nn.BatchNorm1D(base_layer) + self.bn2 = nn.BatchNorm1D(int(base_layer)) + self.bn3 = nn.BatchNorm1D(int(base_layer)) + self.bn4 = nn.BatchNorm1D(int(base_layer)) + self.activation = F.relu + + def forward(self, x): + out = self.activation(self.fc1(x)) + out = self.activation(self.fc2(out)) + out = self.activation(self.fc3(out)) + out = self.activation(self.fc4(out)) + + out = self.fc5(out) + + return out + + +class DoMINO(nn.Layer): + """DoMINO model architecture + Parameters + ---------- + input_features : int + Number of point input features + output_features_vol : int + Number of output features in volume + output_features_surf : int + Number of output features on surface + model_parameters: dict + Dictionary of model parameters controlled by config.yaml + + Example + ------- + >>> from modulus.models.domino.model import DoMINO + >>> import os + >>> from hydra import compose, initialize + >>> from omegaconf import OmegaConf + >>> cfg = OmegaConf.register_new_resolver("eval", eval) + >>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"): + ... cfg = compose(config_name="config") + >>> cfg.model.model_type = "combined" + >>> model = DoMINO( + ... input_features=3, + ... output_features_vol=5, + ... output_features_surf=4, + ... model_parameters=cfg.model + ... ) + + Warp ... + >>> bsize = 1 + >>> nx, ny, nz = 128, 64, 48 + >>> num_neigh = 7 + >>> pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) + >>> pos_normals_com_vol = paddle.randn([bsize, 100, 3]) + >>> pos_normals_com_surface = paddle.randn([bsize, 100, 3]) + >>> geom_centers = paddle.randn([bsize, 100, 3]) + >>> grid = paddle.randn([bsize, nx, ny, nz, 3]) + >>> surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) + >>> sdf_grid = paddle.randn([bsize, nx, ny, nz]) + >>> sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) + >>> sdf_nodes = paddle.randn([bsize, 100, 1]) + >>> surface_coordinates = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) + >>> surface_normals = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) + >>> surface_sizes = paddle.randn([bsize, 100, 3]) + >>> surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) + >>> volume_coordinates = paddle.randn([bsize, 100, 3]) + >>> vol_grid_max_min = paddle.randn([bsize, 2, 3]) + >>> surf_grid_max_min = paddle.randn([bsize, 2, 3]) + >>> stream_velocity = paddle.randn([bsize, 1]) + >>> air_density = paddle.randn([bsize, 1]) + >>> input_dict = { + ... "pos_volume_closest": pos_normals_closest_vol, + ... "pos_volume_center_of_mass": pos_normals_com_vol, + ... "pos_surface_center_of_mass": pos_normals_com_surface, + ... "geometry_coordinates": geom_centers, + ... "grid": grid, + ... "surf_grid": surf_grid, + ... "sdf_grid": sdf_grid, + ... "sdf_surf_grid": sdf_surf_grid, + ... "sdf_nodes": sdf_nodes, + ... "surface_mesh_centers": surface_coordinates, + ... "surface_mesh_neighbors": surface_neighbors, + ... "surface_normals": surface_normals, + ... "surface_neighbors_normals": surface_neighbors_normals, + ... "surface_areas": surface_sizes, + ... "surface_neighbors_areas": surface_neighbors_sizes, + ... "volume_mesh_centers": volume_coordinates, + ... "volume_min_max": vol_grid_max_min, + ... "surface_min_max": surf_grid_max_min, + ... "stream_velocity": stream_velocity, + ... "air_density": air_density, + ... } + >>> output = model(input_dict) + Module ... + >>> print(f"{output[0].shape}, {output[1].shape}") + """ + + def __init__( + self, + input_features, + output_features_vol=None, + output_features_surf=None, + model_parameters=None, + ): + super(DoMINO, self).__init__() + self.input_features = input_features + self.output_features_vol = output_features_vol + self.output_features_surf = output_features_surf + + if self.output_features_vol is None and self.output_features_surf is None: + raise ValueError("Need to specify number of volume or surface features") + + self.num_variables_vol = output_features_vol + self.num_variables_surf = output_features_surf + self.grid_resolution = model_parameters.interp_res + self.surface_neighbors = model_parameters.surface_neighbors + self.use_surface_normals = model_parameters.use_surface_normals + self.use_only_normals = model_parameters.use_only_normals + self.encode_parameters = model_parameters.encode_parameters + self.param_scaling_factors = model_parameters.parameter_model.scaling_params + + if self.use_surface_normals: + if self.use_only_normals: + input_features_surface = input_features + 3 + else: + input_features_surface = input_features + 4 + else: + input_features_surface = input_features + + if self.encode_parameters: + # Defining the parameter model + base_layer_p = model_parameters.parameter_model.base_layer + self.parameter_model = ParameterModel( + input_features=2, model_parameters=model_parameters.parameter_model + ) + else: + base_layer_p = 0 + + self.geo_rep = GeometryRep( + input_features=input_features, + model_parameters=model_parameters, + ) + + # Basis functions for surface and volume + base_layer_nn = model_parameters.nn_basis_functions.base_layer + if self.output_features_surf is not None: + self.nn_basis_surf = nn.LayerList() + for _ in range(self.num_variables_surf): + self.nn_basis_surf.append( + NNBasisFunctions( + input_features=input_features_surface, + model_parameters=model_parameters.nn_basis_functions, + ) + ) + + if self.output_features_vol is not None: + self.nn_basis_vol = nn.LayerList() + for _ in range(self.num_variables_vol): + self.nn_basis_vol.append( + NNBasisFunctions( + input_features=input_features, + model_parameters=model_parameters.nn_basis_functions, + ) + ) + + # Positional encoding + position_encoder_base_neurons = model_parameters.position_encoder.base_neurons + if self.output_features_vol is not None: + if model_parameters.positional_encoding: + inp_pos_vol = 25 if model_parameters.use_sdf_in_basis_func else 12 + else: + inp_pos_vol = 7 if model_parameters.use_sdf_in_basis_func else 3 + + self.fc_p_vol = nn.Linear(inp_pos_vol, position_encoder_base_neurons) + + if self.output_features_surf is not None: + if model_parameters.positional_encoding: + inp_pos_surf = 12 + else: + inp_pos_surf = 3 + + self.fc_p_surf = nn.Linear(inp_pos_surf, position_encoder_base_neurons) + + # Positional encoding hidden layers + self.fc_p1 = nn.Linear( + position_encoder_base_neurons, position_encoder_base_neurons + ) + self.fc_p2 = nn.Linear( + position_encoder_base_neurons, position_encoder_base_neurons + ) + + # BQ for surface and volume + self.neighbors_in_radius = model_parameters.geometry_local.neighbors_in_radius + self.radius = model_parameters.geometry_local.radius + self.bq_warp = BQWarp( + input_features=input_features, + grid_resolution=model_parameters.interp_res, + radius=self.radius, + neighbors_in_radius=self.neighbors_in_radius, + ) + + base_layer_geo = model_parameters.geometry_local.base_layer + self.fc_1 = nn.Linear(self.neighbors_in_radius * 3, base_layer_geo) + self.fc_2 = nn.Linear(base_layer_geo, base_layer_geo) + self.activation = F.relu + + # Aggregation model + if self.output_features_surf is not None: + # Surface + self.agg_model_surf = nn.LayerList() + for _ in range(self.num_variables_surf): + self.agg_model_surf.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo + + base_layer_p, + output_features=1, + model_parameters=model_parameters.aggregation_model, + ) + ) + + if self.output_features_vol is not None: + # Volume + self.agg_model_vol = nn.LayerList() + for _ in range(self.num_variables_vol): + self.agg_model_vol.append( + AggregationModel( + input_features=position_encoder_base_neurons + + base_layer_nn + + base_layer_geo + + base_layer_p, + output_features=1, + model_parameters=model_parameters.aggregation_model, + ) + ) + + self.apply(kaiming_init) + + def geometry_encoder(self, geo_centers, p_grid, sdf): + """Function to return local geometry encoding""" + return self.geo_rep(geo_centers, p_grid, sdf) + + def position_encoder(self, encoding_node, eval_mode="volume"): + """Function to calculate positional encoding""" + if eval_mode == "volume": + x = self.activation(self.fc_p_vol(encoding_node)) + elif eval_mode == "surface": + x = self.activation(self.fc_p_surf(encoding_node)) + x = self.activation(self.fc_p1(x)) + x = self.fc_p2(x) + return x + + def geo_encoding_local_surface(self, encoding_g, volume_mesh_centers, p_grid): + """Function to calculate local geometry encoding from global encoding for surface""" + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + mapping = mapping.astype(paddle.int64) + mask = mapping != 0 + + geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) + geo_encoding = geo_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] + ) + sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) + sdf_encoding = sdf_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] + ) + geo_encoding_long = paddle.reshape( + encoding_g[:, 2], (batch_size, 1, nx * ny * nz) + ) + geo_encoding_long = geo_encoding_long.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] + ) + + geo_encoding_sampled = paddle.take_along_axis( + geo_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + sdf_encoding_sampled = paddle.take_along_axis( + sdf_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + geo_encoding_long_sampled = paddle.take_along_axis( + geo_encoding_long, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + + encoding_g = paddle.concat( + (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), + axis=2, + ) + encoding_g = self.activation(self.fc_1(encoding_g)) + encoding_g = self.fc_2(encoding_g) + + return encoding_g + + def geo_encoding_local(self, encoding_g, volume_mesh_centers, p_grid): + """Function to calculate local geometry encoding from global encoding""" + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + p_grid = paddle.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + mapping = mapping.astype(paddle.int64) + mask = mapping != 0 + + geo_encoding = paddle.reshape(encoding_g[:, 0], (batch_size, 1, nx * ny * nz)) + geo_encoding = geo_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding.shape[2]] + ) + sdf_encoding = paddle.reshape(encoding_g[:, 1], (batch_size, 1, nx * ny * nz)) + sdf_encoding = sdf_encoding.expand( + [batch_size, volume_mesh_centers.shape[1], sdf_encoding.shape[2]] + ) + geo_encoding_long = paddle.reshape( + encoding_g[:, 2], (batch_size, 1, nx * ny * nz) + ) + geo_encoding_long = geo_encoding_long.expand( + [batch_size, volume_mesh_centers.shape[1], geo_encoding_long.shape[2]] + ) + + geo_encoding_sampled = paddle.take_along_axis( + geo_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + sdf_encoding_sampled = paddle.take_along_axis( + sdf_encoding, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + geo_encoding_long_sampled = paddle.take_along_axis( + geo_encoding_long, axis=2, indices=mapping + ) * mask.to(dtype=geo_encoding.dtype) + + encoding_g = paddle.concat( + (geo_encoding_sampled, sdf_encoding_sampled, geo_encoding_long_sampled), + axis=2, + ) + encoding_g = self.activation(self.fc_1(encoding_g)) + encoding_g = self.fc_2(encoding_g) + + return encoding_g + + def calculate_solution_with_neighbors( + self, + surface_mesh_centers, + encoding_g, + encoding_node, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + inlet_velocity, + air_density, + ): + """Function to approximate solution given the neighborhood information""" + num_variables = self.num_variables_surf + nn_basis = self.nn_basis_surf + agg_model = self.agg_model_surf + num_sample_points = surface_mesh_neighbors.shape[2] + 1 + + if self.encode_parameters: + inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + [ + inlet_velocity.shape[0], + surface_mesh_centers.shape[1], + inlet_velocity.shape[2], + ] + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = paddle.unsqueeze(air_density, 1) + air_density = air_density.expand( + [ + air_density.shape[0], + surface_mesh_centers.shape[1], + air_density.shape[2], + ] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = paddle.concat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + + if self.use_surface_normals: + if self.use_only_normals: + surface_mesh_centers = paddle.concat( + (surface_mesh_centers, surface_normals), + axis=-1, + ) + surface_mesh_neighbors = paddle.concat( + ( + surface_mesh_neighbors, + surface_neighbors_normals, + ), + axis=-1, + ) + + else: + surface_mesh_centers = paddle.concat( + (surface_mesh_centers, surface_normals, 10**5 * surface_areas), + axis=-1, + ) + surface_mesh_neighbors = paddle.concat( + ( + surface_mesh_neighbors, + surface_neighbors_normals, + 10**5 * surface_neighbors_areas, + ), + axis=-1, + ) + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = surface_mesh_centers + else: + volume_m_c = surface_mesh_neighbors[:, :, p - 1] + noise = surface_mesh_centers - volume_m_c + dist = paddle.sqrt( + noise[:, :, 0:1] ** 2.0 + + noise[:, :, 1:2] ** 2.0 + + noise[:, :, 2:3] ** 2.0 + ) + basis_f = nn_basis[f](volume_m_c) + output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = paddle.concat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = paddle.concat((output_all, output_res), axis=-1) + + return output_all + + def calculate_solution( + self, + volume_mesh_centers, + encoding_g, + encoding_node, + inlet_velocity, + air_density, + eval_mode, + num_sample_points=20, + noise_intensity=50, + ): + """Function to approximate solution sampling the neighborhood information""" + if eval_mode == "volume": + num_variables = self.num_variables_vol + nn_basis = self.nn_basis_vol + agg_model = self.agg_model_vol + elif eval_mode == "surface": + num_variables = self.num_variables_surf + nn_basis = self.nn_basis_surf + agg_model = self.agg_model_surf + + if self.encode_parameters: + inlet_velocity = paddle.unsqueeze(inlet_velocity, 1) + inlet_velocity = inlet_velocity.expand( + [ + inlet_velocity.shape[0], + volume_mesh_centers.shape[1], + inlet_velocity.shape[2], + ] + ) + inlet_velocity = inlet_velocity / self.param_scaling_factors[0] + + air_density = paddle.unsqueeze(air_density, 1) + air_density = air_density.expand( + [ + air_density.shape[0], + volume_mesh_centers.shape[1], + air_density.shape[2], + ] + ) + air_density = air_density / self.param_scaling_factors[1] + + params = paddle.concat((inlet_velocity, air_density), axis=-1) + param_encoding = self.parameter_model(params) + + for f in range(num_variables): + for p in range(num_sample_points): + if p == 0: + volume_m_c = volume_mesh_centers + else: + noise = paddle.rand( + shape=volume_mesh_centers.shape, dtype=volume_mesh_centers.dtype + ) + noise = 2 * (noise - 0.5) + noise = noise / noise_intensity + dist = paddle.sqrt( + noise[:, :, 0:1] ** 2.0 + + noise[:, :, 1:2] ** 2.0 + + noise[:, :, 2:3] ** 2.0 + ) + volume_m_c = volume_mesh_centers + noise + basis_f = nn_basis[f](volume_m_c) + output = paddle.concat((basis_f, encoding_node, encoding_g), axis=-1) + if self.encode_parameters: + output = paddle.concat((output, param_encoding), axis=-1) + if p == 0: + output_center = agg_model[f](output) + else: + if p == 1: + output_neighbor = agg_model[f](output) * (1.0 / dist) + dist_sum = 1.0 / dist + else: + output_neighbor += agg_model[f](output) * (1.0 / dist) + dist_sum += 1.0 / dist + if num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = paddle.concat((output_all, output_res), axis=-1) + + return output_all + + def forward( + self, + data_dict, + ): + # Loading STL inputs, bounding box grids, precomputed SDF and scaling factors + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + # Parameters + stream_velocity = data_dict["stream_velocity"] + air_density = data_dict["air_density"] + + if self.output_features_vol is not None: + # Represent geometry on computational grid + # Computational domain grid + p_grid = data_dict["grid"] + sdf_grid = data_dict["sdf_grid"] + # Scaling factors + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + + # Normalize based on computational domain + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + encoding_g_vol = self.geo_rep(geo_centers_vol, p_grid, sdf_grid) + + # Normalize based on BBox around surface (car) + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) + + # SDF on volume mesh nodes + sdf_nodes = data_dict["sdf_nodes"] + # Positional encoding based on closest point on surface to a volume node + pos_volume_closest = data_dict["pos_volume_closest"] + # Positional encoding based on center of mass of geometry to volume node + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + encoding_node_vol = paddle.concat( + (sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), axis=-1 + ) + + # Calculate positional encoding on volume nodes + encoding_node_vol = self.position_encoder( + encoding_node_vol, eval_mode="volume" + ) + + if self.output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = self.geo_rep(geo_centers_surf, s_grid, sdf_surf_grid) + + # Positional encoding based on center of mass of geometry to surface node + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + encoding_node_surf = pos_surface_center_of_mass + + # Calculate positional encoding on surface centers + encoding_node_surf = self.position_encoder( + encoding_node_surf, eval_mode="surface" + ) + + encoding_g = 0.5 * encoding_g_surf + # Average the encodings + if self.output_features_vol is not None: + encoding_g += 0.5 * encoding_g_vol + + if self.output_features_vol is not None: + # Calculate local geometry encoding for volume + # Sampled points on volume + volume_mesh_centers = data_dict["volume_mesh_centers"] + encoding_g_vol = self.geo_encoding_local( + encoding_g, volume_mesh_centers, p_grid + ) + + # Approximate solution on volume node + output_vol = self.calculate_solution( + volume_mesh_centers, + encoding_g_vol, + encoding_node_vol, + stream_velocity, + air_density, + eval_mode="volume", + ) + else: + output_vol = None + + if self.output_features_surf is not None: + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + # Calculate local geometry encoding for surface + encoding_g_surf = self.geo_encoding_local_surface( + 0.5 * encoding_g_surf, surface_mesh_centers, s_grid + ) + + # Approximate solution on surface cell center + if not self.surface_neighbors: + output_surf = self.calculate_solution( + surface_mesh_centers, + encoding_g_surf, + encoding_node_surf, + stream_velocity, + air_density, + eval_mode="surface", + num_sample_points=1, + noise_intensity=500, + ) + else: + output_surf = self.calculate_solution_with_neighbors( + surface_mesh_centers, + encoding_g_surf, + encoding_node_surf, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + stream_velocity, + air_density, + ) + else: + output_surf = None + + return output_vol, output_surf + + +if __name__ == "__main__": + from hydra import compose + from hydra import initialize + from omegaconf import OmegaConf + + if paddle.device.cuda.device_count() >= 1: + paddle.set_device("gpu") + else: + paddle.set_device("cpu") + cfg = OmegaConf.register_new_resolver("eval", eval) + with initialize(version_base="1.3", config_path="../../scripts/conf"): + cfg = compose(config_name="config") + cfg.model.model_type = "combined" + model = DoMINO( + input_features=3, + output_features_vol=5, + output_features_surf=4, + model_parameters=cfg.model, + ) + + bsize = 1 + nx, ny, nz = 128, 64, 48 + num_neigh = 7 + pos_normals_closest_vol = paddle.randn([bsize, 100, 3]) + pos_normals_com_vol = paddle.randn([bsize, 100, 3]) + pos_normals_com_surface = paddle.randn([bsize, 100, 3]) + geom_centers = paddle.randn([bsize, 100, 3]) + grid = paddle.randn([bsize, nx, ny, nz, 3]) + surf_grid = paddle.randn([bsize, nx, ny, nz, 3]) + sdf_grid = paddle.randn([bsize, nx, ny, nz]) + sdf_surf_grid = paddle.randn([bsize, nx, ny, nz]) + sdf_nodes = paddle.randn([bsize, 100, 1]) + surface_coordinates = paddle.randn([bsize, 100, 3]) + surface_neighbors = paddle.randn([bsize, 100, num_neigh, 3]) + surface_normals = paddle.randn([bsize, 100, 3]) + surface_neighbors_normals = paddle.randn([bsize, 100, num_neigh, 3]) + surface_sizes = paddle.randn([bsize, 100, 3]) + surface_neighbors_sizes = paddle.randn([bsize, 100, num_neigh, 3]) + volume_coordinates = paddle.randn([bsize, 100, 3]) + vol_grid_max_min = paddle.randn([bsize, 2, 3]) + surf_grid_max_min = paddle.randn([bsize, 2, 3]) + stream_velocity = paddle.randn([bsize, 1]) + air_density = paddle.randn([bsize, 1]) + input_dict = { + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + "pos_surface_center_of_mass": pos_normals_com_surface, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "stream_velocity": stream_velocity, + "air_density": air_density, + } + output = model(input_dict) + print(f"{output[0].shape}, {output[1].shape}") diff --git a/ppsci/data/dataset/domino_datapipe.py b/ppsci/data/dataset/domino_datapipe.py index be9499bfeb..b349633214 100644 --- a/ppsci/data/dataset/domino_datapipe.py +++ b/ppsci/data/dataset/domino_datapipe.py @@ -53,6 +53,11 @@ def normalize(field, mx, mn): return 2.0 * (field - mn) / (mx - mn) - 1.0 +def unnormalize(field, mx, mn): + """Function to unnormalize fields""" + return (field + 1.0) * (mx - mn) * 0.5 + mn + + def standardize(field, mean, std): """Function to standardize fields""" return (field - mean) / std From 5df9b6d613cfe43d511ed96b87341e33b348a211 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Fri, 30 May 2025 00:00:46 +0800 Subject: [PATCH 09/11] feat(domino): refactor physicsnemo --- docs/zh/examples/domino.md | 20 +- examples/domino/conf/config.yaml | 1 + examples/domino/{train.py => domino.py} | 775 +++++++++++++++++++++++- examples/domino/process_data.py | 87 --- examples/domino/requirements.txt | 6 - examples/domino/test.py | 747 ----------------------- ppsci/arch/physicsnemo.py | 182 ------ ppsci/data/dataset/domino_datapipe.py | 365 ++++++++++- ppsci/data/process/__init__.py | 3 +- ppsci/data/process/openfoam/__init__.py | 2 +- 10 files changed, 1134 insertions(+), 1054 deletions(-) rename examples/domino/{train.py => domino.py} (51%) delete mode 100644 examples/domino/process_data.py delete mode 100644 examples/domino/requirements.txt delete mode 100644 examples/domino/test.py diff --git a/docs/zh/examples/domino.md b/docs/zh/examples/domino.md index d54ed27361..e7b3eeae11 100644 --- a/docs/zh/examples/domino.md +++ b/docs/zh/examples/domino.md @@ -11,10 +11,11 @@ # 2. Specify the configuration settings in `examples/domino/conf/config.yaml`. # 3. Run process_data.py. This will process VTP/VTU files and save them as npy for faster processing in DoMINO datapipe. Modify data_processor key in config file. Additionally, run cache_data.py to save outputs of DoMINO datapipe in the .npy files. The DoMINO datapipe is set up to calculate Signed Distance Field and Nearest Neighbor interpolations on-the-fly during training. Caching will save these as a preprocessing step and should be used in cases where the STL surface meshes are upwards of 30 million cells. The final processed dataset should be divided and saved into 2 directories, for training and validation. Specify these directories in conf/config.yaml. - python3 process_data.py + # specify mode using `process`, set path to data_processor.output_dir and data_processor.input_dir + python3 domino.py - # 4. run train - python3 train.py + # 4. run train, specify mode using `train`, set path to data.input_dir and data.input_dir_val + python3 domino.py ``` === "模型评估命令" @@ -29,7 +30,8 @@ ``` sh cd examples/domino - python3 test.py + # specify mode using `eval`, and set path to eval.test_path, eval.save_path and eval.checkpoint_name + python3 domino.py ``` ## 1. 背景简介 @@ -70,15 +72,9 @@ DOMINO模型通过这种分解式、多尺度和迭代的方法,能够有效 ## 3. 完整代码 -``` py linenums="1" title="examples/domino/train.py" +``` py linenums="1" title="examples/domino/domino.py" --8<-- -examples/domino/train.py ---8<-- -``` - -``` py linenums="1" title="examples/domino/test.py" ---8<-- -examples/domino/test.py +examples/domino/domino.py --8<-- ``` diff --git a/examples/domino/conf/config.yaml b/examples/domino/conf/config.yaml index 8b649e9968..1c029b7c89 100644 --- a/examples/domino/conf/config.yaml +++ b/examples/domino/conf/config.yaml @@ -17,6 +17,7 @@ project: # Project name name: AWS_Dataset +mode: train # process, train, eval seed: 42 exp_tag: 1 # Experiment tag # Main output directory. diff --git a/examples/domino/train.py b/examples/domino/domino.py similarity index 51% rename from examples/domino/train.py rename to examples/domino/domino.py index 40ea1f3b91..073d22723c 100644 --- a/examples/domino/train.py +++ b/examples/domino/domino.py @@ -14,6 +14,7 @@ # # refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino +import multiprocessing import os import re import time @@ -22,6 +23,8 @@ import numpy as np import paddle import paddle.distributed as dist +import pyvista as pv +import vtk from hydra.utils import to_absolute_path from omegaconf import DictConfig from omegaconf import OmegaConf @@ -30,15 +33,81 @@ from paddle.amp import auto_cast from paddle.io import DataLoader from paddle.io import DistributedBatchSampler +from scipy.spatial import KDTree +from vtk.util import numpy_support from ppsci.arch.physicsnemo import DoMINO from ppsci.arch.physicsnemo import create_directory +from ppsci.arch.physicsnemo import get_fields +from ppsci.arch.physicsnemo import get_node_to_elem +from ppsci.arch.physicsnemo import get_volume_data from ppsci.arch.physicsnemo import load_checkpoint from ppsci.arch.physicsnemo import mean_std_sampling from ppsci.arch.physicsnemo import save_checkpoint +from ppsci.arch.physicsnemo import write_to_vtp +from ppsci.arch.physicsnemo import write_to_vtu from ppsci.data.dataset.domino_datapipe import DoMINODataPipe +from ppsci.data.dataset.domino_datapipe import OpenFoamDataset +from ppsci.data.dataset.domino_datapipe import cal_normal_positional_encoding +from ppsci.data.dataset.domino_datapipe import calculate_center_of_mass +from ppsci.data.dataset.domino_datapipe import create_grid +from ppsci.data.dataset.domino_datapipe import get_filenames +from ppsci.data.dataset.domino_datapipe import normalize +from ppsci.data.dataset.domino_datapipe import unnormalize +from ppsci.data.process.openfoam import process_files +from ppsci.utils.sdf import signed_distance_field -paddle.set_device("gpu") +AIR_DENSITY = 1.205 +STREAM_VELOCITY = 30.00 + + +def process(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + + fm_data = OpenFoamDataset( + cfg.data_processor.input_dir, + kind=cfg.data_processor.kind, + volume_variables=volume_variable_names, + surface_variables=surface_variable_names, + model_type=cfg.model.model_type, + ) + output_dir = cfg.data_processor.output_dir + create_directory(output_dir) # noqa: F405 + n_processors = cfg.data_processor.num_processors + + num_files = len(fm_data) + ids = np.arange(num_files) + num_elements = int(num_files / n_processors) + 1 + process_list = [] + ctx = multiprocessing.get_context("spawn") + for i in range(n_processors): + if i != n_processors - 1: + sf = ids[i * num_elements : i * num_elements + num_elements] + else: + sf = ids[i * num_elements :] + # print(sf) + process = ctx.Process(target=process_files, args=(sf, i, fm_data, output_dir)) + + process.start() + process_list.append(process) + + for process in process_list: + process.join() def relative_loss_fn(output, target, padded_value=-10): @@ -667,8 +736,7 @@ def compute_scaling_factors(cfg: DictConfig): np.save(surf_save_path, surf_scaling_factors) -@hydra.main(version_base="1.3", config_path="conf", config_name="config") -def main(cfg: DictConfig) -> None: +def train(cfg: DictConfig) -> None: compute_scaling_factors(cfg) input_path = cfg.data.input_dir input_path_val = cfg.data.input_dir_val @@ -885,14 +953,6 @@ def main(cfg: DictConfig) -> None: f"Integral factor {initial_integral_factor}" ) - # if paddle.distributed.get_rank() == 0: - # writer.add_scalars( - # "Training vs. Validation Loss", - # {"Training": avg_loss, "Validation": avg_vloss}, - # epoch_number, - # ) - # writer.flush() - # Track best performance, and save the model's state if paddle.distributed.get_world_size() > 1: paddle.distributed.barrier() @@ -923,5 +983,698 @@ def main(cfg: DictConfig) -> None: exit() +def loss_fn(output, target): + masked_loss = paddle.mean(((output - target) ** 2.0), (0, 1, 2)) + loss = paddle.mean(masked_loss) + return loss + + +def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): + running_tloss_vol = 0.0 + running_tloss_surf = 0.0 + + if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": + output_features_vol = True + else: + output_features_vol = None + + if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": + output_features_surf = True + else: + output_features_surf = None + + with paddle.no_grad(): + point_batch_size = 256000 + + # Non-dimensionalization factors + air_density = data_dict["air_density"] + stream_velocity = data_dict["stream_velocity"] + length_scale = data_dict["length_scale"] + + # STL nodes + geo_centers = data_dict["geometry_coordinates"] + + # Bounding box grid + s_grid = data_dict["surf_grid"] + sdf_surf_grid = data_dict["sdf_surf_grid"] + # Scaling factors + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + + if output_features_vol is not None: + # Represent geometry on computational grid + # Computational domain grid + p_grid = data_dict["grid"] + sdf_grid = data_dict["sdf_grid"] + # Scaling factors + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + + # Normalize based on computational domain + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + encoding_g_vol = model.module.geo_rep(geo_centers_vol, p_grid, sdf_grid) + + # Normalize based on BBox around surface (car) + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.module.geo_rep( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + if output_features_surf is not None: + # Represent geometry on bounding box + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + encoding_g_surf = model.module.geo_rep( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + geo_encoding = 0.5 * encoding_g_surf + # Average the encodings + if output_features_vol is not None: + geo_encoding += 0.5 * encoding_g_vol + + if output_features_vol is not None: + # First calculate volume predictions if required + volume_mesh_centers = data_dict["volume_mesh_centers"] + target_vol = data_dict["volume_fields"] + # SDF on volume mesh nodes + sdf_nodes = data_dict["sdf_nodes"] + # Positional encoding based on closest point on surface to a volume node + pos_volume_closest = data_dict["pos_volume_closest"] + # Positional encoding based on center of mass of geometry to volume node + pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] + p_grid = data_dict["grid"] + + prediction_vol = np.zeros_like(target_vol.cpu().numpy()) + num_points = volume_mesh_centers.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with paddle.no_grad(): + target_batch = target_vol[:, start_idx:end_idx] + volume_mesh_centers_batch = volume_mesh_centers[ + :, start_idx:end_idx + ] + sdf_nodes_batch = sdf_nodes[:, start_idx:end_idx] + pos_volume_closest_batch = pos_volume_closest[:, start_idx:end_idx] + pos_normals_com_batch = pos_volume_center_of_mass[ + :, start_idx:end_idx + ] + geo_encoding_local = model.module.geo_encoding_local( + geo_encoding, volume_mesh_centers_batch, p_grid + ) + if cfg.model.use_sdf_in_basis_func: + pos_encoding = paddle.concat( + ( + sdf_nodes_batch, + pos_volume_closest_batch, + pos_normals_com_batch, + ), + axis=-1, + ) + else: + pos_encoding = pos_normals_com_batch + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="volume" + ) + tpredictions_batch = model.module.calculate_solution( + volume_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + stream_velocity, + air_density, + num_sample_points=20, + eval_mode="volume", + ) + running_tloss_vol += loss_fn(tpredictions_batch, target_batch) + prediction_vol[ + :, start_idx:end_idx + ] = tpredictions_batch.cpu().numpy() + + prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1]) + + prediction_vol[:, :, :3] = ( + prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy() + ) + prediction_vol[:, :, 3] = ( + prediction_vol[:, :, 3] + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() + ) + prediction_vol[:, :, 4] = ( + prediction_vol[:, :, 4] + * stream_velocity[0, 0].cpu().numpy() + * length_scale[0].cpu().numpy() + ) + else: + prediction_vol = None + + if output_features_surf is not None: + # Next calculate surface predictions + # Sampled points on surface + surface_mesh_centers = data_dict["surface_mesh_centers"] + surface_normals = data_dict["surface_normals"] + surface_areas = data_dict["surface_areas"] + + # Neighbors of sampled points on surface + surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] + surface_neighbors_normals = data_dict["surface_neighbors_normals"] + surface_neighbors_areas = data_dict["surface_neighbors_areas"] + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] + num_points = surface_mesh_centers.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + + target_surf = data_dict["surface_fields"] + prediction_surf = np.zeros_like(target_surf.cpu().numpy()) + + surface_areas = paddle.unsqueeze(surface_areas, -1) + surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + with paddle.no_grad(): + target_batch = target_surf[:, start_idx:end_idx] + surface_mesh_centers_batch = surface_mesh_centers[ + :, start_idx:end_idx + ] + surface_mesh_neighbors_batch = surface_mesh_neighbors[ + :, start_idx:end_idx + ] + surface_normals_batch = surface_normals[:, start_idx:end_idx] + surface_neighbors_normals_batch = surface_neighbors_normals[ + :, start_idx:end_idx + ] + surface_areas_batch = surface_areas[:, start_idx:end_idx] + surface_neighbors_areas_batch = surface_neighbors_areas[ + :, start_idx:end_idx + ] + pos_surface_center_of_mass_batch = pos_surface_center_of_mass[ + :, start_idx:end_idx + ] + geo_encoding_local = model.module.geo_encoding_local_surface( + 0.5 * encoding_g_surf, surface_mesh_centers_batch, s_grid + ) + pos_encoding = pos_surface_center_of_mass_batch + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="surface" + ) + + if cfg.model.surface_neighbors: + tpredictions_batch = ( + model.module.calculate_solution_with_neighbors( + surface_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + surface_mesh_neighbors_batch, + surface_normals_batch, + surface_neighbors_normals_batch, + surface_areas_batch, + surface_neighbors_areas_batch, + stream_velocity, + air_density, + ) + ) + else: + tpredictions_batch = model.module.calculate_solution( + surface_mesh_centers_batch, + geo_encoding_local, + pos_encoding, + stream_velocity, + air_density, + num_sample_points=1, + eval_mode="surface", + ) + running_tloss_surf += loss_fn(tpredictions_batch, target_batch) + prediction_surf[ + :, start_idx:end_idx + ] = tpredictions_batch.cpu().numpy() + + prediction_surf = ( + unnormalize(prediction_surf, surf_factors[0], surf_factors[1]) + * stream_velocity[0, 0].cpu().numpy() ** 2.0 + * air_density[0, 0].cpu().numpy() + ) + + else: + prediction_surf = None + + return prediction_vol, prediction_surf + + +def test(cfg: DictConfig): + print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + input_path = cfg.eval.test_path + + model_type = cfg.model.model_type + + dist.init_parallel_env() + + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + vol_save_path = os.path.join( + "outputs", cfg.project.name, "volume_scaling_factors.npy" + ) + surf_save_path = os.path.join( + "outputs", cfg.project.name, "surface_scaling_factors.npy" + ) + if os.path.exists(vol_save_path) and os.path.exists(surf_save_path): + vol_factors = np.load(vol_save_path) + surf_factors = np.load(surf_save_path) + else: + vol_factors = None + surf_factors = None + + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + model_parameters=cfg.model, + ) + + checkpoint = paddle.load( + to_absolute_path(os.path.join(cfg.resume_dir, cfg.eval.checkpoint_name)), + ) + + model.set_state_dict(checkpoint) + + print("Model loaded") + + if paddle.distributed.get_world_size() > 1: + model = DataParallel( + model, + ) + + dirnames_per_gpu = get_filenames(input_path) + + pred_save_path = cfg.eval.save_path + create_directory(pred_save_path) + + for count, dirname in enumerate(dirnames_per_gpu): + # print(f"Processing file {dirname}") + filepath = os.path.join(input_path, dirname) + tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) + stl_path = os.path.join(filepath, f"drivaer_{tag}.stl") + vtp_path = os.path.join(filepath, f"boundary_{tag}.vtp") + vtu_path = os.path.join(filepath, f"volume_{tag}.vtu") + + vtp_pred_save_path = os.path.join( + pred_save_path, f"boundary_{tag}_predicted.vtp" + ) + vtu_pred_save_path = os.path.join(pred_save_path, f"volume_{tag}_predicted.vtu") + + # Read STL + reader = pv.get_reader(stl_path) + mesh_stl = reader.read() + stl_vertices = mesh_stl.points + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ + :, 1: + ] # Assuming triangular elements + mesh_indices_flattened = stl_faces.flatten() + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"], dtype=np.float32) + stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32) + + # Center of mass calculation + center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + + if cfg.data.bounding_box_surface is None: + s_max = np.amax(stl_vertices, 0) + s_min = np.amin(stl_vertices, 0) + else: + bounding_box_dims_surf = [] + bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.max)) + bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.min)) + s_max = np.float32(bounding_box_dims_surf[0]) + s_min = np.float32(bounding_box_dims_surf[1]) + + nx, ny, nz = cfg.model.interp_res + + surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) + surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_surf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + surf_grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + surf_grid = np.float32(surf_grid) + sdf_surf_grid = np.float32(sdf_surf_grid) + surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) + + # Read VTP + if model_type == "surface" or model_type == "combined": + reader = vtk.vtkXMLPolyDataReader() + reader.SetFileName(vtp_path) + reader.Update() + polydata_surf = reader.GetOutput() + + celldata_all = get_node_to_elem(polydata_surf) + + celldata = celldata_all.GetCellData() + surface_fields = get_fields(celldata, surface_variable_names) + surface_fields = np.concatenate(surface_fields, axis=-1) + + mesh = pv.PolyData(polydata_surf) + surface_coordinates = np.array(mesh.cell_centers().points, dtype=np.float32) + + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query( + surface_coordinates, k=cfg.model.num_surface_neighbors + ) + + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + + surface_normals = np.array(mesh.cell_normals, dtype=np.float32) + surface_sizes = mesh.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"], dtype=np.float32) + + # Normalize cell normals + surface_normals = ( + surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] + ) + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_sizes = surface_sizes[ii] + surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] + + dx, dy, dz = ( + (s_max[0] - s_min[0]) / nx, + (s_max[1] - s_min[1]) / ny, + (s_max[2] - s_min[2]) / nz, + ) + + if cfg.model.positional_encoding: + pos_surface_center_of_mass = cal_normal_positional_encoding( + surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_surface_center_of_mass = surface_coordinates - center_of_mass + + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + + else: + surface_coordinates = None + surface_fields = None + surface_sizes = None + surface_normals = None + surface_neighbors = None + surface_neighbors_normals = None + surface_neighbors_sizes = None + pos_surface_center_of_mass = None + + # Read VTU + if model_type == "volume" or model_type == "combined": + reader = vtk.vtkXMLUnstructuredGridReader() + reader.SetFileName(vtu_path) + reader.Update() + polydata_vol = reader.GetOutput() + volume_coordinates, volume_fields = get_volume_data( + polydata_vol, volume_variable_names + ) + volume_fields = np.concatenate(volume_fields, axis=-1) + # print(f"Processed vtu {vtu_path}") + + bounding_box_dims = [] + bounding_box_dims.append(np.asarray(cfg.data.bounding_box.max)) + bounding_box_dims.append(np.asarray(cfg.data.bounding_box.min)) + + if bounding_box_dims is None: + c_max = s_max + (s_max - s_min) / 2 + c_min = s_min - (s_max - s_min) / 2 + c_min[2] = s_min[2] + else: + c_max = np.float32(bounding_box_dims[0]) + c_min = np.float32(bounding_box_dims[1]) + + dx, dy, dz = ( + (c_max[0] - c_min[0]) / nx, + (c_max[1] - c_min[1]) / ny, + (c_max[2] - c_min[2]) / nz, + ) + # Generate a grid of specified resolution to map the bounding box + # The grid is used for capturing structured geometry features and SDF representation of geometry + grid = create_grid(c_max, c_min, [nx, ny, nz]) + grid_reshaped = grid.reshape(nx * ny * nz, 3) + + # SDF calculation on the grid using WARP + sdf_grid = ( + signed_distance_field( + stl_vertices, + mesh_indices_flattened, + grid_reshaped, + use_sign_winding_number=True, + ) + .numpy() + .reshape(nx, ny, nz) + ) + + # SDF calculation + sdf_nodes, sdf_node_closest_point = signed_distance_field( + stl_vertices, + mesh_indices_flattened, + volume_coordinates, + include_hit_points=True, + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) + sdf_node_closest_point = sdf_node_closest_point.numpy() + + if cfg.model.positional_encoding: + pos_volume_closest = cal_normal_positional_encoding( + volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz] + ) + pos_volume_center_of_mass = cal_normal_positional_encoding( + volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] + ) + else: + pos_volume_closest = volume_coordinates - sdf_node_closest_point + pos_volume_center_of_mass = volume_coordinates - center_of_mass + + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(grid, c_max, c_min) + vol_grid_max_min = np.asarray([c_min, c_max]) + + else: + volume_coordinates = None + volume_fields = None + pos_volume_closest = None + pos_volume_center_of_mass = None + + # print(f"Processed sdf and normalized") + + geom_centers = np.float32(stl_vertices) + + if model_type == "combined": + # Add the parameters to the dictionary + data_dict = { + "pos_volume_closest": pos_volume_closest, + "pos_volume_center_of_mass": pos_volume_center_of_mass, + "pos_surface_center_of_mass": pos_surface_center_of_mass, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "surface_fields": surface_fields, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + elif model_type == "surface": + data_dict = { + "pos_surface_center_of_mass": np.float32(pos_surface_center_of_mass), + "geometry_coordinates": np.float32(geom_centers), + "surf_grid": np.float32(surf_grid), + "sdf_surf_grid": np.float32(sdf_surf_grid), + "surface_mesh_centers": np.float32(surface_coordinates), + "surface_mesh_neighbors": np.float32(surface_neighbors), + "surface_normals": np.float32(surface_normals), + "surface_neighbors_normals": np.float32(surface_neighbors_normals), + "surface_areas": np.float32(surface_sizes), + "surface_neighbors_areas": np.float32(surface_neighbors_sizes), + "surface_fields": np.float32(surface_fields), + "surface_min_max": np.float32(surf_grid_max_min), + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + elif model_type == "volume": + data_dict = { + "pos_volume_closest": pos_volume_closest, + "pos_volume_center_of_mass": pos_volume_center_of_mass, + "geometry_coordinates": geom_centers, + "grid": grid, + "surf_grid": surf_grid, + "sdf_grid": sdf_grid, + "sdf_surf_grid": sdf_surf_grid, + "sdf_nodes": sdf_nodes, + "volume_fields": volume_fields, + "volume_mesh_centers": volume_coordinates, + "volume_min_max": vol_grid_max_min, + "surface_min_max": surf_grid_max_min, + "length_scale": np.array(length_scale, dtype=np.float32), + "stream_velocity": np.expand_dims( + np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 + ), + "air_density": np.expand_dims( + np.array(AIR_DENSITY, dtype=np.float32), axis=-1 + ), + } + + data_dict = { + key: paddle.to_tensor(np.expand_dims(np.float32(value), 0)) + for key, value in data_dict.items() + } + + prediction_vol, prediction_surf = test_step( + data_dict, + model, + paddle.distributed.get_rank(), + cfg, + vol_factors, + surf_factors, + ) + + if prediction_surf is not None: + surface_sizes = np.expand_dims(surface_sizes, -1) + + force_x_pred = np.sum( + prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0] + - prediction_surf[0, :, 1] * surface_sizes[:, 0] + ) + force_x_true = np.sum( + surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0] + - surface_fields[:, 1] * surface_sizes[:, 0] + ) + print(dirname, force_x_pred, force_x_true) + + if prediction_vol is not None: + target_vol = volume_fields + prediction_vol = prediction_vol[0] + c_min = vol_grid_max_min[0] + c_max = vol_grid_max_min[1] + volume_coordinates = unnormalize(volume_coordinates, c_max, c_min) + ids_in_bbox = np.where( + (volume_coordinates[:, 0] < c_min[0]) + | (volume_coordinates[:, 0] > c_max[0]) + | (volume_coordinates[:, 1] < c_min[1]) + | (volume_coordinates[:, 1] > c_max[1]) + | (volume_coordinates[:, 2] < c_min[2]) + | (volume_coordinates[:, 2] > c_max[2]) + ) + target_vol[ids_in_bbox] = 0.0 + prediction_vol[ids_in_bbox] = 0.0 + l2_gt = np.sum(np.square(target_vol), (0)) + l2_error = np.sum(np.square(prediction_vol - target_vol), (0)) + print( + "L-2 norm:", + dirname, + np.sqrt(l2_error), + np.sqrt(l2_gt), + np.sqrt(l2_error) / np.sqrt(l2_gt), + ) + + if prediction_surf is not None: + surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 0:1]) + surfParam_vtk.SetName(f"{surface_variable_names[0]}Pred") + celldata_all.GetCellData().AddArray(surfParam_vtk) + + surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 1:]) + surfParam_vtk.SetName(f"{surface_variable_names[1]}Pred") + celldata_all.GetCellData().AddArray(surfParam_vtk) + + write_to_vtp(celldata_all, vtp_pred_save_path) + + if prediction_vol is not None: + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 0:3]) + volParam_vtk.SetName(f"{volume_variable_names[0]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 3:4]) + volParam_vtk.SetName(f"{volume_variable_names[1]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) + volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) + + write_to_vtu(polydata_vol, vtu_pred_save_path) + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + test(cfg) + elif cfg.mode == "process": + process(cfg) + else: + raise ValueError( + f"cfg.mode should in ['process', 'train', 'eval'], but got '{cfg.mode}'" + ) + + if __name__ == "__main__": main() diff --git a/examples/domino/process_data.py b/examples/domino/process_data.py deleted file mode 100644 index 9fb0d5d572..0000000000 --- a/examples/domino/process_data.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino - -""" -This code runs the data processing in parallel to load OpenFoam files, process them -and save in the npy format for faster processing in the DoMINO datapipes. Several -parameters such as number of processors, input and output paths, etc. can be -configured in config.yaml in the data_processing tab. -""" - -import multiprocessing - -import hydra -import numpy as np -from omegaconf import DictConfig -from omegaconf import OmegaConf -from openfoam_datapipe import OpenFoamDataset -from physicsnemo.utils.domino.utils import * # noqa: F403 - -from ppsci.data.process.openfoam import process_files # noqa: F401 - - -@hydra.main(version_base="1.3", config_path="conf", config_name="config") -def main(cfg: DictConfig): - print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") - volume_variable_names = list(cfg.variables.volume.solution.keys()) - num_vol_vars = 0 - for j in volume_variable_names: - if cfg.variables.volume.solution[j] == "vector": - num_vol_vars += 3 - else: - num_vol_vars += 1 - - surface_variable_names = list(cfg.variables.surface.solution.keys()) - num_surf_vars = 0 - for j in surface_variable_names: - if cfg.variables.surface.solution[j] == "vector": - num_surf_vars += 3 - else: - num_surf_vars += 1 - - fm_data = OpenFoamDataset( - cfg.data_processor.input_dir, - kind=cfg.data_processor.kind, - volume_variables=volume_variable_names, - surface_variables=surface_variable_names, - model_type=cfg.model.model_type, - ) - output_dir = cfg.data_processor.output_dir - create_directory(output_dir) # noqa: F405 - n_processors = cfg.data_processor.num_processors - - num_files = len(fm_data) - ids = np.arange(num_files) - num_elements = int(num_files / n_processors) + 1 - process_list = [] - ctx = multiprocessing.get_context("spawn") - for i in range(n_processors): - if i != n_processors - 1: - sf = ids[i * num_elements : i * num_elements + num_elements] - else: - sf = ids[i * num_elements :] - # print(sf) - process = ctx.Process(target=process_files, args=(sf, i, fm_data, output_dir)) - - process.start() - process_list.append(process) - - for process in process_list: - process.join() - - -if __name__ == "__main__": - main() diff --git a/examples/domino/requirements.txt b/examples/domino/requirements.txt deleted file mode 100644 index 57cc3b0363..0000000000 --- a/examples/domino/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -hydra-core -importlib_metadata -pyvista==0.34.2 -termcolor -treelib -warp-lang diff --git a/examples/domino/test.py b/examples/domino/test.py deleted file mode 100644 index 9b48e0de28..0000000000 --- a/examples/domino/test.py +++ /dev/null @@ -1,747 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This code defines a distributed pipeline for testing the DoMINO model on -CFD datasets. It includes the instantiating the DoMINO model and datapipe, -automatically loading the most recent checkpoint, reading the VTP/VTU/STL -testing files, calculation of parameters required for DoMINO model and -evaluating the model in parallel using DataParallel across multiple -GPUs. This is a common recipe that enables training of combined models for surface -and volume as well either of them separately. The model predictions are loaded in -the the VTP/VTU files and saved in the specified directory. The eval tab in -config.yaml can be used to specify the input and output directories. -""" - -import os -import re - -import hydra -import numpy as np -import paddle -import paddle.distributed as dist -import pyvista as pv -import vtk -from hydra.utils import to_absolute_path -from omegaconf import DictConfig -from omegaconf import OmegaConf -from paddle import DataParallel -from scipy.spatial import KDTree -from vtk.util import numpy_support - -from ppsci.arch.physicsnemo import DoMINO -from ppsci.arch.physicsnemo import create_directory -from ppsci.arch.physicsnemo import get_fields -from ppsci.arch.physicsnemo import get_node_to_elem -from ppsci.arch.physicsnemo import get_volume_data -from ppsci.arch.physicsnemo import write_to_vtp -from ppsci.arch.physicsnemo import write_to_vtu -from ppsci.data.dataset.domino_datapipe import cal_normal_positional_encoding -from ppsci.data.dataset.domino_datapipe import calculate_center_of_mass -from ppsci.data.dataset.domino_datapipe import create_grid -from ppsci.data.dataset.domino_datapipe import get_filenames -from ppsci.data.dataset.domino_datapipe import normalize -from ppsci.data.dataset.domino_datapipe import unnormalize -from ppsci.utils.sdf import signed_distance_field - -AIR_DENSITY = 1.205 -STREAM_VELOCITY = 30.00 - -paddle.set_device("gpu") - - -def loss_fn(output, target): - masked_loss = paddle.mean(((output - target) ** 2.0), (0, 1, 2)) - loss = paddle.mean(masked_loss) - return loss - - -def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): - running_tloss_vol = 0.0 - running_tloss_surf = 0.0 - - if cfg.model.model_type == "volume" or cfg.model.model_type == "combined": - output_features_vol = True - else: - output_features_vol = None - - if cfg.model.model_type == "surface" or cfg.model.model_type == "combined": - output_features_surf = True - else: - output_features_surf = None - - with paddle.no_grad(): - point_batch_size = 256000 - - # Non-dimensionalization factors - air_density = data_dict["air_density"] - stream_velocity = data_dict["stream_velocity"] - length_scale = data_dict["length_scale"] - - # STL nodes - geo_centers = data_dict["geometry_coordinates"] - - # Bounding box grid - s_grid = data_dict["surf_grid"] - sdf_surf_grid = data_dict["sdf_surf_grid"] - # Scaling factors - surf_max = data_dict["surface_min_max"][:, 1] - surf_min = data_dict["surface_min_max"][:, 0] - - if output_features_vol is not None: - # Represent geometry on computational grid - # Computational domain grid - p_grid = data_dict["grid"] - sdf_grid = data_dict["sdf_grid"] - # Scaling factors - vol_max = data_dict["volume_min_max"][:, 1] - vol_min = data_dict["volume_min_max"][:, 0] - - # Normalize based on computational domain - geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 - encoding_g_vol = model.module.geo_rep(geo_centers_vol, p_grid, sdf_grid) - - # Normalize based on BBox around surface (car) - geo_centers_surf = ( - 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 - ) - encoding_g_surf = model.module.geo_rep( - geo_centers_surf, s_grid, sdf_surf_grid - ) - - if output_features_surf is not None: - # Represent geometry on bounding box - geo_centers_surf = ( - 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 - ) - encoding_g_surf = model.module.geo_rep( - geo_centers_surf, s_grid, sdf_surf_grid - ) - - geo_encoding = 0.5 * encoding_g_surf - # Average the encodings - if output_features_vol is not None: - geo_encoding += 0.5 * encoding_g_vol - - if output_features_vol is not None: - # First calculate volume predictions if required - volume_mesh_centers = data_dict["volume_mesh_centers"] - target_vol = data_dict["volume_fields"] - # SDF on volume mesh nodes - sdf_nodes = data_dict["sdf_nodes"] - # Positional encoding based on closest point on surface to a volume node - pos_volume_closest = data_dict["pos_volume_closest"] - # Positional encoding based on center of mass of geometry to volume node - pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] - p_grid = data_dict["grid"] - - prediction_vol = np.zeros_like(target_vol.cpu().numpy()) - num_points = volume_mesh_centers.shape[1] - subdomain_points = int(np.floor(num_points / point_batch_size)) - - for p in range(subdomain_points + 1): - start_idx = p * point_batch_size - end_idx = (p + 1) * point_batch_size - with paddle.no_grad(): - target_batch = target_vol[:, start_idx:end_idx] - volume_mesh_centers_batch = volume_mesh_centers[ - :, start_idx:end_idx - ] - sdf_nodes_batch = sdf_nodes[:, start_idx:end_idx] - pos_volume_closest_batch = pos_volume_closest[:, start_idx:end_idx] - pos_normals_com_batch = pos_volume_center_of_mass[ - :, start_idx:end_idx - ] - geo_encoding_local = model.module.geo_encoding_local( - geo_encoding, volume_mesh_centers_batch, p_grid - ) - if cfg.model.use_sdf_in_basis_func: - pos_encoding = paddle.concat( - ( - sdf_nodes_batch, - pos_volume_closest_batch, - pos_normals_com_batch, - ), - axis=-1, - ) - else: - pos_encoding = pos_normals_com_batch - pos_encoding = model.module.position_encoder( - pos_encoding, eval_mode="volume" - ) - tpredictions_batch = model.module.calculate_solution( - volume_mesh_centers_batch, - geo_encoding_local, - pos_encoding, - stream_velocity, - air_density, - num_sample_points=20, - eval_mode="volume", - ) - running_tloss_vol += loss_fn(tpredictions_batch, target_batch) - prediction_vol[ - :, start_idx:end_idx - ] = tpredictions_batch.cpu().numpy() - - prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1]) - - prediction_vol[:, :, :3] = ( - prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy() - ) - prediction_vol[:, :, 3] = ( - prediction_vol[:, :, 3] - * stream_velocity[0, 0].cpu().numpy() ** 2.0 - * air_density[0, 0].cpu().numpy() - ) - prediction_vol[:, :, 4] = ( - prediction_vol[:, :, 4] - * stream_velocity[0, 0].cpu().numpy() - * length_scale[0].cpu().numpy() - ) - else: - prediction_vol = None - - if output_features_surf is not None: - # Next calculate surface predictions - # Sampled points on surface - surface_mesh_centers = data_dict["surface_mesh_centers"] - surface_normals = data_dict["surface_normals"] - surface_areas = data_dict["surface_areas"] - - # Neighbors of sampled points on surface - surface_mesh_neighbors = data_dict["surface_mesh_neighbors"] - surface_neighbors_normals = data_dict["surface_neighbors_normals"] - surface_neighbors_areas = data_dict["surface_neighbors_areas"] - surface_areas = paddle.unsqueeze(surface_areas, -1) - surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) - pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"] - num_points = surface_mesh_centers.shape[1] - subdomain_points = int(np.floor(num_points / point_batch_size)) - - target_surf = data_dict["surface_fields"] - prediction_surf = np.zeros_like(target_surf.cpu().numpy()) - - surface_areas = paddle.unsqueeze(surface_areas, -1) - surface_neighbors_areas = paddle.unsqueeze(surface_neighbors_areas, -1) - - for p in range(subdomain_points + 1): - start_idx = p * point_batch_size - end_idx = (p + 1) * point_batch_size - with paddle.no_grad(): - target_batch = target_surf[:, start_idx:end_idx] - surface_mesh_centers_batch = surface_mesh_centers[ - :, start_idx:end_idx - ] - surface_mesh_neighbors_batch = surface_mesh_neighbors[ - :, start_idx:end_idx - ] - surface_normals_batch = surface_normals[:, start_idx:end_idx] - surface_neighbors_normals_batch = surface_neighbors_normals[ - :, start_idx:end_idx - ] - surface_areas_batch = surface_areas[:, start_idx:end_idx] - surface_neighbors_areas_batch = surface_neighbors_areas[ - :, start_idx:end_idx - ] - pos_surface_center_of_mass_batch = pos_surface_center_of_mass[ - :, start_idx:end_idx - ] - geo_encoding_local = model.module.geo_encoding_local_surface( - 0.5 * encoding_g_surf, surface_mesh_centers_batch, s_grid - ) - pos_encoding = pos_surface_center_of_mass_batch - pos_encoding = model.module.position_encoder( - pos_encoding, eval_mode="surface" - ) - - if cfg.model.surface_neighbors: - tpredictions_batch = ( - model.module.calculate_solution_with_neighbors( - surface_mesh_centers_batch, - geo_encoding_local, - pos_encoding, - surface_mesh_neighbors_batch, - surface_normals_batch, - surface_neighbors_normals_batch, - surface_areas_batch, - surface_neighbors_areas_batch, - stream_velocity, - air_density, - ) - ) - else: - tpredictions_batch = model.module.calculate_solution( - surface_mesh_centers_batch, - geo_encoding_local, - pos_encoding, - stream_velocity, - air_density, - num_sample_points=1, - eval_mode="surface", - ) - running_tloss_surf += loss_fn(tpredictions_batch, target_batch) - prediction_surf[ - :, start_idx:end_idx - ] = tpredictions_batch.cpu().numpy() - - prediction_surf = ( - unnormalize(prediction_surf, surf_factors[0], surf_factors[1]) - * stream_velocity[0, 0].cpu().numpy() ** 2.0 - * air_density[0, 0].cpu().numpy() - ) - - else: - prediction_surf = None - - return prediction_vol, prediction_surf - - -@hydra.main(version_base="1.3", config_path="conf", config_name="config") -def main(cfg: DictConfig): - print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") - - input_path = cfg.eval.test_path - - model_type = cfg.model.model_type - - dist.init_parallel_env() - - if model_type == "volume" or model_type == "combined": - volume_variable_names = list(cfg.variables.volume.solution.keys()) - num_vol_vars = 0 - for j in volume_variable_names: - if cfg.variables.volume.solution[j] == "vector": - num_vol_vars += 3 - else: - num_vol_vars += 1 - else: - num_vol_vars = None - - if model_type == "surface" or model_type == "combined": - surface_variable_names = list(cfg.variables.surface.solution.keys()) - num_surf_vars = 0 - for j in surface_variable_names: - if cfg.variables.surface.solution[j] == "vector": - num_surf_vars += 3 - else: - num_surf_vars += 1 - else: - num_surf_vars = None - - vol_save_path = os.path.join( - "outputs", cfg.project.name, "volume_scaling_factors.npy" - ) - surf_save_path = os.path.join( - "outputs", cfg.project.name, "surface_scaling_factors.npy" - ) - if os.path.exists(vol_save_path) and os.path.exists(surf_save_path): - vol_factors = np.load(vol_save_path) - surf_factors = np.load(surf_save_path) - else: - vol_factors = None - surf_factors = None - - model = DoMINO( - input_features=3, - output_features_vol=num_vol_vars, - output_features_surf=num_surf_vars, - model_parameters=cfg.model, - ) - - checkpoint = paddle.load( - to_absolute_path(os.path.join(cfg.resume_dir, cfg.eval.checkpoint_name)), - ) - - model.set_state_dict(checkpoint) - - print("Model loaded") - - if paddle.distributed.get_world_size() > 1: - model = DataParallel( - model, - ) - - dirnames_per_gpu = get_filenames(input_path) - - pred_save_path = cfg.eval.save_path - create_directory(pred_save_path) - - for count, dirname in enumerate(dirnames_per_gpu): - # print(f"Processing file {dirname}") - filepath = os.path.join(input_path, dirname) - tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) - stl_path = os.path.join(filepath, f"drivaer_{tag}.stl") - vtp_path = os.path.join(filepath, f"boundary_{tag}.vtp") - vtu_path = os.path.join(filepath, f"volume_{tag}.vtu") - - vtp_pred_save_path = os.path.join( - pred_save_path, f"boundary_{tag}_predicted.vtp" - ) - vtu_pred_save_path = os.path.join(pred_save_path, f"volume_{tag}_predicted.vtu") - - # Read STL - reader = pv.get_reader(stl_path) - mesh_stl = reader.read() - stl_vertices = mesh_stl.points - stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ - :, 1: - ] # Assuming triangular elements - mesh_indices_flattened = stl_faces.flatten() - length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) - stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) - stl_sizes = np.array(stl_sizes.cell_data["Area"], dtype=np.float32) - stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32) - - # Center of mass calculation - center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) - - if cfg.data.bounding_box_surface is None: - s_max = np.amax(stl_vertices, 0) - s_min = np.amin(stl_vertices, 0) - else: - bounding_box_dims_surf = [] - bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.max)) - bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.min)) - s_max = np.float32(bounding_box_dims_surf[0]) - s_min = np.float32(bounding_box_dims_surf[1]) - - nx, ny, nz = cfg.model.interp_res - - surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) - surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) - - # SDF calculation on the grid using WARP - sdf_surf_grid = ( - signed_distance_field( - stl_vertices, - mesh_indices_flattened, - surf_grid_reshaped, - use_sign_winding_number=True, - ) - .numpy() - .reshape(nx, ny, nz) - ) - surf_grid = np.float32(surf_grid) - sdf_surf_grid = np.float32(sdf_surf_grid) - surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) - - # Read VTP - if model_type == "surface" or model_type == "combined": - reader = vtk.vtkXMLPolyDataReader() - reader.SetFileName(vtp_path) - reader.Update() - polydata_surf = reader.GetOutput() - - celldata_all = get_node_to_elem(polydata_surf) - - celldata = celldata_all.GetCellData() - surface_fields = get_fields(celldata, surface_variable_names) - surface_fields = np.concatenate(surface_fields, axis=-1) - - mesh = pv.PolyData(polydata_surf) - surface_coordinates = np.array(mesh.cell_centers().points, dtype=np.float32) - - interp_func = KDTree(surface_coordinates) - dd, ii = interp_func.query( - surface_coordinates, k=cfg.model.num_surface_neighbors - ) - - surface_neighbors = surface_coordinates[ii] - surface_neighbors = surface_neighbors[:, 1:] - - surface_normals = np.array(mesh.cell_normals, dtype=np.float32) - surface_sizes = mesh.compute_cell_sizes( - length=False, area=True, volume=False - ) - surface_sizes = np.array(surface_sizes.cell_data["Area"], dtype=np.float32) - - # Normalize cell normals - surface_normals = ( - surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] - ) - surface_neighbors_normals = surface_normals[ii] - surface_neighbors_normals = surface_neighbors_normals[:, 1:] - surface_neighbors_sizes = surface_sizes[ii] - surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] - - dx, dy, dz = ( - (s_max[0] - s_min[0]) / nx, - (s_max[1] - s_min[1]) / ny, - (s_max[2] - s_min[2]) / nz, - ) - - if cfg.model.positional_encoding: - pos_surface_center_of_mass = cal_normal_positional_encoding( - surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] - ) - else: - pos_surface_center_of_mass = surface_coordinates - center_of_mass - - surface_coordinates = normalize(surface_coordinates, s_max, s_min) - surface_neighbors = normalize(surface_neighbors, s_max, s_min) - surf_grid = normalize(surf_grid, s_max, s_min) - - else: - surface_coordinates = None - surface_fields = None - surface_sizes = None - surface_normals = None - surface_neighbors = None - surface_neighbors_normals = None - surface_neighbors_sizes = None - pos_surface_center_of_mass = None - - # Read VTU - if model_type == "volume" or model_type == "combined": - reader = vtk.vtkXMLUnstructuredGridReader() - reader.SetFileName(vtu_path) - reader.Update() - polydata_vol = reader.GetOutput() - volume_coordinates, volume_fields = get_volume_data( - polydata_vol, volume_variable_names - ) - volume_fields = np.concatenate(volume_fields, axis=-1) - # print(f"Processed vtu {vtu_path}") - - bounding_box_dims = [] - bounding_box_dims.append(np.asarray(cfg.data.bounding_box.max)) - bounding_box_dims.append(np.asarray(cfg.data.bounding_box.min)) - - if bounding_box_dims is None: - c_max = s_max + (s_max - s_min) / 2 - c_min = s_min - (s_max - s_min) / 2 - c_min[2] = s_min[2] - else: - c_max = np.float32(bounding_box_dims[0]) - c_min = np.float32(bounding_box_dims[1]) - - dx, dy, dz = ( - (c_max[0] - c_min[0]) / nx, - (c_max[1] - c_min[1]) / ny, - (c_max[2] - c_min[2]) / nz, - ) - # Generate a grid of specified resolution to map the bounding box - # The grid is used for capturing structured geometry features and SDF representation of geometry - grid = create_grid(c_max, c_min, [nx, ny, nz]) - grid_reshaped = grid.reshape(nx * ny * nz, 3) - - # SDF calculation on the grid using WARP - sdf_grid = ( - signed_distance_field( - stl_vertices, - mesh_indices_flattened, - grid_reshaped, - use_sign_winding_number=True, - ) - .numpy() - .reshape(nx, ny, nz) - ) - - # SDF calculation - sdf_nodes, sdf_node_closest_point = signed_distance_field( - stl_vertices, - mesh_indices_flattened, - volume_coordinates, - include_hit_points=True, - use_sign_winding_number=True, - ) - sdf_nodes = sdf_nodes.numpy().reshape(-1, 1) - sdf_node_closest_point = sdf_node_closest_point.numpy() - - if cfg.model.positional_encoding: - pos_volume_closest = cal_normal_positional_encoding( - volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz] - ) - pos_volume_center_of_mass = cal_normal_positional_encoding( - volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] - ) - else: - pos_volume_closest = volume_coordinates - sdf_node_closest_point - pos_volume_center_of_mass = volume_coordinates - center_of_mass - - volume_coordinates = normalize(volume_coordinates, c_max, c_min) - grid = normalize(grid, c_max, c_min) - vol_grid_max_min = np.asarray([c_min, c_max]) - - else: - volume_coordinates = None - volume_fields = None - pos_volume_closest = None - pos_volume_center_of_mass = None - - # print(f"Processed sdf and normalized") - - geom_centers = np.float32(stl_vertices) - - if model_type == "combined": - # Add the parameters to the dictionary - data_dict = { - "pos_volume_closest": pos_volume_closest, - "pos_volume_center_of_mass": pos_volume_center_of_mass, - "pos_surface_center_of_mass": pos_surface_center_of_mass, - "geometry_coordinates": geom_centers, - "grid": grid, - "surf_grid": surf_grid, - "sdf_grid": sdf_grid, - "sdf_surf_grid": sdf_surf_grid, - "sdf_nodes": sdf_nodes, - "surface_mesh_centers": surface_coordinates, - "surface_mesh_neighbors": surface_neighbors, - "surface_normals": surface_normals, - "surface_neighbors_normals": surface_neighbors_normals, - "surface_areas": surface_sizes, - "surface_neighbors_areas": surface_neighbors_sizes, - "volume_fields": volume_fields, - "volume_mesh_centers": volume_coordinates, - "surface_fields": surface_fields, - "volume_min_max": vol_grid_max_min, - "surface_min_max": surf_grid_max_min, - "length_scale": np.array(length_scale, dtype=np.float32), - "stream_velocity": np.expand_dims( - np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 - ), - "air_density": np.expand_dims( - np.array(AIR_DENSITY, dtype=np.float32), axis=-1 - ), - } - elif model_type == "surface": - data_dict = { - "pos_surface_center_of_mass": np.float32(pos_surface_center_of_mass), - "geometry_coordinates": np.float32(geom_centers), - "surf_grid": np.float32(surf_grid), - "sdf_surf_grid": np.float32(sdf_surf_grid), - "surface_mesh_centers": np.float32(surface_coordinates), - "surface_mesh_neighbors": np.float32(surface_neighbors), - "surface_normals": np.float32(surface_normals), - "surface_neighbors_normals": np.float32(surface_neighbors_normals), - "surface_areas": np.float32(surface_sizes), - "surface_neighbors_areas": np.float32(surface_neighbors_sizes), - "surface_fields": np.float32(surface_fields), - "surface_min_max": np.float32(surf_grid_max_min), - "length_scale": np.array(length_scale, dtype=np.float32), - "stream_velocity": np.expand_dims( - np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 - ), - "air_density": np.expand_dims( - np.array(AIR_DENSITY, dtype=np.float32), axis=-1 - ), - } - elif model_type == "volume": - data_dict = { - "pos_volume_closest": pos_volume_closest, - "pos_volume_center_of_mass": pos_volume_center_of_mass, - "geometry_coordinates": geom_centers, - "grid": grid, - "surf_grid": surf_grid, - "sdf_grid": sdf_grid, - "sdf_surf_grid": sdf_surf_grid, - "sdf_nodes": sdf_nodes, - "volume_fields": volume_fields, - "volume_mesh_centers": volume_coordinates, - "volume_min_max": vol_grid_max_min, - "surface_min_max": surf_grid_max_min, - "length_scale": np.array(length_scale, dtype=np.float32), - "stream_velocity": np.expand_dims( - np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1 - ), - "air_density": np.expand_dims( - np.array(AIR_DENSITY, dtype=np.float32), axis=-1 - ), - } - - data_dict = { - key: paddle.to_tensor(np.expand_dims(np.float32(value), 0)) - for key, value in data_dict.items() - } - - prediction_vol, prediction_surf = test_step( - data_dict, - model, - paddle.distributed.get_rank(), - cfg, - vol_factors, - surf_factors, - ) - - if prediction_surf is not None: - surface_sizes = np.expand_dims(surface_sizes, -1) - - force_x_pred = np.sum( - prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0] - - prediction_surf[0, :, 1] * surface_sizes[:, 0] - ) - force_x_true = np.sum( - surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0] - - surface_fields[:, 1] * surface_sizes[:, 0] - ) - print(dirname, force_x_pred, force_x_true) - - if prediction_vol is not None: - target_vol = volume_fields - prediction_vol = prediction_vol[0] - c_min = vol_grid_max_min[0] - c_max = vol_grid_max_min[1] - volume_coordinates = unnormalize(volume_coordinates, c_max, c_min) - ids_in_bbox = np.where( - (volume_coordinates[:, 0] < c_min[0]) - | (volume_coordinates[:, 0] > c_max[0]) - | (volume_coordinates[:, 1] < c_min[1]) - | (volume_coordinates[:, 1] > c_max[1]) - | (volume_coordinates[:, 2] < c_min[2]) - | (volume_coordinates[:, 2] > c_max[2]) - ) - target_vol[ids_in_bbox] = 0.0 - prediction_vol[ids_in_bbox] = 0.0 - l2_gt = np.sum(np.square(target_vol), (0)) - l2_error = np.sum(np.square(prediction_vol - target_vol), (0)) - print( - "L-2 norm:", - dirname, - np.sqrt(l2_error), - np.sqrt(l2_gt), - np.sqrt(l2_error) / np.sqrt(l2_gt), - ) - - if prediction_surf is not None: - surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 0:1]) - surfParam_vtk.SetName(f"{surface_variable_names[0]}Pred") - celldata_all.GetCellData().AddArray(surfParam_vtk) - - surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 1:]) - surfParam_vtk.SetName(f"{surface_variable_names[1]}Pred") - celldata_all.GetCellData().AddArray(surfParam_vtk) - - write_to_vtp(celldata_all, vtp_pred_save_path) - - if prediction_vol is not None: - - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 0:3]) - volParam_vtk.SetName(f"{volume_variable_names[0]}Pred") - polydata_vol.GetPointData().AddArray(volParam_vtk) - - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 3:4]) - volParam_vtk.SetName(f"{volume_variable_names[1]}Pred") - polydata_vol.GetPointData().AddArray(volParam_vtk) - - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) - volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") - polydata_vol.GetPointData().AddArray(volParam_vtk) - - write_to_vtu(polydata_vol, vtu_pred_save_path) - - -if __name__ == "__main__": - main() diff --git a/ppsci/arch/physicsnemo.py b/ppsci/arch/physicsnemo.py index 1516694bfd..ffa401b866 100644 --- a/ppsci/arch/physicsnemo.py +++ b/ppsci/arch/physicsnemo.py @@ -43,188 +43,6 @@ scheduler = NewType("scheduler", LRScheduler) scaler = NewType("scaler", GradScaler) -try: - import pyvista as pv - - PV_AVAILABLE = True -except ImportError: - PV_AVAILABLE = False -try: - import vtk - from vtk import vtkDataSetTriangleFilter - from vtk.util import numpy_support - - VTK_AVAILABLE = True -except ImportError: - VTK_AVAILABLE = False - - -def unstandardize(field, mean, std): - """Function to unstandardize fields""" - return field * std + mean - - -def write_to_vtp(polydata, filename): - """Function to write polydata to vtp""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - writer = vtk.vtkXMLPolyDataWriter() - writer.SetFileName(filename) - writer.SetInputData(polydata) - writer.Write() - - -def write_to_vtu(polydata, filename): - """Function to write polydata to vtu""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - writer = vtk.vtkXMLUnstructuredGridWriter() - writer.SetFileName(filename) - writer.SetInputData(polydata) - writer.Write() - - -def extract_surface_triangles(tet_mesh): - """Extracts the surface triangles from a triangular mesh.""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - if not PV_AVAILABLE: - raise ImportError("PyVista is not installed. This function cannot be used.") - surface_filter = vtk.vtkDataSetSurfaceFilter() - surface_filter.SetInputData(tet_mesh) - surface_filter.Update() - - surface_mesh = pv.wrap(surface_filter.GetOutput()) - triangle_indices = [] - faces = surface_mesh.faces.reshape((-1, 4)) - for face in faces: - if face[0] == 3: - triangle_indices.extend([face[1], face[2], face[3]]) - else: - raise ValueError("Face is not a triangle") - - return triangle_indices - - -def convert_to_tet_mesh(polydata): - """Function to convert tet to stl""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - # Create a VTK DataSetTriangleFilter object - tet_filter = vtkDataSetTriangleFilter() - tet_filter.SetInputData(polydata) - tet_filter.Update() # Update to apply the filter - - # Get the output as an UnstructuredGrid - # tet_mesh = pv.wrap(tet_filter.GetOutput()) - tet_mesh = tet_filter.GetOutput() - return tet_mesh - - -def get_node_to_elem(polydata): - """Function to convert node to elem""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - c2p = vtk.vtkPointDataToCellData() - c2p.SetInputData(polydata) - c2p.Update() - cell_data = c2p.GetOutput() - return cell_data - - -def get_fields_from_cell(ptdata, var_list): - """Function to get fields from elem""" - fields = [] - for var in var_list: - variable = ptdata.GetArray(var) - num_tuples = variable.GetNumberOfTuples() - cell_fields = [] - for j in range(num_tuples): - variable_value = np.array(variable.GetTuple(j)) - cell_fields.append(variable_value) - cell_fields = np.asarray(cell_fields) - fields.append(cell_fields) - fields = np.transpose(np.asarray(fields), (1, 0)) - - return fields - - -def get_fields(data, variables): - """Function to get fields from VTP/VTU""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - fields = [] - for array_name in variables: - try: - array = data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = numpy_support.vtk_to_numpy(array).reshape( - array.GetNumberOfTuples(), array.GetNumberOfComponents() - ) - fields.append(array_data) - return fields - - -def get_vertices(polydata): - """Function to get vertices""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - points = polydata.GetPoints() - vertices = numpy_support.vtk_to_numpy(points.GetData()) - return vertices - - -def get_volume_data(polydata, variables): - """Function to get volume data""" - vertices = get_vertices(polydata) - point_data = polydata.GetPointData() - - fields = get_fields(point_data, variables) - - return vertices, fields - - -def get_surface_data(polydata, variables): - """Function to get surface data""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - points = polydata.GetPoints() - vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) - - point_data = polydata.GetPointData() - fields = [] - for array_name in variables: - try: - array = point_data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = np.zeros( - (points.GetNumberOfPoints(), array.GetNumberOfComponents()) - ) - for j in range(points.GetNumberOfPoints()): - array.GetTuple(j, array_data[j]) - fields.append(array_data) - - polys = polydata.GetPolys() - if polys is None: - raise ValueError("Failed to get polygons from the polydata.") - polys.InitTraversal() - edges = [] - id_list = vtk.vtkIdList() - for _ in range(polys.GetNumberOfCells()): - polys.GetNextCell(id_list) - num_ids = id_list.GetNumberOfIds() - edges = [ - (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) - ] - - return vertices, fields, edges - def nd_interpolator(coodinates, field, grid): """Function to for nd interpolation""" diff --git a/ppsci/data/dataset/domino_datapipe.py b/ppsci/data/dataset/domino_datapipe.py index b349633214..7dc7fbf3a4 100644 --- a/ppsci/data/dataset/domino_datapipe.py +++ b/ppsci/data/dataset/domino_datapipe.py @@ -27,6 +27,7 @@ """ import os +import random import time from pathlib import Path from typing import Literal @@ -40,6 +41,24 @@ from ppsci.utils.sdf import signed_distance_field +try: + import pyvista as pv + + PV_AVAILABLE = True +except ImportError: + PV_AVAILABLE = False +try: + import vtk + from vtk import vtkDataSetTriangleFilter + from vtk.util import numpy_support + + VTK_AVAILABLE = True +except ImportError: + VTK_AVAILABLE = False + +AIR_DENSITY = 1.205 +STREAM_VELOCITY = 30.00 + def calculate_center_of_mass(stl_centers, stl_sizes): """Function to calculate center of mass""" @@ -63,6 +82,173 @@ def standardize(field, mean, std): return (field - mean) / std +def unstandardize(field, mean, std): + """Function to unstandardize fields""" + return field * std + mean + + +def write_to_vtp(polydata, filename): + """Function to write polydata to vtp""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def write_to_vtu(polydata, filename): + """Function to write polydata to vtu""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + writer = vtk.vtkXMLUnstructuredGridWriter() + writer.SetFileName(filename) + writer.SetInputData(polydata) + writer.Write() + + +def extract_surface_triangles(tet_mesh): + """Extracts the surface triangles from a triangular mesh.""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + if not PV_AVAILABLE: + raise ImportError("PyVista is not installed. This function cannot be used.") + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tet_mesh) + surface_filter.Update() + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise ValueError("Face is not a triangle") + + return triangle_indices + + +def convert_to_tet_mesh(polydata): + """Function to convert tet to stl""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + # Create a VTK DataSetTriangleFilter object + tet_filter = vtkDataSetTriangleFilter() + tet_filter.SetInputData(polydata) + tet_filter.Update() # Update to apply the filter + + # Get the output as an UnstructuredGrid + # tet_mesh = pv.wrap(tet_filter.GetOutput()) + tet_mesh = tet_filter.GetOutput() + return tet_mesh + + +def get_node_to_elem(polydata): + """Function to convert node to elem""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + c2p = vtk.vtkPointDataToCellData() + c2p.SetInputData(polydata) + c2p.Update() + cell_data = c2p.GetOutput() + return cell_data + + +def get_fields_from_cell(ptdata, var_list): + """Function to get fields from elem""" + fields = [] + for var in var_list: + variable = ptdata.GetArray(var) + num_tuples = variable.GetNumberOfTuples() + cell_fields = [] + for j in range(num_tuples): + variable_value = np.array(variable.GetTuple(j)) + cell_fields.append(variable_value) + cell_fields = np.asarray(cell_fields) + fields.append(cell_fields) + fields = np.transpose(np.asarray(fields), (1, 0)) + + return fields + + +def get_fields(data, variables): + """Function to get fields from VTP/VTU""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + fields = [] + for array_name in variables: + try: + array = data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = numpy_support.vtk_to_numpy(array).reshape( + array.GetNumberOfTuples(), array.GetNumberOfComponents() + ) + fields.append(array_data) + return fields + + +def get_vertices(polydata): + """Function to get vertices""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(points.GetData()) + return vertices + + +def get_volume_data(polydata, variables): + """Function to get volume data""" + vertices = get_vertices(polydata) + point_data = polydata.GetPointData() + + fields = get_fields(point_data, variables) + + return vertices, fields + + +def get_surface_data(polydata, variables): + """Function to get surface data""" + if not VTK_AVAILABLE: + raise ImportError("VTK or is not installed. This function cannot be used.") + points = polydata.GetPoints() + vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) + + point_data = polydata.GetPointData() + fields = [] + for array_name in variables: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + fields.append(array_data) + + polys = polydata.GetPolys() + if polys is None: + raise ValueError("Failed to get polygons from the polydata.") + polys.InitTraversal() + edges = [] + id_list = vtk.vtkIdList() + for _ in range(polys.GetNumberOfCells()): + polys.GetNextCell(id_list) + num_ids = id_list.GetNumberOfIds() + edges = [ + (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) + ] + + return vertices, fields, edges + + def cal_normal_positional_encoding(coordinates_a, coordinates_b=None, cell_length=[]): """Function to get normal positional encoding""" dx = cell_length[0] @@ -643,10 +829,175 @@ def __getitem__(self, idx): } -if __name__ == "__main__": - fm_data = DoMINODataPipe( - data_path="/code/processed_data/new_models_1/", - phase="train", - sampling=False, - sample_in_bbox=False, - ) +class DriveSimPaths: + @staticmethod + def geometry_path(car_dir: Path) -> Path: + return car_dir / "body.stl" + + @staticmethod + def volume_path(car_dir: Path) -> Path: + return car_dir / "VTK/simpleFoam_steady_3000/internal.vtu" + + @staticmethod + def surface_path(car_dir: Path) -> Path: + return car_dir / "VTK/simpleFoam_steady_3000/boundary/aero_suv.vtp" + + +class DrivAerAwsPaths: + @staticmethod + def _get_index(car_dir: Path) -> str: + return car_dir.name.removeprefix("run_") + + @staticmethod + def geometry_path(car_dir: Path) -> Path: + return car_dir / f"drivaer_{DrivAerAwsPaths._get_index(car_dir)}.stl" + + @staticmethod + def volume_path(car_dir: Path) -> Path: + return car_dir / f"volume_{DrivAerAwsPaths._get_index(car_dir)}.vtu" + + @staticmethod + def surface_path(car_dir: Path) -> Path: + return car_dir / f"boundary_{DrivAerAwsPaths._get_index(car_dir)}.vtp" + + +class OpenFoamDataset(Dataset): + """ + Datapipe for converting openfoam dataset to npy + + """ + + def __init__( + self, + data_path: Union[str, Path], + kind: Literal["drivesim", "drivaer_aws"] = "drivesim", + surface_variables: Optional[list] = [ + "pMean", + "wallShearStress", + ], + volume_variables: Optional[list] = ["UMean", "pMean"], + device: int = 0, + model_type=None, + ): + if isinstance(data_path, str): + data_path = Path(data_path) + data_path = data_path.expanduser() + + self.data_path = data_path + + supported_kinds = ["drivesim", "drivaer_aws"] + assert ( + kind in supported_kinds + ), f"kind should be one of {supported_kinds}, got {kind}" + self.path_getter = DriveSimPaths if kind == "drivesim" else DrivAerAwsPaths + + assert self.data_path.exists(), f"Path {self.data_path} does not exist" + + assert self.data_path.is_dir(), f"Path {self.data_path} is not a directory" + + self.filenames = get_filenames(self.data_path) + random.shuffle(self.filenames) + self.indices = np.array(len(self.filenames)) + + self.surface_variables = surface_variables + self.volume_variables = volume_variables + self.device = device + self.model_type = model_type + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + cfd_filename = self.filenames[idx] + car_dir = self.data_path / cfd_filename + + stl_path = self.path_getter.geometry_path(car_dir) + reader = pv.get_reader(stl_path) + mesh_stl = reader.read() + stl_vertices = mesh_stl.points + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ + :, 1: + ] # Assuming triangular elements + mesh_indices_flattened = stl_faces.flatten() + stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) + stl_sizes = np.array(stl_sizes.cell_data["Area"]) + stl_centers = np.array(mesh_stl.cell_centers().points) + + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + + if self.model_type == "volume" or self.model_type == "combined": + filepath = self.path_getter.volume_path(car_dir) + reader = vtk.vtkXMLUnstructuredGridReader() + reader.SetFileName(filepath) + reader.Update() + + # Get the unstructured grid data + polydata = reader.GetOutput() + volume_coordinates, volume_fields = get_volume_data( + polydata, self.volume_variables + ) + volume_fields = np.concatenate(volume_fields, axis=-1) + + # Non-dimensionalize volume fields + volume_fields[:, :3] = volume_fields[:, :3] / STREAM_VELOCITY + volume_fields[:, 3:4] = volume_fields[:, 3:4] / ( + AIR_DENSITY * STREAM_VELOCITY**2.0 + ) + + volume_fields[:, 4:] = volume_fields[:, 4:] / ( + STREAM_VELOCITY * length_scale + ) + else: + volume_fields = None + volume_coordinates = None + + if self.model_type == "surface" or self.model_type == "combined": + surface_filepath = self.path_getter.surface_path(car_dir) + reader = vtk.vtkXMLPolyDataReader() + reader.SetFileName(surface_filepath) + reader.Update() + polydata = reader.GetOutput() + + celldata_all = get_node_to_elem(polydata) + celldata = celldata_all.GetCellData() + surface_fields = get_fields(celldata, self.surface_variables) + surface_fields = np.concatenate(surface_fields, axis=-1) + + mesh = pv.PolyData(polydata) + surface_coordinates = np.array(mesh.cell_centers().points) + + surface_normals = np.array(mesh.cell_normals) + surface_sizes = mesh.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_sizes = np.array(surface_sizes.cell_data["Area"]) + + # Normalize cell normals + surface_normals = ( + surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] + ) + + # Non-dimensionalize surface fields + surface_fields = surface_fields / (AIR_DENSITY * STREAM_VELOCITY**2.0) + else: + surface_fields = None + surface_coordinates = None + surface_normals = None + surface_sizes = None + + # Add the parameters to the dictionary + return { + "stl_coordinates": np.float32(stl_vertices), + "stl_centers": np.float32(stl_centers), + "stl_faces": np.float32(mesh_indices_flattened), + "stl_areas": np.float32(stl_sizes), + "surface_mesh_centers": np.float32(surface_coordinates), + "surface_normals": np.float32(surface_normals), + "surface_areas": np.float32(surface_sizes), + "volume_fields": np.float32(volume_fields), + "volume_mesh_centers": np.float32(volume_coordinates), + "surface_fields": np.float32(surface_fields), + "filename": cfd_filename, + "stream_velocity": STREAM_VELOCITY, + "air_density": AIR_DENSITY, + } diff --git a/ppsci/data/process/__init__.py b/ppsci/data/process/__init__.py index cf217ef7f9..d3272a4730 100644 --- a/ppsci/data/process/__init__.py +++ b/ppsci/data/process/__init__.py @@ -13,10 +13,11 @@ # limitations under the License. from ppsci.data.process import batch_transform +from ppsci.data.process import openfoam from ppsci.data.process import transform __all__ = [ "batch_transform", - "openfoam", "transform", + "openfoam", ] diff --git a/ppsci/data/process/openfoam/__init__.py b/ppsci/data/process/openfoam/__init__.py index 92dc8f2974..66c1919070 100644 --- a/ppsci/data/process/openfoam/__init__.py +++ b/ppsci/data/process/openfoam/__init__.py @@ -14,7 +14,7 @@ # # refs: https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/external_aerodynamics/domino -from ppsci.data.process.openfoam import process_files +from ppsci.data.process.openfoam.preprocess import process_files __all__ = [ "process_files", From 385a128689f1876416e539f2db3b9de23c72c313 Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Fri, 30 May 2025 00:04:49 +0800 Subject: [PATCH 10/11] feat(domino): refactor physicsnemo --- examples/domino/domino.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/domino/domino.py b/examples/domino/domino.py index 073d22723c..8616b3e633 100644 --- a/examples/domino/domino.py +++ b/examples/domino/domino.py @@ -60,6 +60,8 @@ AIR_DENSITY = 1.205 STREAM_VELOCITY = 30.00 +paddle.set_device("gpu") + def process(cfg: DictConfig): print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") From 2b14a8a31463b74148de1936b6261f940ee60c7b Mon Sep 17 00:00:00 2001 From: xiaoyewww <641311428@qq.com> Date: Thu, 5 Jun 2025 00:30:11 +0800 Subject: [PATCH 11/11] feat(domino): complement reade --- docs/zh/examples/domino.md | 115 +++++++++++++++++++++++++++---------- requirements.txt | 1 - 2 files changed, 86 insertions(+), 30 deletions(-) diff --git a/docs/zh/examples/domino.md b/docs/zh/examples/domino.md index e7b3eeae11..79225ac4a4 100644 --- a/docs/zh/examples/domino.md +++ b/docs/zh/examples/domino.md @@ -36,41 +36,80 @@ ## 1. 背景简介 -外部空气动力学涉及高雷诺数Navier-Stokes方程求解,传统CFD方法计算成本高昂。神经算子通过端到端映射提升了效率,但面临多尺度耦合建模与长期预测稳定性不足的挑战。Decomposable Multi-scale Iterative Neural Operator(Domino)提出可分解多尺度架构,通过分层特征解耦、迭代残差校正及参数独立编码,显著提升跨尺度流动建模精度与泛化能力。实验显示,其计算速度较CFD快2-3个量级,分离流预测精度较FNO等模型提升约40%,为飞行器设计等工程问题提供高效解决方案。 +在现代工程产品的设计与开发过程中,数值模拟(如计算流体动力学,CFD)扮演着至关重要的角色。它们能够提供对复杂物理现象的精确预测,从而指导产品性能优化和设计迭代。然而,传统的高保真数值模拟方法,特别是针对具有复杂几何形状(例如汽车、飞机等)和大规模计算域的场景,往往需要耗费巨大的计算资源和时间。动辄数小时甚至数天的模拟周期,严重制约了设计迭代的效率和并行探索多个设计方案的可能性。 -## 2. 模型原理 +为了突破这一瓶颈,近年来研究人员积极探索将机器学习(ML)模型作为传统数值模拟的快速替代(即代理模型)。这些ML模型通过从大量的模拟数据中学习物理系统的输入-输出映射关系,从而在显著减少计算时间的同时,仍能保持可接受的精度。早期的ML代理模型在处理较小规模或简化问题时表现出一定的潜力。然而,当面对大型工程模拟时,这些模型常常暴露出局限性,例如在准确性和可扩展性方面存在瓶颈。许多现有方法依赖于对模拟网格进行大幅度的降采样,这不仅可能导致预测精度的下降,还会损害模型在未见数据上的泛化能力,限制了它们在实际复杂工程问题中的应用。因此,开发一种既能处理大规模数据、又能保持高精度和良好泛化能力的机器学习代理模型,成为了当前研究的迫切需求。 -DOMINO (Decomposable Multi-scale Iterative Neural Operator)是一种新颖的机器学习模型架构,旨在解决大规模工程仿真代理建模中的挑战。它是一个基于点云的机器学习模型,利用局部几何信息来预测离散点上的流场 。 +## 2. 问题定义 -以下是DOMINO模型的主要原理: +### 2.1 数据集 -- 全局几何表示学习(Global Geometry Representation): - - 模型首先以几何体的三维表面网格作为输入。 - - 在几何体周围构建一个紧密贴合的表面包围盒和一个表示计算域的包围盒。 - - 几何点云的特征(如空间坐标)通过可学习的点卷积核投影到表面包围盒上的N维结构化网格上(分辨率为$m×m×m×f$)。 - - 点卷积核的实现使用了NVIDIA Warp加速的自定义球查询层 。 - - 通过两种方法将几何特征传播到计算域包围盒中:1)学习一组单独的多尺度点卷积核,将几何信息投影到计算域网格上;2)使用包含卷积、池化和反池化层的CNN块,将表面包围盒网格上的特征$G_s$​传播到计算域包围盒网格$G_c$。CNN块会迭代评估。 - - 计算域网格上计算出的$m×m×m×f$特征代表了几何点云的全局编码。此外,还会计算符号距离场(SDF)及其梯度分量,并附加到学习到的特征中,以提供关于几何拓扑的额外信息。 +DrivAerML数据集是一个专门为汽车空气动力学机器学习研究而设计的大规模、高保真计算流体动力学(CFD)数据集。它包含了数百个经过几何变形的DrivAer溜背式汽车变体的高保真CFD模拟结果。该数据集以其庞大的数据量、高分辨率的网格以及显著的几何变化而著称,这为训练和测试能够处理复杂几何和流场的机器学习模型提供了理想的基准。与早期的DrivAerNet相比,DrivAerML及其后续版本如DrivAerNet++包含了更多样的几何设计(例如,DrivAerNet++包含8000个汽车设计,涵盖了传统的内燃机汽车和电动汽车的各种底盘和车轮设计),以及更为丰富的CFD模拟数据,包括STL格式的参数化汽车几何体、表面压力场数据,以及全三维的压力、速度和湍流场以及壁面剪切应力等。这些数据集的发布旨在为数据驱动的空气动力学设计提供大规模、高保真的数据,支持机器学习模型在空气动力学评估、生成设计等方面的训练。 -- 局部几何表示(Local Geometry Representation): - - 局部几何表示取决于计算域中评估解场的物理位置。 - - 在计算局部几何表示之前,会在计算域中采样一批离散点。 - - 对于批次中每个采样点,在其周围定义一个大小为$l×l×l$的子区域,并计算局部几何编码。 - - 局部编码本质上是全局编码的一个子集,取决于其在计算域中的位置,并通过点卷积计算。 - - 提取的局部特征通过全连接神经网络进一步转换。 - - 这种局部几何表示用于使用聚合网络评估采样点上的解场。 +### 2.2 主要问题 -- 聚合网络(Aggregation Network): - - 局部几何表示代表了采样点及其邻居的计算模板附近几何和解的学习特征。 - - 计算模板中的每个点都由其在计算域中的物理坐标、这些坐标处的SDF、来自域质心的法向量以及表面法向量(如果点在表面上)表示。 - - 这些输入特征通过一个全连接神经网络(称为基函数神经网络),计算出一个潜在向量,代表计算模板中每个点的这些特征。 - - 每个潜在向量与局部几何编码连接,并通过另一组全连接层,以预测计算模板中每个点上的解向量。 - - 解向量通过逆距离加权方案进行平均,以预测采样点处的最终解向量。 - - 对于每个解变量,都使用聚合网络的一个独立实例,但全局几何编码网络在它们之间是共享的。 +尽管机器学习代理模型在工程模拟领域展现出广阔前景,但其在大规模工程模拟中的应用仍面临诸多挑战。具体而言,本文所要解决的问题主要集中在以下几个方面: -DOMINO模型通过这种分解式、多尺度和迭代的方法,能够有效地处理大规模仿真数据,捕捉长距离和短距离的相互作用,并在不牺牲准确性的情况下提供可扩展、准确和可推广的代理模型 。 +- *可扩展性与计算效率*: 传统的ML模型在处理大规模计算域和高分辨率网格(通常包含数亿甚至数十亿个网格元素)时,面临巨大的内存和计算资源需求。它们难以有效地扩展到如此庞大的数据量,导致训练和推理时间过长,无法满足实时或准实时的工程应用需求。 -## 3. 完整代码 +- *几何表示与泛化能力*: 许多现有ML方法难以有效地表示复杂的三维几何形状。它们通常试图学习一个全局的几何表示来预测整个计算域的解场。然而,这种全局表示往往是高维且密集的,对于具有精细特征的复杂几何体,很难准确捕捉到解场与局部几何细节之间的复杂关系。此外,许多模型对输入数据的空间结构敏感,导致在不同网格类型或点云分布之间泛化能力较差。例如,在结构化网格上训练的模型可能无法很好地应用于非结构化网格或任意分布的点云数据。 + +- *精度与长程相互作用*: 尽管一些方法能够扩展到大型网格,但由于无法有效捕捉长程物理相互作用(例如,上游几何形状对下游流场的影响),它们的预测精度往往受到限制。这意味着模型可能无法准确预测远离输入几何体或在复杂流态(如尾流)中的物理量。 + +- *迭代特性与物理一致性*: 传统的模拟过程通常是迭代的,逐步收敛到稳态解。ML代理模型往往直接预测最终解,缺乏迭代特性,这可能导致其无法充分利用物理约束或通过迭代修正来提高解的物理一致性。 + +针对这些问题,本文的目标是开发一种新的ML模型架构,能够有效地处理大规模几何数据,准确捕捉局部和长程依赖关系,同时具有良好的泛化能力,并在计算效率上取得显著提升,从而成为高保真工程模拟的实用替代方案。 + +## 3. 模型原理 + +为了克服上述挑战,本文提出了DOMINO(Decomposable Multi-scale Iterative Neural Operator)模型。DOMINO的核心思想是结合了点云处理、多尺度学习和迭代优化,以高效地建模大规模工程模拟。 + +### 3.1 模型架构概述 + +DOMINO模型以三维几何体的表面点云(通常从表面网格转换而来)作为输入。它首先在几何体周围定义一个三维包围盒作为整个计算域。关键创新在于,DOMINO不试图一次性学习整个计算域的全局解,而是将问题分解为局部的、可并行处理的子问题,并通过多尺度和迭代的方式逐步细化解。 + +### 3.2 几何表示与特征提取 + +- *输入表示*: 原始的三维几何体通过其表面网格(例如,三角形网格)或点云进行表示。这些几何信息被转换为一个标准化的N维结构化表示,用于定义计算域内的分辨率。 + +- *全局几何编码网络*: 这是一个多尺度的点卷积网络,旨在从输入的表面几何体中提取丰富的几何特征。 + + - *多尺度点卷积*: 针对表面几何体,采用一系列具有不同半径参数(核大小)的点卷积核。这使得网络能够同时捕捉几何体的精细局部特征和长程的几何相互作用(例如,车身不同部位之间的相对位置关系)。 + - *特征传播到计算域*: 提取到的表面几何特征需要传播到整个计算域。DOMINO提供了两种方法: + 1. *独立的点卷积*: 学习一组单独的多尺度点卷积核,将表面几何信息直接投影到计算域的规则网格或点云上。 + 2. *CNN块传播*: 使用包含卷积、池化和反池化层的U-Net风格的卷积神经网络(CNN)块,将表面包围盒网格上提取的特征有效、分层地传播到计算域包围盒网格上。这种方法能够更好地捕捉不同尺度下的空间上下文信息。 + +- *局部几何编码*: 尽管全局几何编码提供了丰富的上下文信息,但计算域中任何一点的解场主要受其局部物理环境的影响。因此,DOMINO设计了一种机制来从全局几何编码中提取特定点的局部几何编码。这通过在每个采样点周围定义一个子区域并对该子区域内的特征进行聚合来实现,确保模型能够关注与当前预测点最相关的几何特征。 + +### 3.3 迭代与多分辨率预测框架 + +DOMINO采用了一种多分辨率迭代方法来逐步精化预测结果,这模拟了传统数值模拟的迭代求解过程: + +- *计算模板*: 在计算域中随机或均匀采样一批离散点。对于每个采样点,在其周围采样一定数量(p个)的邻近点,形成一个“计算模板”(类似于有限差分或有限体积方法中的计算单元)。这些邻近点及其特征构成了当前点预测的局部上下文。 + +- *聚合网络*: 一个专门的聚合神经网络被设计用于处理每个采样点及其计算模板的输入特征。这个网络结合了局部几何编码、模板中点的坐标信息以及迭代过程中的当前预测值。它通过对这些局部信息的聚合和非线性变换,预测出当前采样点的新解值。 + +- *迭代细化*: DOMINO是一个迭代模型。在每次迭代中,模型会基于当前的解场和几何信息,利用聚合网络更新所有采样点的解。这种迭代过程允许模型逐步收敛到更准确的解,并可以模拟物理系统中的信息传播。 + +- *多分辨率*: DOMINO支持多分辨率处理。模型可以在不同分辨率的网格或点云上进行预测,通过迭代过程在粗粒度上捕捉大尺度特征,然后在细粒度上捕捉局部细节。 + +### 3.4 表面与体积变量预测 + +DOMINO能够同时预测表面变量(如压力系数$C_p$、壁面剪切应力$\tau_w$)和体积变量(如速度场$u$、压力$p$、湍流参数等)。由于表面变量和体积变量的物理特性和分布模式不同,DOMINO为表面预测和体积预测设计了独立的聚合神经网络。然而,共享的全局几何编码网络可以为两者提供统一的几何上下文信息,提高了模型的效率。 + +$$y_i=\Sigma_{j=0}^{j=n_y}f(\overrightarrow{x_i},\overrightarrow{x_j},d_{ij})$$ + +其中,$\overrightarrow{x_i},\overrightarrow{x_j}$分别表示不同的点云数据集,$d_{ij}$表示两者之间的距离,$n_y$表示独立的聚合神经网络。 + +### 3.5 损失函数与训练 + +模型通常通过最小化预测值与真实CFD模拟数据之间的L2范数误差进行训练。为了提高模型的泛化能力和鲁棒性,可能会结合其他损失项,例如物理约束损失或正则化项。 + +$$\epsilon=\frac{\sqrt(\Sigma(y_T^2))-y_2^P}{\sqrt(\Sigma(y_T^2))}$$ + +其中,$y_T^2$和$y_P^2$分别表示真实值与预测CFD模拟数据。 + +## 4. 完整代码 ``` py linenums="1" title="examples/domino/domino.py" --8<-- @@ -78,8 +117,26 @@ examples/domino/domino.py --8<-- ``` -## 4. 结果展示 +## 5. 结果展示 + +### 5.1 实验结果与性能评估 + +在DrivAerML数据集上的实验结果充分证明了DOMINO模型的有效性和优越性: + +- *准确捕获表面和体积流场*: DOMINO模型能够准确地预测汽车表面的压力分布和壁面剪切应力,这些是评估车辆空气动力学性能的关键指标。例如,在挡风玻璃、侧后视镜、车身底部等关键区域,模型的预测值与高保真CFD模拟结果高度吻合。对于体积流场,如速度、压力和湍流粘度,DOMINO也能在整个计算域内提供高精度的预测。 + +- *准确捕捉设计趋势与工程指标*: 除了流场的可视化对比,DOMINO还能准确预测关键工程指标,例如汽车的空气阻力(Drag Force)。模型不仅能给出准确的阻力值,还能捕获不同几何变体下的阻力设计趋势,这对于辅助工程师进行快速设计迭代至关重要。研究中提供了模拟值与DOMINO预测值之间的回归图,进一步证实了其预测精度。 + +- *出色的泛化能力*: 模型在训练集中未见过的几何变体上表现出强大的泛化能力,这表明DOMINO能够学习到通用的物理规律和几何-流场映射关系,而不仅仅是记忆训练数据。 + +- *网格独立性*: 这是DOMINO的一个显著优势。模型在均匀采样的点云上而非原始模拟网格上进行验证,证明了其预测能力不依赖于特定的网格结构。这意味着DOMINO可以应用于不同离散化方式或分辨率的数据,极大地增强了其在实际应用中的灵活性。 + +- *计算效率*: 尽管论文中没有直接给出具体的加速倍数,但作为基于深度学习的代理模型,DOMINO的目标就是在保持高精度的前提下,显著降低推理时间,从而实现近乎实时的空气动力学评估,这对于迭代设计和优化过程至关重要。通过利用局部信息和迭代细化,DOMINO避免了处理整个大规模网格的内存和计算瓶颈。 + +### 5.2 局限性与未来工作 + +尽管DOMINO取得了显著进展,但论文可能也指出了未来改进的方向。例如,对于一些稀疏且复杂的物理量(如湍流粘度),特别是在远离几何体的区域,预测精度仍有提升空间。未来工作可能包括探索更先进的神经网络架构、更有效的损失函数、或者结合更多物理约束来进一步提高模型在各种复杂流态下的预测精度和鲁棒性。 -## 5. 参考资料 +## 6. 参考资料 - [DoMINO: A Decomposable Multi-scale Iterative Neural Operator for Modeling Large Scale Engineering Simulations](https://arxiv.org/abs/2501.13350) diff --git a/requirements.txt b/requirements.txt index 58c8be4a3d..0858a3520e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ colorlog einops h5py hydra-core -hydra-core imageio importlib_metadata matplotlib