13
13
# limitations under the License.
14
14
from __future__ import annotations
15
15
16
+ import copy
16
17
from functools import cached_property
17
18
from typing import TypeVar
18
19
@@ -102,6 +103,13 @@ def from_value(value: paddle.pir.Value) -> DistInfo:
102
103
value ._local_shape ,
103
104
)
104
105
106
+ def __deepcopy__ (self , memo ):
107
+ return DistInfo (
108
+ mesh = copy .deepcopy (self .mesh ),
109
+ dims_mapping = copy .deepcopy (self .dims_mapping ),
110
+ local_shape = copy .deepcopy (self .local_shape ),
111
+ )
112
+
105
113
def __repr__ (self ) -> str :
106
114
return f"DistInfo(mesh={ self .mesh } , dims_mapping={ self .dims_mapping } , local_shape={ self .local_shape } )"
107
115
@@ -270,6 +278,18 @@ def guard_str(self):
270
278
shape = self .shape_with_special_symbol (SymbolicInt ())
271
279
return f"({ shape } , { self .dtype } , { self .stop_gradient } )"
272
280
281
+ def __deepcopy__ (self , memo ):
282
+ return MetaInfo (
283
+ list (self .shape ),
284
+ self .dtype ,
285
+ self .stop_gradient ,
286
+ self .name ,
287
+ self .persistable ,
288
+ self .type ,
289
+ self .place ,
290
+ dist_info = copy .deepcopy (self .dist_info ),
291
+ )
292
+
273
293
def __repr__ (self ):
274
294
return meta_str (self .shape , self .dtype , self .stop_gradient )
275
295
@@ -541,6 +561,9 @@ def infermeta_grad(
541
561
542
562
543
563
class InferMetaCache (Cache , metaclass = Singleton ):
564
+ def __init__ (self ):
565
+ super ().__init__ (copy = True )
566
+
544
567
def key_fn (
545
568
self , func , * args , ** kwargs
546
569
): # args & kwargs have transformed to MetaInfo
@@ -556,6 +579,9 @@ def value_fn(self, func, *args, **kwargs):
556
579
557
580
558
581
class LayerInferMetaCache (Cache , metaclass = Singleton ):
582
+ def __init__ (self ):
583
+ super ().__init__ (copy = True )
584
+
559
585
def key_fn (self , layer , * args , ** kwargs ):
560
586
params = [
561
587
MetaInfo .from_value (x )
0 commit comments