Skip to content

Commit eaff0fc

Browse files
[Fix] restore 'by_epoch' for SchedulerList and fix EPNN (#777)
* restore 'by_epoch' for SchedulerList * fix for epnn
1 parent 988fd33 commit eaff0fc

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

examples/epnn/functions.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def __init__(self, data_state, data_stress, itrain):
226226
self.data_stress = data_stress
227227
self.itrain = itrain
228228

229+
def _cvt_to_ndarray(self, list_dict):
230+
for key in list_dict:
231+
list_dict[key] = np.asarray(list_dict[key])
232+
return list_dict
233+
229234
def get(self, epochs=1):
230235
# Slow if using BatchSampler to obtain data
231236
input_dict_train = {
@@ -243,7 +248,7 @@ def get(self, epochs=1):
243248
label_dict_train = {"dummy_loss": []}
244249
label_dict_val = {"dummy_loss": []}
245250
for i in range(epochs):
246-
shuffled_indices = paddle.randperm(n=self.data_state.x_train.shape[0])
251+
shuffled_indices = np.random.permutation(self.data_state.x_train.shape[0])
247252
input_dict_train["state_x"].append(
248253
self.data_state.x_train[shuffled_indices[0 : self.itrain]]
249254
)
@@ -256,9 +261,9 @@ def get(self, epochs=1):
256261
input_dict_train["stress_y"].append(
257262
self.data_stress.y_train[shuffled_indices[0 : self.itrain]]
258263
)
259-
label_dict_train["dummy_loss"].append(paddle.to_tensor(0.0))
264+
label_dict_train["dummy_loss"].append(0.0)
260265

261-
shuffled_indices = paddle.randperm(n=self.data_state.x_valid.shape[0])
266+
shuffled_indices = np.random.permutation(self.data_state.x_valid.shape[0])
262267
input_dict_val["state_x"].append(
263268
self.data_state.x_valid[shuffled_indices[0 : self.itrain]]
264269
)
@@ -271,7 +276,11 @@ def get(self, epochs=1):
271276
input_dict_val["stress_y"].append(
272277
self.data_stress.y_valid[shuffled_indices[0 : self.itrain]]
273278
)
274-
label_dict_val["dummy_loss"].append(paddle.to_tensor(0.0))
279+
label_dict_val["dummy_loss"].append(0.0)
280+
input_dict_train = self._cvt_to_ndarray(input_dict_train)
281+
label_dict_train = self._cvt_to_ndarray(label_dict_train)
282+
input_dict_val = self._cvt_to_ndarray(input_dict_val)
283+
label_dict_val = self._cvt_to_ndarray(label_dict_val)
275284
return input_dict_train, label_dict_train, input_dict_val, label_dict_val
276285

277286

@@ -287,7 +296,7 @@ def __init__(self, dataset_path, train_p=0.6, cross_valid_p=0.2, test_p=0.2):
287296
def get_shuffled_data(self):
288297
# Need to set the seed, otherwise the loss will not match the precision
289298
ppsci.utils.misc.set_random_seed(seed=10)
290-
shuffled_indices = paddle.randperm(n=self.x.shape[0])
299+
shuffled_indices = np.random.permutation(self.x.shape[0])
291300
n_train = math.floor(self.train_p * self.x.shape[0])
292301
n_cross_valid = math.floor(self.cross_valid_p * self.x.shape[0])
293302
n_test = math.floor(self.test_p * self.x.shape[0])

ppsci/data/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def build_dataloader(_dataset, cfg):
8989
logger.warning(
9090
"`batch_size` is set to 1 as neither sampler config nor batch_size is set."
9191
)
92-
batch_sampler = None
92+
batch_sampler = io.BatchSampler(
93+
_dataset,
94+
batch_size=cfg["batch_size"],
95+
)
9396

9497
# build collate_fn if specified
9598
batch_transforms_cfg = cfg.pop("batch_transforms", None)

ppsci/optimizer/lr_scheduler.py

+1
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ class SchedulerList:
750750
def __init__(self, scheduler_list: Tuple[lr.LRScheduler, ...]):
751751
super().__init__()
752752
self._sch_list = scheduler_list
753+
self.by_epoch = False
753754

754755
def step(self):
755756
for sch in self._sch_list:

0 commit comments

Comments
 (0)