25
25
import paddle
26
26
import paddle .fluid
27
27
from paddle .fluid import core , unique_name
28
- from paddle .fluid .framework import EagerParamBase
29
28
from paddle .fluid .framework import in_dygraph_mode
30
29
30
+ if in_dygraph_mode ():
31
+ from paddle .fluid .framework import EagerParamBase
32
+ else :
33
+ from paddle .fluid .framework import ParamBase
34
+
31
35
alignment = {"gpu" : 256 , }
32
36
align = {
33
37
paddle .bfloat16 : 2 ,
@@ -78,7 +82,10 @@ def __init__(self, size, dtype, device, convert_cpu=True):
78
82
else :
79
83
value = np .zeros (size , dtype = np .float32 )
80
84
81
- self .buffer = core .eager .Tensor (value = value , place = core .CPUPlace ())
85
+ if in_dygraph_mode ():
86
+ self .buffer = core .eager .Tensor (value = value , place = core .CPUPlace ())
87
+ else :
88
+ self .buffer = core .VarBase (value = value , place = core .CPUPlace ())
82
89
else :
83
90
self .buffer = paddle .zeros (size , dtype = dtype )
84
91
@@ -166,9 +173,14 @@ def _add_param_as_view(self, param, align, convert_gpu=True):
166
173
dev_id = 0 if paddle .get_device () == "cpu" else int (paddle .get_device ()
167
174
.split (":" )[1 ])
168
175
with device_guard (dev_id , "cpu" ):
169
- tmp_var = core .eager .Tensor (
170
- self .buffer ._slice (self ._fill , var_end ))
171
- tmp_var .get_tensor ()._set_dims (param .shape )
176
+ if in_dygraph_mode ():
177
+ tmp_var = core .eager .Tensor (
178
+ self .buffer ._slice (self ._fill , var_end ))
179
+ tmp_var .get_tensor ()._set_dims (param .shape )
180
+ else :
181
+ tmp_var = core .VarBase (
182
+ tensor = self .buffer ._slice (self ._fill , var_end ))
183
+ tmp_var .value ().get_tensor ()._set_dims (param .shape )
172
184
if convert_gpu :
173
185
param_cpu = param .cpu ()
174
186
param .value ().get_tensor ()._clear ()
@@ -188,8 +200,9 @@ def _convert_buffer(self, param, p_shape, align):
188
200
189
201
# Convert the param value
190
202
tmp_tensor = self .buffer ._slice (self ._fill , var_end )
191
- param .value ().get_tensor ()._share_data_with (tmp_tensor .value ()
192
- .get_tensor ())
203
+ if in_dygraph_mode ():
204
+ tmp_tensor = tmp_tensor .value ().get_tensor ()
205
+ param .value ().get_tensor ()._share_data_with (tmp_tensor )
193
206
param .value ().get_tensor ()._set_dims (p_shape )
194
207
195
208
self ._fill = offset
@@ -204,8 +217,11 @@ def _array_params(self):
204
217
205
218
self ._fill = 0
206
219
for p in self ._params :
207
- self ._convert_buffer (p , p .shape ,
208
- self .param2align [p .name ]) # modify
220
+ if in_dygraph_mode ():
221
+ self ._convert_buffer (p , p .shape ,
222
+ self .param2align [p .name ]) # modify
223
+ else :
224
+ self ._convert_buffer (p , p .shape , self .param2align [p .name ]) # modify
209
225
210
226
211
227
class GradStorage (object ):
@@ -240,7 +256,10 @@ def __init__(self,
240
256
value = np .zeros (size , dtype = np .uint16 )
241
257
else :
242
258
value = np .zeros (size , dtype = np .float32 )
243
- self .buffer = core .eager .Tensor (value = value , place = core .CPUPlace ())
259
+ if in_dygraph_mode ():
260
+ self .buffer = core .eager .Tensor (value = value , place = core .CPUPlace ())
261
+ else :
262
+ self .buffer = core .VarBase (value = value , place = core .CPUPlace ())
244
263
else :
245
264
self .buffer = paddle .zeros (size , dtype = dtype )
246
265
@@ -371,15 +390,23 @@ def _add_grad_as_view(self, param, align):
371
390
.split (":" )[1 ])
372
391
if self ._device == "cpu" :
373
392
with device_guard (dev_id , self ._device ):
374
- tmp_var = core .eager .Tensor (
375
- self .buffer ._slice (self ._fill , grad_end ))
376
- tmp_var .get_tensor ()._set_dims (param .shape )
393
+ if in_dygraph_mode ():
394
+ tmp_var = core .eager .Tensor (
395
+ self .buffer ._slice (self ._fill , grad_end ))
396
+ tmp_var .get_tensor ()._set_dims (param .shape )
397
+ else :
398
+ tmp_var = core .VarBase (self .buffer ._slice (self ._fill , grad_end ))
399
+ tmp_var .value ().get_tensor ()._set_dims (param .shape )
377
400
param ._copy_gradient_from (tmp_var )
378
401
379
402
elif self ._device == "gpu" :
380
- tmp_var = core .eager .Tensor (
381
- self .buffer ._slice (self ._fill , grad_end ))
382
- tmp_var .get_tensor ()._set_dims (param .shape )
403
+ if in_dygraph_mode ():
404
+ tmp_var = core .eager .Tensor (
405
+ self .buffer ._slice (self ._fill , grad_end ))
406
+ tmp_var .get_tensor ()._set_dims (param .shape )
407
+ else :
408
+ tmp_var = core .VarBase (self .buffer ._slice (self ._fill , grad_end ))
409
+ tmp_var .value ().get_tensor ()._set_dims (param .shape )
383
410
param ._copy_gradient_from (tmp_var )
384
411
385
412
self ._fill = offset
@@ -390,8 +417,12 @@ def assign_group_by_size(parameters, group_size=256 * 1024 * 1024):
390
417
assign group by size
391
418
"""
392
419
is_sparse_gradient = [False ] * len (parameters )
393
- group_indices = core .eager_assign_group_by_size (
394
- parameters , is_sparse_gradient , [group_size , group_size ])
420
+ if in_dygraph_mode ():
421
+ group_indices = core .eager_assign_group_by_size (
422
+ parameters , is_sparse_gradient , [group_size , group_size ])
423
+ else :
424
+ group_indices = core .assign_group_by_size (
425
+ parameters , is_sparse_gradient , [group_size , group_size ])
395
426
var_groups = OrderedDict ()
396
427
for group_idx , indices in enumerate (group_indices ):
397
428
for index in indices :
@@ -435,10 +466,16 @@ def flatten_dense_tensors(parameters):
435
466
for param in parameters :
436
467
grad_storage .add_grad (param , _param2align [param .name ])
437
468
438
- fused_param = EagerParamBase (
439
- shape = param_storage .buffer .shape ,
440
- dtype = dtype ,
441
- name = unique_name .generate ('fused_param' ))
469
+ if in_dygraph_mode ():
470
+ fused_param = EagerParamBase (
471
+ shape = param_storage .buffer .shape ,
472
+ dtype = dtype ,
473
+ name = unique_name .generate ('fused_param' ))
474
+ else :
475
+ fused_param = ParamBase (
476
+ shape = param_storage .buffer .shape ,
477
+ dtype = dtype ,
478
+ name = unique_name .generate ('fused_param' ))
442
479
param_storage .buffer ._share_buffer_to (fused_param )
443
480
fused_param ._copy_gradient_from (grad_storage .buffer )
444
481
fused_param .__dict__ .update (state )
0 commit comments