File tree 1 file changed +2
-13
lines changed
1 file changed +2
-13
lines changed Original file line number Diff line number Diff line change @@ -88,13 +88,7 @@ def __call__(self, params):
88
88
if param .grad .dtype == paddle .float16 :
89
89
clip_coef = clip_coef_fp32 .astype ("float16" )
90
90
91
- # inplace calculate
92
- paddle .fluid .framework ._dygraph_tracer ().trace_op (
93
- type = "elementwise_mul" ,
94
- inputs = {'X' : param .grad ,
95
- 'Y' : clip_coef },
96
- outputs = {'Out' : param .grad },
97
- attrs = {'axis' : - 1 })
91
+ param .grad .detach ().scale_ (clip_coef )
98
92
99
93
100
94
@paddle .no_grad ()
@@ -141,10 +135,5 @@ def clip_grad_norm_(parameters,
141
135
clip_coef = max_norm / (total_norm + 1e-6 )
142
136
clip_coef_clamped = paddle .clip (clip_coef , max = 1.0 )
143
137
for p in parameters :
144
- paddle .fluid .framework ._dygraph_tracer ().trace_op (
145
- type = "elementwise_mul" ,
146
- inputs = {'X' : p .grad ,
147
- 'Y' : clip_coef_clamped },
148
- outputs = {'Out' : p .grad },
149
- attrs = {'axis' : - 1 })
138
+ p .grad .detach ().scale_ (clip_coef_clamped )
150
139
return total_norm
You can’t perform that action at this time.
0 commit comments