Skip to content

Commit d629745

Browse files
[Fix] Fix weight dtype and bug in symbolic (#709)
* fix dtype of weight to avoid unnecessary type promotion * fix bug in symbolic
1 parent 0d3db44 commit d629745

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

examples/hpinns/functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def init_lambda(output_dict: Dict[str, paddle.Tensor], bound: int):
100100
"""
101101
global lambda_re, lambda_im, loss_weight
102102
x, y = output_dict["x"], output_dict["y"]
103-
lambda_re = np.zeros((len(x[bound:]), 1))
104-
lambda_im = np.zeros((len(y[bound:]), 1))
103+
lambda_re = np.zeros((len(x[bound:]), 1), paddle.get_default_dtype())
104+
lambda_im = np.zeros((len(y[bound:]), 1), paddle.get_default_dtype())
105105
# loss_weight: [PDE loss 1, PDE loss 2, Lagrangian loss 1, Lagrangian loss 2, objective loss]
106106
if train_mode == "aug_lag":
107107
loss_weight = [0.5 * mu] * 2 + [1.0, 1.0] + [1.0]

ppsci/data/dataset/array_dataset.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ def __init__(
108108
self.input_keys = tuple(input.keys())
109109
self.label_keys = tuple(self.label.keys())
110110
self.weight = (
111-
{key: paddle.to_tensor(value) for key, value in weight.items()}
111+
{
112+
key: paddle.to_tensor(value, paddle.get_default_dtype())
113+
for key, value in weight.items()
114+
}
112115
if weight is not None
113116
else None
114117
)

ppsci/data/dataset/trphysx_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def __getitem__(self, i):
130130

131131
weight_shape = [1] * len(data_item.shape)
132132
weight_item = {
133-
key: np.full(weight_shape, value) for key, value in self.weight_dict.items()
133+
key: np.full(weight_shape, value, paddle.get_default_dtype())
134+
for key, value in self.weight_dict.items()
134135
}
135136
return (input_item, label_item, weight_item)
136137

@@ -307,6 +308,7 @@ def __getitem__(self, i):
307308
label_item[self.label_keys[1]] = data_item[1:, :]
308309
weight_shape = [1] * len(data_item.shape)
309310
weight_item = {
310-
key: np.full(weight_shape, value) for key, value in self.weight_dict.items()
311+
key: np.full(weight_shape, value, paddle.get_default_dtype())
312+
for key, value in self.weight_dict.items()
311313
}
312314
return (input_item, label_item, weight_item)

ppsci/utils/symbolic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ def _minimum_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT:
256256
)
257257
for i in range(2, len(self.childs)):
258258
data_dict[self.key] = paddle.minimum(
259-
data_dict[data_dict[self.key]],
260-
data_dict[data_dict[self.childs[i]]],
259+
data_dict[self.key],
260+
data_dict[self.childs[i]],
261261
)
262262
return data_dict
263263

@@ -267,8 +267,8 @@ def _maximum_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT:
267267
)
268268
for i in range(2, len(self.childs)):
269269
data_dict[self.key] = paddle.maximum(
270-
data_dict[data_dict[self.key]],
271-
data_dict[data_dict[self.childs[i]]],
270+
data_dict[self.key],
271+
data_dict[self.childs[i]],
272272
)
273273
return data_dict
274274

0 commit comments

Comments
 (0)