Skip to content

Commit d78c059

Browse files
authored
[Typing] Fix margin_cross_entropy argument reduction missing None type (#66007)
* [Typing] Fix `margin_cross_entropy` arg reduction missing `None` type * fix some typing * ignore more
1 parent f9212ea commit d78c059

File tree

5 files changed

+22
-16
lines changed

5 files changed

+22
-16
lines changed

python/paddle/base/data_feeder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import struct
18+
from typing import TYPE_CHECKING
1619

1720
import numpy as np
1821

1922
from paddle import pir
20-
from paddle._typing.dtype_like import DTypeLike
2123

2224
from ..pir import Value
2325
from ..pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE, ParameterMeta
@@ -31,6 +33,10 @@
3133
in_pir_mode,
3234
)
3335

36+
if TYPE_CHECKING:
37+
from paddle._typing import DTypeLike
38+
from paddle._typing.dtype_like import _DTypeLiteral
39+
3440
__all__ = []
3541

3642
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
@@ -92,7 +98,7 @@ def convert_uint16_to_float(data):
9298
return np.reshape(new_data, data.shape)
9399

94100

95-
def convert_dtype(dtype: DTypeLike) -> str:
101+
def convert_dtype(dtype: DTypeLike) -> _DTypeLiteral:
96102
if isinstance(dtype, core.VarDesc.VarType):
97103
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
98104
return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]

python/paddle/distribution/lkj_cholesky.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
if TYPE_CHECKING:
3030
from paddle import Tensor
31-
from paddle._typing import DTypeLike
31+
from paddle._typing.dtype_like import _DTypeLiteral
3232

3333

3434
__all__ = ["LKJCholesky"]
@@ -147,7 +147,7 @@ class LKJCholesky(distribution.Distribution):
147147
"""
148148

149149
concentration: Tensor
150-
dtype: DTypeLike
150+
dtype: _DTypeLiteral
151151
dim: int
152152
sample_method: Literal["onion", "cvine"]
153153

python/paddle/io/dataloader/batch_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BatchSampler(Sampler[Sequence[int]]):
7070
7171
>>> np.random.seed(2023)
7272
>>> # init with dataset
73-
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
73+
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
7474
... def __init__(self, num_samples):
7575
... self.num_samples = num_samples
7676
...
@@ -226,7 +226,7 @@ class DistributedBatchSampler(BatchSampler):
226226
>>> from paddle.io import Dataset, DistributedBatchSampler
227227
228228
>>> # init with dataset
229-
>>> class RandomDataset(Dataset):
229+
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
230230
... def __init__(self, num_samples):
231231
... self.num_samples = num_samples
232232
...
@@ -386,7 +386,7 @@ def set_epoch(self, epoch: int) -> None:
386386
>>> from paddle.io import Dataset, DistributedBatchSampler
387387
388388
>>> # init with dataset
389-
>>> class RandomDataset(Dataset):
389+
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
390390
... def __init__(self, num_samples):
391391
... self.num_samples = num_samples
392392
...

python/paddle/io/dataloader/sampler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class Sampler(Generic[_T]):
6969
>>> import numpy as np
7070
>>> from paddle.io import Dataset, Sampler
7171
72-
>>> class RandomDataset(Dataset):
72+
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
7373
... def __init__(self, num_samples):
7474
... self.num_samples = num_samples
7575
...
@@ -81,15 +81,15 @@ class Sampler(Generic[_T]):
8181
... def __len__(self):
8282
... return self.num_samples
8383
...
84-
>>> class MySampler(Sampler):
84+
>>> class MySampler(Sampler): # type: ignore[type-arg]
8585
... def __init__(self, data_source):
8686
... self.data_source = data_source
8787
...
8888
... def __iter__(self):
89-
... return iter(range(len(self.data_source)))
89+
... return iter(range(len(self.data_source))) # type: ignore[arg-type]
9090
...
9191
... def __len__(self):
92-
... return len(self.data_source)
92+
... return len(self.data_source) # type: ignore[arg-type]
9393
...
9494
>>> sampler = MySampler(data_source=RandomDataset(100))
9595
@@ -142,7 +142,7 @@ class SequenceSampler(Sampler[int]):
142142
>>> import numpy as np
143143
>>> from paddle.io import Dataset, SequenceSampler
144144
145-
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
145+
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
146146
... def __init__(self, num_samples):
147147
... self.num_samples = num_samples
148148
...
@@ -205,7 +205,7 @@ class RandomSampler(Sampler[int]):
205205
>>> from paddle.io import Dataset, RandomSampler
206206
207207
>>> np.random.seed(2023)
208-
>>> class RandomDataset(Dataset):
208+
>>> class RandomDataset(Dataset): # type: ignore[type-arg]
209209
... def __init__(self, num_samples):
210210
... self.num_samples = num_samples
211211
...

python/paddle/nn/functional/loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,7 +2175,7 @@ def margin_cross_entropy(
21752175
scale: float = ...,
21762176
group=...,
21772177
return_softmax: Literal[True] = ...,
2178-
reduction: _ReduceMode = ...,
2178+
reduction: _ReduceMode | None = ...,
21792179
) -> tuple[Tensor, Tensor]:
21802180
...
21812181

@@ -2190,7 +2190,7 @@ def margin_cross_entropy(
21902190
scale: float = ...,
21912191
group=...,
21922192
return_softmax: Literal[False] = ...,
2193-
reduction: _ReduceMode = ...,
2193+
reduction: _ReduceMode | None = ...,
21942194
) -> Tensor:
21952195
...
21962196

@@ -2205,7 +2205,7 @@ def margin_cross_entropy(
22052205
scale: float = ...,
22062206
group=...,
22072207
return_softmax: bool = ...,
2208-
reduction: _ReduceMode = ...,
2208+
reduction: _ReduceMode | None = ...,
22092209
) -> Tensor | tuple[Tensor, Tensor]:
22102210
...
22112211

0 commit comments

Comments
 (0)