Skip to content

Commit a2929ec

Browse files
committed
lazy_init
1 parent 66c4ecb commit a2929ec

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

python/paddle/nn/layer/layers.py

+31
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import weakref
2222
from collections import OrderedDict
2323
from typing import TYPE_CHECKING, Any, Callable, Dict, Union
24+
import functools
2425

2526
import numpy as np
2627
from typing_extensions import Self
@@ -462,6 +463,9 @@ def __init__(
462463
# Records original functions after @to_static to support to rollback
463464
self._original_funcs = OrderedDict()
464465

466+
# Records parameters whether they are initialized
467+
self._is_parameters_initialized = False
468+
465469
def train(self) -> None:
466470
"""
467471
@@ -1555,6 +1559,33 @@ def _dygraph_call_func(self, *inputs: Any, **kwargs: Any) -> Any:
15551559

15561560
return outputs
15571561

1562+
1563+
def _init_params_decorator(func):
1564+
"""
1565+
Decorator function that initializes parameters before calling the decorated method.
1566+
1567+
This decorator checks whether each parameter has been initialized using the '_is_initialized' property.
1568+
If any parameter is uninitialized, it calls the 'initialize' method on that parameter.
1569+
1570+
Args:
1571+
func (function): The function being decorated.
1572+
1573+
Returns:
1574+
function: A wrapped version of the input function that performs parameter initialization before calling the original function.
1575+
"""
1576+
@functools.wraps(func)
1577+
def wrapper(self, *args, **kwargs):
1578+
if not self._is_parameters_initialized:
1579+
for _, param in self.named_parameters():
1580+
if not param._is_initialized():
1581+
param.initialize()
1582+
self._is_parameters_initialized = True
1583+
1584+
return func(self, *args, **kwargs)
1585+
1586+
return wrapper
1587+
1588+
@_init_params_decorator
15581589
def __call__(self, *inputs: Any, **kwargs: Any) -> Any:
15591590
if (
15601591
(not in_to_static_mode())

0 commit comments

Comments
 (0)