Skip to content

Commit b733fa6

Browse files
committed
add dynamic support for max_grad op
1 parent 1bde725 commit b733fa6

File tree

4 files changed

+225
-34
lines changed

4 files changed

+225
-34
lines changed

paddle/fluid/primitive/rule/vjp/details.h

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,43 +1560,91 @@ void max_grad(const Tensor& x,
15601560
if (!x_grad) {
15611561
return;
15621562
}
1563-
auto zero_tensor = full<T>(common::vectorize(x.dims()), 0.0, x.dtype());
1564-
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
1565-
int64_t axis_size = axis.size();
1566-
int64_t x_dim_size = x_dim.size();
1567-
reduce_all = false;
1568-
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
1569-
reduce_all = true;
1570-
} else {
1563+
1564+
Tensor x_grad_tmp;
1565+
if (has_dynamic_shape(x.shape())) {
1566+
const Tensor x_shape = shape<T>(x);
1567+
const Tensor zero_tensor =
1568+
backend::full_with_tensor<T>(x_shape, 0.0, x.dtype());
1569+
const size_t axis_size = axis.size();
1570+
const size_t x_dim_size = x.dims().size();
1571+
15711572
reduce_all = false;
1572-
}
1573-
auto x_grad_tmp = Tensor();
1574-
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
1575-
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
1576-
auto out_tmp = out.expand(IntArray(x_dim));
1577-
auto mask = equal<T>(x, out_tmp);
1578-
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1579-
} else {
1580-
auto axis_ = std::vector<int64_t>();
1581-
if (reduce_all) {
1582-
for (int64_t i = 0; i < x_dim_size; i++) {
1583-
axis_.push_back(i);
1573+
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
1574+
reduce_all = true;
1575+
} else {
1576+
reduce_all = false;
1577+
}
1578+
1579+
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
1580+
auto out_grad_tmp = backend::expand<T>(out_grad, x_shape);
1581+
auto out_tmp = backend::expand<T>(out, x_shape);
1582+
auto mask = equal<T>(x, out_tmp);
1583+
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1584+
} else {
1585+
const Tensor out_grad_shape = shape<T>(out_grad);
1586+
auto axis_ = std::vector<int64_t>();
1587+
1588+
if (reduce_all) {
1589+
for (int64_t i = 0; i < x_dim_size; i++) {
1590+
axis_.push_back(i);
1591+
}
1592+
} else {
1593+
axis_ = axis.GetData();
1594+
for (int64_t i = 0; i < axis_size; i++) {
1595+
if (axis[i] < 0) {
1596+
axis_[i] = axis[i] + x_dim_size;
1597+
}
1598+
}
15841599
}
1600+
const Tensor out_grad_shape_extend =
1601+
get_unsqueeze_dims<T>(out_grad_shape, axis_);
1602+
auto out_grad_ = backend::reshape<T>(out_grad, out_grad_shape_extend);
1603+
auto out_ = backend::reshape<T>(out, out_grad_shape_extend);
1604+
auto out_grad_tmp = backend::expand<T>(out_grad_, x_shape);
1605+
auto out_tmp = backend::expand<T>(out_, x_shape);
1606+
auto mask = equal<T>(x, out_tmp);
1607+
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1608+
}
1609+
} else {
1610+
auto zero_tensor = full<T>(common::vectorize(x.dims()), 0.0, x.dtype());
1611+
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
1612+
int64_t axis_size = axis.size();
1613+
int64_t x_dim_size = x_dim.size();
1614+
reduce_all = false;
1615+
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
1616+
reduce_all = true;
15851617
} else {
1586-
axis_ = axis.GetData();
1587-
for (int64_t i = 0; i < axis_size; i++) {
1588-
if (axis[i] < 0) {
1589-
axis_[i] = axis[i] + x_dim_size;
1618+
reduce_all = false;
1619+
}
1620+
1621+
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
1622+
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
1623+
auto out_tmp = out.expand(IntArray(x_dim));
1624+
auto mask = equal<T>(x, out_tmp);
1625+
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
1626+
} else {
1627+
auto axis_ = std::vector<int64_t>();
1628+
if (reduce_all) {
1629+
for (int64_t i = 0; i < x_dim_size; i++) {
1630+
axis_.push_back(i);
1631+
}
1632+
} else {
1633+
axis_ = axis.GetData();
1634+
for (int64_t i = 0; i < axis_size; i++) {
1635+
if (axis[i] < 0) {
1636+
axis_[i] = axis[i] + x_dim_size;
1637+
}
15901638
}
15911639
}
1640+
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
1641+
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
1642+
auto out_ = reshape<T>(out, out_grad_shape);
1643+
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
1644+
auto out_tmp = out_.expand(IntArray(x_dim));
1645+
auto mask = equal<T>(x, out_tmp);
1646+
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
15921647
}
1593-
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
1594-
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
1595-
auto out_ = reshape<T>(out, out_grad_shape);
1596-
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
1597-
auto out_tmp = out_.expand(IntArray(x_dim));
1598-
auto mask = equal<T>(x, out_tmp);
1599-
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
16001648
}
16011649
set_output<T>(x_grad_tmp, x_grad);
16021650
}
@@ -2292,7 +2340,7 @@ void swiglu_grad(const Tensor& x,
22922340
Tensor* dx,
22932341
Tensor* dy) {
22942342
const auto& x_shape = x.shape();
2295-
auto one_tensor = full<T>(x_shape, 1.0, x.dtype());
2343+
auto one_tensor = full_scalar<T>(1.0, x.dtype());
22962344
Tensor x_grad;
22972345
if (y) {
22982346
const auto& y_tensor = y.get();

paddle/fluid/primitive/utils/utils.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ static std::vector<int64_t> get_expand_dims(const Tensor& origin,
7777
return result;
7878
}
7979

80-
// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
80+
// This function compute unsqueeze dims for reshape to replace unsqueeze.
8181
static std::vector<int64_t> get_unsqueeze_dims(
8282
const Tensor& origin, const std::vector<int64_t>& axis) {
8383
auto origin_dims = origin.shape();
@@ -103,7 +103,41 @@ static std::vector<int64_t> get_unsqueeze_dims(
103103
return result;
104104
}
105105

106-
// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
106+
// This function compute `dynamic` unsqueeze dims for reshape to replace
107+
// unsqueeze. And should used only on `dynamic`.
108+
template <typename T>
109+
Tensor get_unsqueeze_dims(const Tensor& origin_shape,
110+
const std::vector<int64_t>& axis) {
111+
auto total_shape_size = origin_shape.numel() + axis.size();
112+
const Tensor one = full<T>({1}, 1, origin_shape.dtype());
113+
114+
std::vector<Tensor> result(total_shape_size, one);
115+
// to support axis not in increasing order.
116+
std::vector<bool> is_set(total_shape_size, false);
117+
118+
for (size_t i = 0; i < axis.size(); ++i) {
119+
PADDLE_ENFORCE_LT(
120+
axis[i],
121+
total_shape_size,
122+
common::errors::OutOfRange("Your index [%lu] exceeds the number of "
123+
"elements in origin_dims[%lu].",
124+
axis[i],
125+
total_shape_size));
126+
is_set[axis[i]] = true;
127+
}
128+
129+
size_t j = 0;
130+
for (size_t i = 0; i < total_shape_size; ++i) {
131+
if (is_set[i]) {
132+
continue;
133+
}
134+
result[i] = get_slice<T>(origin_shape, int64_t(j));
135+
is_set[i] = true;
136+
++j;
137+
}
138+
return concat<T>(result);
139+
}
140+
107141
static std::vector<int64_t> get_squeeze_dims(const Tensor& origin,
108142
const std::vector<int64_t>& axis) {
109143
auto origin_dims = origin.shape();

python/paddle/autograd/backward_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"pd_op.gelu",
6666
"pd_op.hardswish",
6767
"pd_op.reduce_as",
68+
"pd_op.max",
6869
]
6970

7071

test/prim/pir_prim/test_prim_sub_graph_backward_dynamic_shape.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,30 @@ def reduce_as_net(x, y):
175175
return paddle.reduce_as(x, y)
176176

177177

178+
def max_net1(x):
179+
return paddle.max(x, keepdim=True)
180+
181+
182+
def max_net2(x):
183+
return paddle.max(x, keepdim=False)
184+
185+
186+
def max_net3(x):
187+
return paddle.max(x, axis=[0, 1], keepdim=False)
188+
189+
190+
def max_net4(x):
191+
return paddle.max(x, axis=[-1, -2], keepdim=False)
192+
193+
194+
def max_net5(x):
195+
return paddle.max(x, axis=[-1, 0], keepdim=False)
196+
197+
198+
def max_net6(x):
199+
return paddle.max(x)
200+
201+
178202
def apply_to_static(net, use_cinn, input_spec=None):
179203
build_strategy = paddle.static.BuildStrategy()
180204
build_strategy.build_cinn_pass = use_cinn
@@ -1966,5 +1990,89 @@ def setUp(self):
19661990
self.y_without_grad = True
19671991

19681992

1993+
class TestPrimMaxWithGrad1(TestPrimBaseWithGrad):
1994+
def setUp(self):
1995+
np.random.seed(2024)
1996+
self.dtype = "float32"
1997+
self.x_shape = [30, 200, 40]
1998+
self.init_x_shape = [None, None, None]
1999+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2000+
self.net = max_net1
2001+
self.enable_cinn = False
2002+
self.tol = 1e-6
2003+
2004+
2005+
class TestPrimMaxWithGrad2(TestPrimBaseWithGrad):
2006+
def setUp(self):
2007+
np.random.seed(2024)
2008+
self.dtype = "float32"
2009+
self.x_shape = [30]
2010+
self.init_x_shape = [None]
2011+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2012+
self.net = max_net1
2013+
self.enable_cinn = False
2014+
self.tol = 1e-6
2015+
2016+
2017+
class TestPrimMaxWithGrad3(TestPrimBaseWithGrad):
2018+
def setUp(self):
2019+
np.random.seed(2024)
2020+
self.dtype = "float32"
2021+
self.x_shape = [30, 200, 40]
2022+
self.init_x_shape = [None, None, None]
2023+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2024+
self.net = max_net2
2025+
self.enable_cinn = False
2026+
self.tol = 1e-6
2027+
2028+
2029+
class TestPrimMaxWithGrad4(TestPrimBaseWithGrad):
2030+
def setUp(self):
2031+
np.random.seed(2024)
2032+
self.dtype = "float32"
2033+
self.x_shape = [30, 200, 40]
2034+
self.init_x_shape = [None, None, None]
2035+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2036+
self.net = max_net3
2037+
self.enable_cinn = False
2038+
self.tol = 1e-6
2039+
2040+
2041+
class TestPrimMaxWithGrad5(TestPrimBaseWithGrad):
2042+
def setUp(self):
2043+
np.random.seed(2024)
2044+
self.dtype = "float32"
2045+
self.x_shape = [30, 200, 40]
2046+
self.init_x_shape = [None, None, None]
2047+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2048+
self.net = max_net4
2049+
self.enable_cinn = False
2050+
self.tol = 1e-6
2051+
2052+
2053+
class TestPrimMaxWithGrad6(TestPrimBaseWithGrad):
2054+
def setUp(self):
2055+
np.random.seed(2024)
2056+
self.dtype = "float32"
2057+
self.x_shape = [30, 200, 40]
2058+
self.init_x_shape = [None, None, None]
2059+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2060+
self.net = max_net5
2061+
self.enable_cinn = False
2062+
self.tol = 1e-6
2063+
2064+
2065+
class TestPrimMaxWithGrad7(TestPrimBaseWithGrad):
2066+
def setUp(self):
2067+
np.random.seed(2024)
2068+
self.dtype = "float32"
2069+
self.x_shape = [30, 200, 40]
2070+
self.init_x_shape = [None, None, None]
2071+
self.x = np.random.random(self.x_shape).astype(self.dtype)
2072+
self.net = max_net6
2073+
self.enable_cinn = False
2074+
self.tol = 1e-6
2075+
2076+
19692077
if __name__ == "__main__":
19702078
unittest.main()

0 commit comments

Comments
 (0)