Skip to content

Commit b20b074

Browse files
authored
[Typing][B-96] Add type annotations for python/paddle/base/framework.py (#66301)
1 parent c1b8092 commit b20b074

File tree

1 file changed

+50
-26
lines changed

1 file changed

+50
-26
lines changed

python/paddle/base/framework.py

+50-26
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import warnings
2929
from collections.abc import Iterable
3030
from types import FunctionType, MethodType
31-
from typing import TYPE_CHECKING, Callable, TypeVar
31+
from typing import TYPE_CHECKING, Callable, TypeVar, overload
3232

3333
import numpy as np
3434
from typing_extensions import ParamSpec
@@ -49,6 +49,8 @@
4949
_RetT = TypeVar("_RetT")
5050

5151
if TYPE_CHECKING:
52+
from typing import Generator, Sequence
53+
5254
from paddle.static.amp.fp16_utils import AmpOptions
5355

5456
__all__ = []
@@ -106,7 +108,7 @@ def _global_flags():
106108
return _global_flags_
107109

108110

109-
def set_flags(flags):
111+
def set_flags(flags: dict[str, bool | str | float]) -> None:
110112
"""
111113
This function sets the GFlags value in Paddle.
112114
For FLAGS please refer to :ref:`en_guides_flags_flags`
@@ -131,7 +133,7 @@ def set_flags(flags):
131133
)
132134

133135

134-
def get_flags(flags):
136+
def get_flags(flags: str | Sequence[str]) -> dict[str, bool | str | float]:
135137
"""
136138
This function gets the GFlags value in Paddle.
137139
For FLAGS please refer to :ref:`en_guides_flags_flags`
@@ -404,7 +406,9 @@ def in_cinn_mode() -> bool:
404406

405407

406408
@signature_safe_contextmanager
407-
def ipu_shard_guard(index=-1, stage=-1):
409+
def ipu_shard_guard(
410+
index: int = -1, stage: int = -1
411+
) -> Generator[None, None, None]:
408412
"""
409413
Used to shard the graph on IPUs. Set each Op run on which IPU in the sharding and which stage in the pipelining.
410414
@@ -456,6 +460,20 @@ def ipu_shard_guard(index=-1, stage=-1):
456460
global_ipu_stage = prev_ipu_stage
457461

458462

463+
@overload
464+
def set_ipu_shard(
465+
call_func: Callable[_InputT, _RetT], index: int = ..., stage: int = ...
466+
) -> Callable[_InputT, _RetT]:
467+
...
468+
469+
470+
@overload
471+
def set_ipu_shard(
472+
call_func: paddle.nn.Layer, index: int = ..., stage: int = ...
473+
) -> paddle.nn.Layer:
474+
...
475+
476+
459477
def set_ipu_shard(call_func, index=-1, stage=-1):
460478
"""
461479
Shard the ipu with the given call function. Set every ops in call function to the given ipu sharding.
@@ -467,9 +485,9 @@ def set_ipu_shard(call_func, index=-1, stage=-1):
467485
468486
Args:
469487
call_func(Layer|function): Specify the call function to be wrapped.
470-
index(int, optional): Specify which ipu the Tensor is computed on, (such as 0, 1, 2, 3).
488+
index(int, optional): Specify which ipu the Tensor is computed on, (such as '0, 1, 2, 3').
471489
The default value is -1, which means the Op only run on IPU 0.
472-
stage(int, optional): Specify the computation order of the sharded model(such as 0, 1, 2, 3).
490+
stage(int, optional): Specify the computation order of the sharded model(such as '0, 1, 2, 3').
473491
The sharded model will be computed from small to large. The default value is -1,
474492
which means no pipelining computation order and run Ops in terms of graph.
475493
@@ -489,8 +507,8 @@ def set_ipu_shard(call_func, index=-1, stage=-1):
489507
>>> relu(a)
490508
"""
491509

492-
def decorate(func):
493-
def wrapper(*args, **kwargs):
510+
def decorate(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
511+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
494512
with ipu_shard_guard(index=index, stage=stage):
495513
return func(*args, **kwargs)
496514

@@ -843,7 +861,7 @@ def _custom_device_ids(device_type):
843861
return device_ids
844862

845863

846-
def is_compiled_with_xpu():
864+
def is_compiled_with_xpu() -> bool:
847865
"""
848866
Whether this whl package can be used to run the model on XPU.
849867
@@ -858,7 +876,7 @@ def is_compiled_with_xpu():
858876
return core.is_compiled_with_xpu()
859877

860878

861-
def disable_signal_handler():
879+
def disable_signal_handler() -> None:
862880
"""
863881
Reset signal handler registered by Paddle.
864882
@@ -884,7 +902,7 @@ def disable_signal_handler():
884902
core.disable_signal_handler()
885903

886904

887-
def is_compiled_with_cinn():
905+
def is_compiled_with_cinn() -> bool:
888906
"""
889907
Whether this whl package can be used to run the model on CINN.
890908
@@ -900,7 +918,7 @@ def is_compiled_with_cinn():
900918
return core.is_compiled_with_cinn()
901919

902920

903-
def is_compiled_with_cuda():
921+
def is_compiled_with_cuda() -> bool:
904922
"""
905923
Whether this whl package can be used to run the model on GPU.
906924
@@ -916,7 +934,7 @@ def is_compiled_with_cuda():
916934
return core.is_compiled_with_cuda()
917935

918936

919-
def is_compiled_with_distribute():
937+
def is_compiled_with_distribute() -> bool:
920938
"""
921939
Whether this whl package can be used to run the model with distribute.
922940
@@ -932,7 +950,7 @@ def is_compiled_with_distribute():
932950
return core.is_compiled_with_distribute()
933951

934952

935-
def is_compiled_with_rocm():
953+
def is_compiled_with_rocm() -> bool:
936954
"""
937955
Whether this whl package can be used to run the model on AMD or Hygon GPU(ROCm).
938956
@@ -948,7 +966,9 @@ def is_compiled_with_rocm():
948966
return core.is_compiled_with_rocm()
949967

950968

951-
def cuda_places(device_ids=None):
969+
def cuda_places(
970+
device_ids: Sequence[int] | None = None,
971+
) -> list[core.CUDAPlace]:
952972
"""
953973
Note:
954974
For multi-card tasks, please use `FLAGS_selected_gpus` environment variable to set the visible GPU device.
@@ -996,7 +1016,7 @@ def cuda_places(device_ids=None):
9961016
return [core.CUDAPlace(dev_id) for dev_id in device_ids]
9971017

9981018

999-
def xpu_places(device_ids=None):
1019+
def xpu_places(device_ids: Sequence[int] | None = None) -> list[core.XPUPlace]:
10001020
"""
10011021
**Note**:
10021022
For multi-card tasks, please use `FLAGS_selected_xpus` environment variable to set the visible XPU device.
@@ -1035,7 +1055,7 @@ def xpu_places(device_ids=None):
10351055
return [core.XPUPlace(dev_id) for dev_id in device_ids]
10361056

10371057

1038-
def cpu_places(device_count=None):
1058+
def cpu_places(device_count: int | None = None) -> list[core.CPUPlace]:
10391059
"""
10401060
This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list.
10411061
@@ -1069,7 +1089,9 @@ def cpu_places(device_count=None):
10691089
return [core.CPUPlace()] * device_count
10701090

10711091

1072-
def cuda_pinned_places(device_count=None):
1092+
def cuda_pinned_places(
1093+
device_count: int | None = None,
1094+
) -> list[core.CUDAPinnedPlace]:
10731095
"""
10741096
This function creates a list of :code:`base.CUDAPinnedPlace` objects.
10751097
@@ -1130,7 +1152,7 @@ def name(self):
11301152

11311153

11321154
@signature_safe_contextmanager
1133-
def name_scope(prefix=None):
1155+
def name_scope(prefix: str | None = None) -> Generator[None, None, None]:
11341156
"""
11351157
11361158
Generate hierarchical name prefix for the operators in Static Graph.
@@ -7784,7 +7806,7 @@ def _copy_to(self, device, blocking):
77847806
_startup_program_._is_start_up_program_ = True
77857807

77867808

7787-
def default_startup_program():
7809+
def default_startup_program() -> Program:
77887810
"""
77897811
Get default/global startup program.
77907812
@@ -7813,7 +7835,7 @@ def default_startup_program():
78137835
return _startup_program_
78147836

78157837

7816-
def default_main_program():
7838+
def default_main_program() -> Program:
78177839
"""
78187840
This API can be used to get ``default main program`` which store the
78197841
descriptions of Ops and tensors.
@@ -7850,7 +7872,7 @@ def default_main_program():
78507872
return _main_program_
78517873

78527874

7853-
def switch_main_program(program):
7875+
def switch_main_program(program: Program) -> Program:
78547876
"""
78557877
Switch the main program to a new program.
78567878
@@ -7866,7 +7888,7 @@ def switch_main_program(program):
78667888
return prev_program
78677889

78687890

7869-
def switch_startup_program(program):
7891+
def switch_startup_program(program: Program) -> Program:
78707892
"""
78717893
Switch the startup program to a new program
78727894
Args:
@@ -7882,7 +7904,9 @@ def switch_startup_program(program):
78827904

78837905

78847906
@signature_safe_contextmanager
7885-
def program_guard(main_program, startup_program=None):
7907+
def program_guard(
7908+
main_program: Program, startup_program: Program | None = None
7909+
) -> Generator[None, None, None]:
78867910
"""
78877911
:api_attr: Static Graph
78887912
@@ -8016,7 +8040,7 @@ def switch_device(device):
80168040

80178041

80188042
@signature_safe_contextmanager
8019-
def device_guard(device=None):
8043+
def device_guard(device: str | None = None) -> Generator[None, None, None]:
80208044
"""
80218045
80228046
Note:
@@ -8233,7 +8257,7 @@ def dtype_to_str(in_dtype):
82338257
elif in_dtype == core.VarDesc.VarType.COMPLEX128:
82348258
return "complex128"
82358259
else:
8236-
raise TypeError(f"got unspport data type for promotion: {in_dtype}.")
8260+
raise TypeError(f"got unsupport data type for promotion: {in_dtype}.")
82378261

82388262

82398263
def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):

0 commit comments

Comments
 (0)