Skip to content

Commit 9eb9cd1

Browse files
authored
modify gridmask op (#2693)
1 parent a542e3d commit 9eb9cd1

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

ppdet/data/transform/gridmask_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __call__(self, x, curr_iter):
4545
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
4646
if np.random.rand() > self.prob:
4747
return x
48-
h, w, _ = x.shape
48+
# image should be C, H, W format
49+
_, h, w = x.shape
4950
hh = int(1.5 * h)
5051
ww = int(1.5 * w)
5152
d = np.random.randint(2, h)

static/ppdet/data/transform/gridmask_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __call__(self, x, curr_iter):
4545
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
4646
if np.random.rand() > self.prob:
4747
return x
48-
h, w, _ = x.shape
48+
# image should be C, H, W format
49+
_, h, w = x.shape
4950
hh = int(1.5 * h)
5051
ww = int(1.5 * w)
5152
d = np.random.randint(2, h)

0 commit comments

Comments
 (0)