Skip to content

Commit 25bb1bd

Browse files
[Enh] Refactor sum aggregator (#834)
* add Sum loss aggregator * simplify loss aggregation code in train.py and add check for AGDA and PCGrad when used with amp * add check for using L-BFGS with use_amp=True * Refine Relobralo * Fix docstring of timedomain.py * remove unnecessary code in train.py * automatically download *.pdeqn file if available when download pretrained model * wrap func generated by symbolic module with DDP * fix Relobralo * initialize loss with 0.0 instead of first loss
1 parent fcb277c commit 25bb1bd

File tree

11 files changed

+117
-45
lines changed

11 files changed

+117
-45
lines changed

docs/zh/api/loss/mtl.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
- LossAggregator
99
- PCGrad
1010
- Relobralo
11+
- Sum
1112
show_root_heading: true
1213
heading_level: 3

docs/zh/examples/viv.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
=== "模型评估命令"
1212

1313
``` sh
14-
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn
15-
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
16-
python viv.py mode=eval EVAL.pretrained_model_path=./viv_pretrained
14+
python viv.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
1715
```
1816

1917
| 预训练模型 | 指标 |

ppsci/geometry/timedomain.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def random_points(
207207
208208
Args:
209209
n (int): The total number of random points to generate.
210-
random (string): Specifies the way to generate random points, default is "pseudo" , which means that a pseudo-random number generator is used.
210+
random (str): Specifies the way to generate random points, default is "pseudo" , which means that a pseudo-random number generator is used.
211211
criteria (Optional[Callable]): A method that filters on the generated random points, defualt is None.
212212
213213
Returns:
@@ -432,7 +432,7 @@ def random_boundary_points(
432432
433433
Args:
434434
n (int): The total number of spatial-temporal points generated on a given geometry boundary.
435-
random (string): Controls the way to generate random points. Default is "pseudo".
435+
random (str): Controls the way to generate random points. Default is "pseudo".
436436
criteria (Optional[Callable]): Used to filter the generated boundary points, only points that meet certain conditions are retained. Default is None.
437437
438438
Returns:
@@ -650,7 +650,7 @@ def random_initial_points(self, n: int, random: str = "pseudo"):
650650
651651
Args:
652652
n (int): The total number of generated points.
653-
random (string): Controls the way to generate random points. Default is "pseudo".
653+
random (str): Controls the way to generate random points. Default is "pseudo".
654654
655655
Returns:
656656
np.ndarray: A set of point coordinates randomly distributed on the spatial-temporal domain at the initial moment.
@@ -709,7 +709,7 @@ def sample_initial_interior(
709709
710710
Args:
711711
n (int): The total number of interior points generated.
712-
random (string): The method used to specify the initial point of generation. Default is "pseudo".
712+
random (str): The method used to specify the initial point of generation. Default is "pseudo".
713713
criteria (Optional[Callable]): Used to filter the generated interior points, only points that meet certain conditions are retained. Default is None.
714714
evenly (bool): Indicates whether the initial points are generated evenly. Default is False.
715715
compute_sdf_derivatives (bool): Indicates whether to calculate the derivative of signed distance function or not. Default is False.

ppsci/loss/mtl/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
from ppsci.loss.mtl.base import LossAggregator
1919
from ppsci.loss.mtl.pcgrad import PCGrad
2020
from ppsci.loss.mtl.relobralo import Relobralo
21+
from ppsci.loss.mtl.sum import Sum
2122

2223
__all__ = [
2324
"AGDA",
2425
"LossAggregator",
2526
"PCGrad",
2627
"Relobralo",
28+
"Sum",
2729
]
2830

2931

ppsci/loss/mtl/agda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
5757
self.Lf_tilde_acc = 0.0
5858
self.Lu_tilde_acc = 0.0
5959

60-
def __call__(self, losses, step: int = 0):
60+
def __call__(self, losses, step: int = 0) -> "AGDA":
6161
if len(losses) != 2:
6262
raise ValueError(
6363
f"Number of losses(tasks) for AGDA shoule be 2, but got {len(losses)}"

ppsci/loss/mtl/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, model: nn.Layer) -> None:
3232
if not param.stop_gradient:
3333
self.param_num += 1
3434

35-
def __call__(self, losses, step: int = 0):
35+
def __call__(self, losses, step: int = 0) -> "LossAggregator":
3636
self.losses = losses
3737
self.loss_num = len(losses)
3838
self.step = step

ppsci/loss/mtl/relobralo.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,28 @@ def __init__(
6969
self.register_buffer("losses_prev", paddle.zeros([self.num_losses]))
7070
self.register_buffer("lmbda", paddle.ones([self.num_losses]))
7171

72-
def _softmax(self, vec: paddle.Tensor) -> paddle.Tensor:
72+
def _softmax(self, vec: "paddle.Tensor") -> "paddle.Tensor":
7373
max_item = vec.max()
7474
result = paddle.exp(vec - max_item) / paddle.exp(vec - max_item).sum()
7575
return result
7676

7777
def _compute_bal(
78-
self, losses_vec1: paddle.Tensor, losses_vec2: paddle.Tensor
79-
) -> paddle.Tensor:
78+
self, losses_vec1: "paddle.Tensor", losses_vec2: "paddle.Tensor"
79+
) -> "paddle.Tensor":
8080
return self.num_losses * (
8181
self._softmax(losses_vec1 / (self.tau * losses_vec2 + self.eps))
8282
)
8383

84-
def __call__(self, losses: List[paddle.Tensor], step: int = 0) -> "Relobralo":
85-
self.step = step
84+
def __call__(self, losses: List["paddle.Tensor"], step: int = 0) -> "paddle.Tensor":
8685
assert len(losses) == self.num_losses, (
8786
f"Length of given losses({len(losses)}) should be equal to "
8887
f"num_losses({self.num_losses})."
8988
)
89+
self.step = step
9090
losses_stacked = paddle.stack(losses) # [num_losses, ]
9191

9292
if self.step == 0:
93-
self.loss = losses_stacked.sum()
93+
loss = losses_stacked.sum()
9494
with paddle.no_grad():
9595
paddle.assign(losses_stacked.detach(), self.losses_init)
9696
else:
@@ -110,12 +110,10 @@ def __call__(self, losses: List[paddle.Tensor], step: int = 0) -> "Relobralo":
110110
)
111111

112112
# 3. compute reweighted total loss with lambda
113-
self.loss = (losses_stacked * self.lmbda).sum()
113+
loss = (losses_stacked * self.lmbda).sum()
114114

115115
# update losses_prev at the end of each step
116116
with paddle.no_grad():
117117
paddle.assign(losses_stacked.detach(), self.losses_prev)
118-
return self
119118

120-
def backward(self) -> None:
121-
self.loss.backward()
119+
return loss

ppsci/loss/mtl/sum.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
from typing import Sequence
19+
20+
if TYPE_CHECKING:
21+
import paddle
22+
23+
from ppsci.loss.mtl.base import LossAggregator
24+
25+
26+
class Sum(LossAggregator):
27+
r"""
28+
**Default loss aggregator** which do simple summation for given losses as below.
29+
30+
$$
31+
loss = \sum_i^N losses_i
32+
$$
33+
"""
34+
35+
def __init__(self) -> None:
36+
self.step = 0
37+
38+
def __call__(
39+
self, losses: Sequence["paddle.Tensor"], step: int = 0
40+
) -> paddle.Tensor:
41+
assert (
42+
len(losses) > 0
43+
), f"Number of given losses({len(losses)}) can not be empty."
44+
self.step = step
45+
46+
loss = 0.0
47+
for i in range(len(losses)):
48+
loss += losses[i]
49+
50+
return loss

ppsci/solver/solver.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ def __init__(
300300

301301
# choosing an appropriate training function for different optimizers
302302
if misc.typename(self.optimizer) == "LBFGS":
303+
if self.use_amp:
304+
raise ValueError(
305+
"Auto Mix Precision is not supported for L-BFGS optimizer."
306+
)
303307
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
304308
if self.update_freq != 1:
305309
self.update_freq = 1
@@ -398,8 +402,13 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
398402
jit.enable_to_static(to_static)
399403
logger.info(f"Set to_static={to_static} for computational optimization.")
400404

401-
# use loss aggregator, use summation if None
402-
self.loss_aggregator = loss_aggregator
405+
# use loss aggregator, use Sum if None
406+
if isinstance(loss_aggregator, (mtl.AGDA, mtl.PCGrad)) and self.use_amp:
407+
raise ValueError(
408+
"Auto Mix Precision do not support AGDA, PCGrad loss aggregator yet, "
409+
"please set use_amp=False."
410+
)
411+
self.loss_aggregator = loss_aggregator or mtl.Sum()
403412

404413
# convert sympy to callable object if exist
405414
extra_parameters = []
@@ -432,6 +441,10 @@ def convert_expr(
432441
for name in container.output_expr:
433442
if isinstance(container.output_expr[name], sp.Basic):
434443
container.output_expr[name] = funcs[ind]
444+
if self.world_size > 1:
445+
container.output_expr[name] = dist_wrapper(
446+
container.output_expr[name]
447+
)
435448
ind += 1
436449

437450
if self.constraint:
@@ -775,7 +788,6 @@ def export(
775788
)
776789
logger.message(f"ONNX model has been exported to: {export_path}.onnx")
777790

778-
@functools.lru_cache()
779791
def autocast_context_manager(
780792
self, enable: bool, level: Literal["O0", "O1", "O2", "OD"] = "O1"
781793
) -> contextlib.AbstractContextManager:
@@ -820,7 +832,6 @@ def no_grad_context_manager(
820832
)
821833
return ctx_manager
822834

823-
@functools.lru_cache()
824835
def no_sync_context_manager(
825836
self,
826837
enable: bool,

ppsci/solver/train.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
4545
f"Training iteration {solver.global_step + 1}"
4646
) # Training iteration
4747

48-
total_loss = 0.0
4948
total_batch_size = 0
5049
reader_cost = 0.0
5150
batch_cost = 0.0
@@ -106,31 +105,30 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
106105
if solver.nvtx_flag: # only for nsight analysis
107106
core.nvprof_nvtx_push("Loss aggregator")
108107

108+
total_loss = solver.loss_aggregator(
109+
constraint_losses, solver.global_step
110+
)
111+
if solver.update_freq > 1:
112+
total_loss = total_loss / solver.update_freq
113+
109114
for i, _constraint in enumerate(solver.constraint.values()):
110-
total_loss += constraint_losses[i]
111-
loss_dict[_constraint.name] += (
115+
loss_dict[_constraint.name] = (
112116
float(constraint_losses[i]) / solver.update_freq
113117
)
114-
if solver.update_freq > 1:
115-
total_loss = total_loss / solver.update_freq
118+
loss_dict["loss"] = float(total_loss)
116119

117120
if solver.nvtx_flag: # only for nsight analysis
118121
core.nvprof_nvtx_pop() # Loss aggregator
119122

120-
loss_dict["loss"] = float(total_loss)
121-
122123
# backward
123124
if solver.nvtx_flag: # only for nsight analysis
124125
core.nvprof_nvtx_push("Loss backward")
125126

126-
if solver.loss_aggregator is None:
127-
if solver.use_amp:
128-
total_loss_scaled = solver.scaler.scale(total_loss)
129-
total_loss_scaled.backward()
130-
else:
131-
total_loss.backward()
127+
if solver.use_amp:
128+
total_loss_scaled = solver.scaler.scale(total_loss)
129+
total_loss_scaled.backward()
132130
else:
133-
solver.loss_aggregator(constraint_losses, solver.global_step).backward()
131+
total_loss.backward()
134132

135133
if solver.nvtx_flag: # only for nsight analysis
136134
core.nvprof_nvtx_pop() # Loss backward
@@ -233,7 +231,6 @@ def closure() -> paddle.Tensor:
233231
Returns:
234232
paddle.Tensor: Computed loss scalar.
235233
"""
236-
total_loss = 0
237234
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
238235
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
239236
# forward for every constraint, including model and equation expression
@@ -248,20 +245,18 @@ def closure() -> paddle.Tensor:
248245
label_dicts,
249246
weight_dicts,
250247
)
248+
249+
total_loss = solver.loss_aggregator(
250+
constraint_losses, solver.global_step
251+
)
251252
# accumulate all losses
252253
for i, _constraint in enumerate(solver.constraint.values()):
253-
total_loss += constraint_losses[i]
254254
loss_dict[_constraint.name] = float(constraint_losses[i])
255255
loss_dict["loss"] = float(total_loss)
256256

257257
# backward
258258
solver.optimizer.clear_grad()
259-
if solver.loss_aggregator is None:
260-
total_loss.backward()
261-
else:
262-
solver.loss_aggregator(
263-
constraint_losses, solver.global_step
264-
).backward()
259+
total_loss.backward()
265260

266261
if solver.world_size > 1:
267262
# fuse + allreduce manually before optimization if use DDP model

ppsci/utils/save_load.py

+17
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,25 @@ def load_pretrain(
9898
... path="path/to/pretrain_model") # doctest: +SKIP
9999
"""
100100
if path.startswith("http"):
101+
# download from path(url) and get its' physical path
102+
eqn_path = path.replace(".pdparams", ".pdeq", 1)
101103
path = download.get_weights_path_from_url(path)
102104

105+
# automatically download additional equation weights if avaiable
106+
def is_url_accessible(url: str):
107+
try:
108+
import requests
109+
110+
response = requests.head(url, timeout=5)
111+
return response.status_code == requests.codes.ok
112+
except requests.RequestException:
113+
return False
114+
except Exception:
115+
return False
116+
117+
if is_url_accessible(eqn_path):
118+
download.get_weights_path_from_url(eqn_path)
119+
103120
# remove ".pdparams" in suffix of path for convenient
104121
if path.endswith(".pdparams"):
105122
path = path[:-9]

0 commit comments

Comments
 (0)