Skip to content

Commit 54fd11b

Browse files
authored
[SOT] Copy cached meta info to avoid got wrong symbol (#71971)
1 parent efb2886 commit 54fd11b

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

python/paddle/jit/sot/infer_meta.py

+26
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import copy
1617
from functools import cached_property
1718
from typing import TypeVar
1819

@@ -102,6 +103,13 @@ def from_value(value: paddle.pir.Value) -> DistInfo:
102103
value._local_shape,
103104
)
104105

106+
def __deepcopy__(self, memo):
107+
return DistInfo(
108+
mesh=copy.deepcopy(self.mesh),
109+
dims_mapping=copy.deepcopy(self.dims_mapping),
110+
local_shape=copy.deepcopy(self.local_shape),
111+
)
112+
105113
def __repr__(self) -> str:
106114
return f"DistInfo(mesh={self.mesh}, dims_mapping={self.dims_mapping}, local_shape={self.local_shape})"
107115

@@ -270,6 +278,18 @@ def guard_str(self):
270278
shape = self.shape_with_special_symbol(SymbolicInt())
271279
return f"({shape}, {self.dtype}, {self.stop_gradient})"
272280

281+
def __deepcopy__(self, memo):
282+
return MetaInfo(
283+
list(self.shape),
284+
self.dtype,
285+
self.stop_gradient,
286+
self.name,
287+
self.persistable,
288+
self.type,
289+
self.place,
290+
dist_info=copy.deepcopy(self.dist_info),
291+
)
292+
273293
def __repr__(self):
274294
return meta_str(self.shape, self.dtype, self.stop_gradient)
275295

@@ -541,6 +561,9 @@ def infermeta_grad(
541561

542562

543563
class InferMetaCache(Cache, metaclass=Singleton):
564+
def __init__(self):
565+
super().__init__(copy=True)
566+
544567
def key_fn(
545568
self, func, *args, **kwargs
546569
): # args & kwargs have transformed to MetaInfo
@@ -556,6 +579,9 @@ def value_fn(self, func, *args, **kwargs):
556579

557580

558581
class LayerInferMetaCache(Cache, metaclass=Singleton):
582+
def __init__(self):
583+
super().__init__(copy=True)
584+
559585
def key_fn(self, layer, *args, **kwargs):
560586
params = [
561587
MetaInfo.from_value(x)

python/paddle/jit/sot/utils/utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import builtins
18+
import copy
1819
import inspect
1920
import sys
2021
import time
@@ -267,12 +268,13 @@ def is_true(*args):
267268

268269

269270
class Cache:
270-
def __init__(self, weak=False):
271+
def __init__(self, weak=False, copy=False):
271272
if not weak:
272273
self.cache = {}
273274
else:
274275
self.cache = WeakValueDictionary()
275276
self.hit_num = 0
277+
self.copy = copy
276278

277279
def __call__(self, *args, **kwargs):
278280
cache_key = self.key_fn(*args, **kwargs)
@@ -281,7 +283,10 @@ def __call__(self, *args, **kwargs):
281283
if cache_key in self.cache:
282284
log(5, "cache hit: ", cache_key, "\n")
283285
self.hit_num += 1
284-
return self.cache[cache_key]
286+
cache_item = self.cache[cache_key]
287+
if self.copy:
288+
cache_item = copy.deepcopy(cache_item)
289+
return cache_item
285290
value = self.value_fn(*args, **kwargs)
286291
self.cache[cache_key] = value
287292
return value

0 commit comments

Comments
 (0)