Skip to content

Commit bbf1423

Browse files
authored
fix gridmask (#2695)
1 parent 9eb9cd1 commit bbf1423

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

ppdet/data/transform/gridmask_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def __call__(self, x, curr_iter):
4545
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
4646
if np.random.rand() > self.prob:
4747
return x
48-
# image should be C, H, W format
49-
_, h, w = x.shape
48+
h, w, _ = x.shape
5049
hh = int(1.5 * h)
5150
ww = int(1.5 * w)
5251
d = np.random.randint(2, h)
@@ -74,7 +73,7 @@ def __call__(self, x, curr_iter):
7473

7574
if self.mode == 1:
7675
mask = 1 - mask
77-
mask = np.expand_dims(mask, axis=0)
76+
mask = np.expand_dims(mask, axis=-1)
7877
if self.offset:
7978
offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32)
8079
x = (x * mask + offset * (1 - mask)).astype(x.dtype)

static/ppdet/data/transform/gridmask_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __call__(self, x, curr_iter):
4646
if np.random.rand() > self.prob:
4747
return x
4848
# image should be C, H, W format
49-
_, h, w = x.shape
49+
h, w, _ = x.shape
5050
hh = int(1.5 * h)
5151
ww = int(1.5 * w)
5252
d = np.random.randint(2, h)
@@ -74,7 +74,7 @@ def __call__(self, x, curr_iter):
7474

7575
if self.mode == 1:
7676
mask = 1 - mask
77-
mask = np.expand_dims(mask, axis=0)
77+
mask = np.expand_dims(mask, axis=-1)
7878
if self.offset:
7979
offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32)
8080
x = (x * mask + offset * (1 - mask)).astype(x.dtype)

0 commit comments

Comments
 (0)