Skip to content

Commit 7bf128c

Browse files
[PIR] Add state_dict to program_patch to enable default paramater (#70377)
* add state_dict to program_patch * fix
1 parent 92eca04 commit 7bf128c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

paddle/fluid/pybind/pir.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ void BindProgram(py::module *m) {
594594
})
595595
.def("num_ops", [](Program &self) { return self.num_ops(); })
596596
.def(
597-
"state_dict",
597+
"_state_dict",
598598
[](std::shared_ptr<Program> self,
599599
const std::string &mode = "all",
600600
const framework::Scope &scope = framework::Scope()) {

python/paddle/pir/program_patch.py

+14
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,22 @@ def _lr_schedule_guard(self, is_with_opt=False):
2727
# be fixed in the future.
2828
yield
2929

30+
def state_dict(self, mode="all", scope=None):
31+
from paddle.base import core
32+
from paddle.base.executor import global_scope
33+
34+
if scope is not None and not isinstance(scope, core._Scope):
35+
raise TypeError(
36+
f"`scope` should be None or `paddle.static.Scope'` type, but received {type(scope)}."
37+
)
38+
39+
if scope is None:
40+
scope = global_scope()
41+
return self._state_dict(mode, scope)
42+
3043
program_attrs = {
3144
"_lr_schedule_guard": _lr_schedule_guard,
45+
"state_dict": state_dict,
3246
}
3347

3448
global _already_patch_program

0 commit comments

Comments
 (0)