28
28
import warnings
29
29
from collections .abc import Iterable
30
30
from types import FunctionType , MethodType
31
- from typing import TYPE_CHECKING , Callable , TypeVar
31
+ from typing import TYPE_CHECKING , Callable , TypeVar , overload
32
32
33
33
import numpy as np
34
34
from typing_extensions import ParamSpec
49
49
_RetT = TypeVar ("_RetT" )
50
50
51
51
if TYPE_CHECKING :
52
+ from typing import Generator , Sequence
53
+
52
54
from paddle .static .amp .fp16_utils import AmpOptions
53
55
54
56
__all__ = []
@@ -106,7 +108,7 @@ def _global_flags():
106
108
return _global_flags_
107
109
108
110
109
- def set_flags (flags ) :
111
+ def set_flags (flags : dict [ str , bool | str | float ]) -> None :
110
112
"""
111
113
This function sets the GFlags value in Paddle.
112
114
For FLAGS please refer to :ref:`en_guides_flags_flags`
@@ -131,7 +133,7 @@ def set_flags(flags):
131
133
)
132
134
133
135
134
- def get_flags (flags ) :
136
+ def get_flags (flags : str | Sequence [ str ]) -> dict [ str , bool | str | float ] :
135
137
"""
136
138
This function gets the GFlags value in Paddle.
137
139
For FLAGS please refer to :ref:`en_guides_flags_flags`
@@ -404,7 +406,9 @@ def in_cinn_mode() -> bool:
404
406
405
407
406
408
@signature_safe_contextmanager
407
- def ipu_shard_guard (index = - 1 , stage = - 1 ):
409
+ def ipu_shard_guard (
410
+ index : int = - 1 , stage : int = - 1
411
+ ) -> Generator [None , None , None ]:
408
412
"""
409
413
Used to shard the graph on IPUs. Set each Op run on which IPU in the sharding and which stage in the pipelining.
410
414
@@ -456,6 +460,20 @@ def ipu_shard_guard(index=-1, stage=-1):
456
460
global_ipu_stage = prev_ipu_stage
457
461
458
462
463
+ @overload
464
+ def set_ipu_shard (
465
+ call_func : Callable [_InputT , _RetT ], index : int = ..., stage : int = ...
466
+ ) -> Callable [_InputT , _RetT ]:
467
+ ...
468
+
469
+
470
+ @overload
471
+ def set_ipu_shard (
472
+ call_func : paddle .nn .Layer , index : int = ..., stage : int = ...
473
+ ) -> paddle .nn .Layer :
474
+ ...
475
+
476
+
459
477
def set_ipu_shard (call_func , index = - 1 , stage = - 1 ):
460
478
"""
461
479
Shard the ipu with the given call function. Set every ops in call function to the given ipu sharding.
@@ -467,9 +485,9 @@ def set_ipu_shard(call_func, index=-1, stage=-1):
467
485
468
486
Args:
469
487
call_func(Layer|function): Specify the call function to be wrapped.
470
- index(int, optional): Specify which ipu the Tensor is computed on, (such as ‘ 0, 1, 2, 3’ ).
488
+ index(int, optional): Specify which ipu the Tensor is computed on, (such as ' 0, 1, 2, 3' ).
471
489
The default value is -1, which means the Op only run on IPU 0.
472
- stage(int, optional): Specify the computation order of the sharded model(such as ‘ 0, 1, 2, 3’ ).
490
+ stage(int, optional): Specify the computation order of the sharded model(such as ' 0, 1, 2, 3' ).
473
491
The sharded model will be computed from small to large. The default value is -1,
474
492
which means no pipelining computation order and run Ops in terms of graph.
475
493
@@ -489,8 +507,8 @@ def set_ipu_shard(call_func, index=-1, stage=-1):
489
507
>>> relu(a)
490
508
"""
491
509
492
- def decorate (func ) :
493
- def wrapper (* args , ** kwargs ) :
510
+ def decorate (func : Callable [ _InputT , _RetT ]) -> Callable [ _InputT , _RetT ] :
511
+ def wrapper (* args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
494
512
with ipu_shard_guard (index = index , stage = stage ):
495
513
return func (* args , ** kwargs )
496
514
@@ -843,7 +861,7 @@ def _custom_device_ids(device_type):
843
861
return device_ids
844
862
845
863
846
- def is_compiled_with_xpu ():
864
+ def is_compiled_with_xpu () -> bool :
847
865
"""
848
866
Whether this whl package can be used to run the model on XPU.
849
867
@@ -858,7 +876,7 @@ def is_compiled_with_xpu():
858
876
return core .is_compiled_with_xpu ()
859
877
860
878
861
- def disable_signal_handler ():
879
+ def disable_signal_handler () -> None :
862
880
"""
863
881
Reset signal handler registered by Paddle.
864
882
@@ -884,7 +902,7 @@ def disable_signal_handler():
884
902
core .disable_signal_handler ()
885
903
886
904
887
- def is_compiled_with_cinn ():
905
+ def is_compiled_with_cinn () -> bool :
888
906
"""
889
907
Whether this whl package can be used to run the model on CINN.
890
908
@@ -900,7 +918,7 @@ def is_compiled_with_cinn():
900
918
return core .is_compiled_with_cinn ()
901
919
902
920
903
- def is_compiled_with_cuda ():
921
+ def is_compiled_with_cuda () -> bool :
904
922
"""
905
923
Whether this whl package can be used to run the model on GPU.
906
924
@@ -916,7 +934,7 @@ def is_compiled_with_cuda():
916
934
return core .is_compiled_with_cuda ()
917
935
918
936
919
- def is_compiled_with_distribute ():
937
+ def is_compiled_with_distribute () -> bool :
920
938
"""
921
939
Whether this whl package can be used to run the model with distribute.
922
940
@@ -932,7 +950,7 @@ def is_compiled_with_distribute():
932
950
return core .is_compiled_with_distribute ()
933
951
934
952
935
- def is_compiled_with_rocm ():
953
+ def is_compiled_with_rocm () -> bool :
936
954
"""
937
955
Whether this whl package can be used to run the model on AMD or Hygon GPU(ROCm).
938
956
@@ -948,7 +966,9 @@ def is_compiled_with_rocm():
948
966
return core .is_compiled_with_rocm ()
949
967
950
968
951
- def cuda_places (device_ids = None ):
969
+ def cuda_places (
970
+ device_ids : Sequence [int ] | None = None ,
971
+ ) -> list [core .CUDAPlace ]:
952
972
"""
953
973
Note:
954
974
For multi-card tasks, please use `FLAGS_selected_gpus` environment variable to set the visible GPU device.
@@ -996,7 +1016,7 @@ def cuda_places(device_ids=None):
996
1016
return [core .CUDAPlace (dev_id ) for dev_id in device_ids ]
997
1017
998
1018
999
- def xpu_places (device_ids = None ):
1019
+ def xpu_places (device_ids : Sequence [ int ] | None = None ) -> list [ core . XPUPlace ] :
1000
1020
"""
1001
1021
**Note**:
1002
1022
For multi-card tasks, please use `FLAGS_selected_xpus` environment variable to set the visible XPU device.
@@ -1035,7 +1055,7 @@ def xpu_places(device_ids=None):
1035
1055
return [core .XPUPlace (dev_id ) for dev_id in device_ids ]
1036
1056
1037
1057
1038
- def cpu_places (device_count = None ):
1058
+ def cpu_places (device_count : int | None = None ) -> list [ core . CPUPlace ] :
1039
1059
"""
1040
1060
This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list.
1041
1061
@@ -1069,7 +1089,9 @@ def cpu_places(device_count=None):
1069
1089
return [core .CPUPlace ()] * device_count
1070
1090
1071
1091
1072
- def cuda_pinned_places (device_count = None ):
1092
+ def cuda_pinned_places (
1093
+ device_count : int | None = None ,
1094
+ ) -> list [core .CUDAPinnedPlace ]:
1073
1095
"""
1074
1096
This function creates a list of :code:`base.CUDAPinnedPlace` objects.
1075
1097
@@ -1130,7 +1152,7 @@ def name(self):
1130
1152
1131
1153
1132
1154
@signature_safe_contextmanager
1133
- def name_scope (prefix = None ):
1155
+ def name_scope (prefix : str | None = None ) -> Generator [ None , None , None ] :
1134
1156
"""
1135
1157
1136
1158
Generate hierarchical name prefix for the operators in Static Graph.
@@ -7784,7 +7806,7 @@ def _copy_to(self, device, blocking):
7784
7806
_startup_program_ ._is_start_up_program_ = True
7785
7807
7786
7808
7787
- def default_startup_program ():
7809
+ def default_startup_program () -> Program :
7788
7810
"""
7789
7811
Get default/global startup program.
7790
7812
@@ -7813,7 +7835,7 @@ def default_startup_program():
7813
7835
return _startup_program_
7814
7836
7815
7837
7816
- def default_main_program ():
7838
+ def default_main_program () -> Program :
7817
7839
"""
7818
7840
This API can be used to get ``default main program`` which store the
7819
7841
descriptions of Ops and tensors.
@@ -7850,7 +7872,7 @@ def default_main_program():
7850
7872
return _main_program_
7851
7873
7852
7874
7853
- def switch_main_program (program ) :
7875
+ def switch_main_program (program : Program ) -> Program :
7854
7876
"""
7855
7877
Switch the main program to a new program.
7856
7878
@@ -7866,7 +7888,7 @@ def switch_main_program(program):
7866
7888
return prev_program
7867
7889
7868
7890
7869
- def switch_startup_program (program ) :
7891
+ def switch_startup_program (program : Program ) -> Program :
7870
7892
"""
7871
7893
Switch the startup program to a new program
7872
7894
Args:
@@ -7882,7 +7904,9 @@ def switch_startup_program(program):
7882
7904
7883
7905
7884
7906
@signature_safe_contextmanager
7885
- def program_guard (main_program , startup_program = None ):
7907
+ def program_guard (
7908
+ main_program : Program , startup_program : Program | None = None
7909
+ ) -> Generator [None , None , None ]:
7886
7910
"""
7887
7911
:api_attr: Static Graph
7888
7912
@@ -8016,7 +8040,7 @@ def switch_device(device):
8016
8040
8017
8041
8018
8042
@signature_safe_contextmanager
8019
- def device_guard (device = None ):
8043
+ def device_guard (device : str | None = None ) -> Generator [ None , None , None ] :
8020
8044
"""
8021
8045
8022
8046
Note:
@@ -8233,7 +8257,7 @@ def dtype_to_str(in_dtype):
8233
8257
elif in_dtype == core .VarDesc .VarType .COMPLEX128 :
8234
8258
return "complex128"
8235
8259
else :
8236
- raise TypeError (f"got unspport data type for promotion: { in_dtype } ." )
8260
+ raise TypeError (f"got unsupport data type for promotion: { in_dtype } ." )
8237
8261
8238
8262
8239
8263
def add_cast_for_type_promotion (op , block , idx , var_name , out_dtype ):
0 commit comments