Skip to content

Commit da4f020

Browse files
add create_graph for jacobian and hessian in order to reducing unnecessary GPU mem when derivatives involved in predicting progress (#600)
1 parent f475b12 commit da4f020

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

examples/bubble/bubble.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ def transform_out(in_, out):
298298
psi_y = out["psi"]
299299
y = in_["y"]
300300
x = in_["x"]
301-
u = jacobian(psi_y, y)
302-
v = -jacobian(psi_y, x)
301+
u = jacobian(psi_y, y, create_graph=False)
302+
v = -jacobian(psi_y, x, create_graph=False)
303303
return {"u": u, "v": v}
304304

305305
# register transform

ppsci/autodiff/ad.py

+46-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def __init__(self, ys: "paddle.Tensor", xs: "paddle.Tensor"):
4545

4646
self.J: Dict[str, paddle.Tensor] = {}
4747

48-
def __call__(self, i: int = 0, j: Optional[int] = None) -> "paddle.Tensor":
48+
def __call__(
49+
self,
50+
i: int = 0,
51+
j: Optional[int] = None,
52+
retain_graph: Optional[bool] = None,
53+
create_graph: bool = True,
54+
) -> "paddle.Tensor":
4955
"""Returns J[`i`][`j`]. If `j` is ``None``, returns the gradient of y_i, i.e.,
5056
J[i].
5157
"""
@@ -56,7 +62,9 @@ def __call__(self, i: int = 0, j: Optional[int] = None) -> "paddle.Tensor":
5662
# Compute J[i]
5763
if i not in self.J:
5864
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
59-
self.J[i] = paddle.grad(y, self.xs, create_graph=True)[0]
65+
self.J[i] = paddle.grad(
66+
y, self.xs, retain_graph=retain_graph, create_graph=create_graph
67+
)[0]
6068

6169
return self.J[i] if (j is None or self.dim_x == 1) else self.J[i][:, j : j + 1]
6270

@@ -82,6 +90,8 @@ def __call__(
8290
xs: "paddle.Tensor",
8391
i: int = 0,
8492
j: Optional[int] = None,
93+
retain_graph: Optional[bool] = None,
94+
create_graph: bool = True,
8595
) -> "paddle.Tensor":
8696
"""Compute jacobians for given ys and xs.
8797
@@ -90,6 +100,15 @@ def __call__(
90100
xs (paddle.Tensor): Input tensor.
91101
i (int, optional): i-th output variable. Defaults to 0.
92102
j (Optional[int]): j-th input variable. Defaults to None.
103+
retain_graph (Optional[bool]): whether to retain the forward graph which
104+
is used to calculate the gradient. When it is True, the graph would
105+
be retained, in which way users can calculate backward twice for the
106+
same graph. When it is False, the graph would be freed. Default None,
107+
which means it is equal to `create_graph`.
108+
create_graph (bool, optional): whether to create the gradient graphs of
109+
the computing process. When it is True, higher order derivatives are
110+
supported to compute; when it is False, the gradient graphs of the
111+
computing process would be discarded. Default False.
93112
94113
Returns:
95114
paddle.Tensor: Jacobian matrix of ys[i] to xs[j].
@@ -105,7 +124,7 @@ def __call__(
105124
key = (ys, xs)
106125
if key not in self.Js:
107126
self.Js[key] = _Jacobian(ys, xs)
108-
return self.Js[key](i, j)
127+
return self.Js[key](i, j, retain_graph, create_graph)
109128

110129
def _clear(self):
111130
"""Clear cached Jacobians."""
@@ -157,12 +176,21 @@ def __init__(
157176
component = 0
158177

159178
if grad_y is None:
160-
grad_y = jacobian(ys, xs, i=component, j=None)
179+
# `create_graph` of first order(jacobian) should be `True` in _Hessian.
180+
grad_y = jacobian(
181+
ys, xs, i=component, j=None, retain_graph=None, create_graph=True
182+
)
161183
self.H = _Jacobian(grad_y, xs)
162184

163-
def __call__(self, i: int = 0, j: int = 0):
185+
def __call__(
186+
self,
187+
i: int = 0,
188+
j: int = 0,
189+
retain_graph: Optional[bool] = None,
190+
create_graph: bool = True,
191+
):
164192
"""Returns H[`i`][`j`]."""
165-
return self.H(i, j)
193+
return self.H(i, j, retain_graph, create_graph)
166194

167195

168196
class Hessians:
@@ -188,6 +216,8 @@ def __call__(
188216
i: int = 0,
189217
j: int = 0,
190218
grad_y: Optional["paddle.Tensor"] = None,
219+
retain_graph: Optional[bool] = None,
220+
create_graph: bool = True,
191221
) -> "paddle.Tensor":
192222
"""Compute hessian matrix for given ys and xs.
193223
@@ -201,6 +231,15 @@ def __call__(
201231
j (int, optional): j-th input variable. Defaults to 0.
202232
grad_y (Optional[paddle.Tensor]): The gradient of `y` w.r.t. `xs`. Provide `grad_y` if known to avoid
203233
duplicate computation. Defaults to None.
234+
retain_graph (Optional[bool]): whether to retain the forward graph which
235+
is used to calculate the gradient. When it is True, the graph would
236+
be retained, in which way users can calculate backward twice for the
237+
same graph. When it is False, the graph would be freed. Default None,
238+
which means it is equal to `create_graph`.
239+
create_graph (bool, optional): whether to create the gradient graphs of
240+
the computing process. When it is True, higher order derivatives are
241+
supported to compute; when it is False, the gradient graphs of the
242+
computing process would be discarded. Default False.
204243
205244
Returns:
206245
paddle.Tensor: Hessian matrix.
@@ -216,7 +255,7 @@ def __call__(
216255
key = (ys, xs, component)
217256
if key not in self.Hs:
218257
self.Hs[key] = _Hessian(ys, xs, component=component, grad_y=grad_y)
219-
return self.Hs[key](i, j)
258+
return self.Hs[key](i, j, retain_graph, create_graph)
220259

221260
def _clear(self):
222261
"""Clear cached Hessians."""

0 commit comments

Comments
 (0)