Skip to content

Commit bd71872

Browse files
authored
[Update]update setting of 'auto_collation' and fix errors (#783)
* set 'auto_collation' false when using data transform * fix data transform error of topopt example
1 parent 7ebdc0b commit bd71872

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

examples/topopt/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def augmentation(
113113
"""
114114
inputs = input_dict["input"]
115115
labels = label_dict["output"]
116+
assert len(inputs.shape) == 3
117+
assert len(labels.shape) == 3
116118

117119
# random horizontal flip
118120
if np.random.random() > 0.5:
@@ -125,7 +127,7 @@ def augmentation(
125127
# random 90* rotation
126128
if np.random.random() > 0.5:
127129
new_perm = list(range(len(inputs.shape)))
128-
new_perm[1], new_perm[2] = new_perm[2], new_perm[1]
130+
new_perm[-2], new_perm[-1] = new_perm[-1], new_perm[-2]
129131
inputs = np.transpose(inputs, new_perm)
130132
labels = np.transpose(labels, new_perm)
131133

examples/topopt/topoptmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TopOptNN(ppsci.arch.UNetEx):
3939
4040
Examples:
4141
>>> import ppsci
42-
>>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 2, 1, 3, (16, 32, 64), 2, lambda: 1, Flase, False)
42+
>>> model = ppsci.arch.ppsci.arch.TopOptNN("input", "output", 2, 1, 3, (16, 32, 64), 2, lambda: 1, Flase, False)
4343
"""
4444

4545
def __init__(

ppsci/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def build_dataloader(_dataset, cfg):
134134
if (
135135
cfg.get("auto_collation", not getattr(_dataset, "batch_index", False))
136136
is False
137+
and "transforms" not in cfg["dataset"]
137138
):
138139
# 1. wrap batch_sampler again into BatchSampler for disabling auto collation,
139140
# which can speed up the process of batch samples indexing from dataset. See

0 commit comments

Comments
 (0)