|
21 | 21 | import weakref
|
22 | 22 | from collections import OrderedDict
|
23 | 23 | from typing import TYPE_CHECKING, Any, Callable, Dict, Union
|
| 24 | +import functools |
24 | 25 |
|
25 | 26 | import numpy as np
|
26 | 27 | from typing_extensions import Self
|
@@ -462,6 +463,9 @@ def __init__(
|
462 | 463 | # Records original functions after @to_static to support to rollback
|
463 | 464 | self._original_funcs = OrderedDict()
|
464 | 465 |
|
| 466 | + # Records parameters whether they are initialized |
| 467 | + self._is_parameters_initialized = False |
| 468 | + |
465 | 469 | def train(self) -> None:
|
466 | 470 | """
|
467 | 471 |
|
@@ -1555,6 +1559,33 @@ def _dygraph_call_func(self, *inputs: Any, **kwargs: Any) -> Any:
|
1555 | 1559 |
|
1556 | 1560 | return outputs
|
1557 | 1561 |
|
| 1562 | + |
| 1563 | + def _init_params_decorator(func): |
| 1564 | + """ |
| 1565 | + Decorator function that initializes parameters before calling the decorated method. |
| 1566 | +
|
| 1567 | + This decorator checks whether each parameter has been initialized using the '_is_initialized' property. |
| 1568 | + If any parameter is uninitialized, it calls the 'initialize' method on that parameter. |
| 1569 | +
|
| 1570 | + Args: |
| 1571 | + func (function): The function being decorated. |
| 1572 | +
|
| 1573 | + Returns: |
| 1574 | + function: A wrapped version of the input function that performs parameter initialization before calling the original function. |
| 1575 | + """ |
| 1576 | + @functools.wraps(func) |
| 1577 | + def wrapper(self, *args, **kwargs): |
| 1578 | + if not self._is_parameters_initialized: |
| 1579 | + for _, param in self.named_parameters(): |
| 1580 | + if not param._is_initialized(): |
| 1581 | + param.initialize() |
| 1582 | + self._is_parameters_initialized = True |
| 1583 | + |
| 1584 | + return func(self, *args, **kwargs) |
| 1585 | + |
| 1586 | + return wrapper |
| 1587 | + |
| 1588 | + @_init_params_decorator |
1558 | 1589 | def __call__(self, *inputs: Any, **kwargs: Any) -> Any:
|
1559 | 1590 | if (
|
1560 | 1591 | (not in_to_static_mode())
|
|
0 commit comments