@@ -45,7 +45,13 @@ def __init__(self, ys: "paddle.Tensor", xs: "paddle.Tensor"):
45
45
46
46
self .J : Dict [str , paddle .Tensor ] = {}
47
47
48
- def __call__ (self , i : int = 0 , j : Optional [int ] = None ) -> "paddle.Tensor" :
48
+ def __call__ (
49
+ self ,
50
+ i : int = 0 ,
51
+ j : Optional [int ] = None ,
52
+ retain_graph : Optional [bool ] = None ,
53
+ create_graph : bool = True ,
54
+ ) -> "paddle.Tensor" :
49
55
"""Returns J[`i`][`j`]. If `j` is ``None``, returns the gradient of y_i, i.e.,
50
56
J[i].
51
57
"""
@@ -56,7 +62,9 @@ def __call__(self, i: int = 0, j: Optional[int] = None) -> "paddle.Tensor":
56
62
# Compute J[i]
57
63
if i not in self .J :
58
64
y = self .ys [:, i : i + 1 ] if self .dim_y > 1 else self .ys
59
- self .J [i ] = paddle .grad (y , self .xs , create_graph = True )[0 ]
65
+ self .J [i ] = paddle .grad (
66
+ y , self .xs , retain_graph = retain_graph , create_graph = create_graph
67
+ )[0 ]
60
68
61
69
return self .J [i ] if (j is None or self .dim_x == 1 ) else self .J [i ][:, j : j + 1 ]
62
70
@@ -82,6 +90,8 @@ def __call__(
82
90
xs : "paddle.Tensor" ,
83
91
i : int = 0 ,
84
92
j : Optional [int ] = None ,
93
+ retain_graph : Optional [bool ] = None ,
94
+ create_graph : bool = True ,
85
95
) -> "paddle.Tensor" :
86
96
"""Compute jacobians for given ys and xs.
87
97
@@ -90,6 +100,15 @@ def __call__(
90
100
xs (paddle.Tensor): Input tensor.
91
101
i (int, optional): i-th output variable. Defaults to 0.
92
102
j (Optional[int]): j-th input variable. Defaults to None.
103
+ retain_graph (Optional[bool]): whether to retain the forward graph which
104
+ is used to calculate the gradient. When it is True, the graph would
105
+ be retained, in which way users can calculate backward twice for the
106
+ same graph. When it is False, the graph would be freed. Default None,
107
+ which means it is equal to `create_graph`.
108
+ create_graph (bool, optional): whether to create the gradient graphs of
109
+ the computing process. When it is True, higher order derivatives are
110
+ supported to compute; when it is False, the gradient graphs of the
111
+ computing process would be discarded. Default False.
93
112
94
113
Returns:
95
114
paddle.Tensor: Jacobian matrix of ys[i] to xs[j].
@@ -105,7 +124,7 @@ def __call__(
105
124
key = (ys , xs )
106
125
if key not in self .Js :
107
126
self .Js [key ] = _Jacobian (ys , xs )
108
- return self .Js [key ](i , j )
127
+ return self .Js [key ](i , j , retain_graph , create_graph )
109
128
110
129
def _clear (self ):
111
130
"""Clear cached Jacobians."""
@@ -157,12 +176,21 @@ def __init__(
157
176
component = 0
158
177
159
178
if grad_y is None :
160
- grad_y = jacobian (ys , xs , i = component , j = None )
179
+ # `create_graph` of first order(jacobian) should be `True` in _Hessian.
180
+ grad_y = jacobian (
181
+ ys , xs , i = component , j = None , retain_graph = None , create_graph = True
182
+ )
161
183
self .H = _Jacobian (grad_y , xs )
162
184
163
- def __call__ (self , i : int = 0 , j : int = 0 ):
185
+ def __call__ (
186
+ self ,
187
+ i : int = 0 ,
188
+ j : int = 0 ,
189
+ retain_graph : Optional [bool ] = None ,
190
+ create_graph : bool = True ,
191
+ ):
164
192
"""Returns H[`i`][`j`]."""
165
- return self .H (i , j )
193
+ return self .H (i , j , retain_graph , create_graph )
166
194
167
195
168
196
class Hessians :
@@ -188,6 +216,8 @@ def __call__(
188
216
i : int = 0 ,
189
217
j : int = 0 ,
190
218
grad_y : Optional ["paddle.Tensor" ] = None ,
219
+ retain_graph : Optional [bool ] = None ,
220
+ create_graph : bool = True ,
191
221
) -> "paddle.Tensor" :
192
222
"""Compute hessian matrix for given ys and xs.
193
223
@@ -201,6 +231,15 @@ def __call__(
201
231
j (int, optional): j-th input variable. Defaults to 0.
202
232
grad_y (Optional[paddle.Tensor]): The gradient of `y` w.r.t. `xs`. Provide `grad_y` if known to avoid
203
233
duplicate computation. Defaults to None.
234
+ retain_graph (Optional[bool]): whether to retain the forward graph which
235
+ is used to calculate the gradient. When it is True, the graph would
236
+ be retained, in which way users can calculate backward twice for the
237
+ same graph. When it is False, the graph would be freed. Default None,
238
+ which means it is equal to `create_graph`.
239
+ create_graph (bool, optional): whether to create the gradient graphs of
240
+ the computing process. When it is True, higher order derivatives are
241
+ supported to compute; when it is False, the gradient graphs of the
242
+ computing process would be discarded. Default False.
204
243
205
244
Returns:
206
245
paddle.Tensor: Hessian matrix.
@@ -216,7 +255,7 @@ def __call__(
216
255
key = (ys , xs , component )
217
256
if key not in self .Hs :
218
257
self .Hs [key ] = _Hessian (ys , xs , component = component , grad_y = grad_y )
219
- return self .Hs [key ](i , j )
258
+ return self .Hs [key ](i , j , retain_graph , create_graph )
220
259
221
260
def _clear (self ):
222
261
"""Clear cached Hessians."""
0 commit comments