Skip to content

Fix is_dataset_splitted is_dataset_split #71713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3248,7 +3248,7 @@ class ShardDataloader:
ShardDataloader converts a dataloader to a new dataloader which provided two capabilities:
1. split dataloader by shard_dim to do data parallel.
2. reshard the output of dataloader to distributed tensor.
if is_dataset_splitted is True, just need to do reshard.
if is_dataset_split is True, just need to do reshard.

Args:
dataloader (paddle.io.DataLoader): The dataloader to be sharded.
Expand All @@ -3262,7 +3262,7 @@ class ShardDataloader:
shard_dims (list|tuple|str|int]): The mesh dimension to shard the dataloader.
Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes.
Default: None, which means the data loader will not be split, i.e. mp.
is_dataset_splitted (bool): Whether the dataset has been splitted.
is_dataset_split (bool): Whether the dataset has been splitted.
dense_tensor_idx (list): A paired 2D list specifies the index of the dense_tensor in the output of dataloader.
It allows users to identify which elements within each output batch are dense_tensor.
first dense_tensor: the dense_tensor return by dataloader.
Expand All @@ -3277,13 +3277,13 @@ def __init__(
meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh],
input_keys: list[str] | tuple[str] | None = None,
shard_dims: list | tuple | str | int | None = None,
is_dataset_splitted: bool = False,
is_dataset_split: bool = False,
dense_tensor_idx: list[list[int]] | None = None,
):
# do some check
if is_dataset_splitted is True and shard_dims is None:
if is_dataset_split is True and shard_dims is None:
raise ValueError(
"shard_dims must be set when is_dataset_splitted is True"
"shard_dims must be set when is_dataset_split is True"
)

self._meshes = to_list(meshes)
Expand All @@ -3310,7 +3310,7 @@ def __init__(
dp_rank = mesh.get_rank_by_dim_and_process_id(shard_dim, process_id)
dp_world_size = mesh.get_dim_size(shard_dim)

if is_dataset_splitted is True or shard_dims is None:
if is_dataset_split is True or shard_dims is None:
self._dataloader = dataloader
self.batch_size = dataloader.batch_sampler.batch_size
else:
Expand Down Expand Up @@ -3566,15 +3566,15 @@ def shard_dataloader(
meshes: ProcessMesh | Sequence[ProcessMesh],
input_keys: Sequence[str] | None = None,
shard_dims: Sequence[str] | Sequence[int] | str | int | None = None,
is_dataset_splitted: bool = False,
is_dataset_split: bool = False,
dense_tensor_idx: list[list[int]] | None = None,
) -> ShardDataloader:
"""
Convert the dataloader to a ShardDataloader which provided two capabilities:
1. split dataloader by shard_dim to do data parallel if it it not None.
2. reshard the output of dataloader to distributed tensor.
if is_dataset_splitted is True, it means that the dataset has been split by users, and just need to do reshard.
only if is_dataset_splitted is False and shard_dims is not None, it will do split.
if is_dataset_split is True, it means that the dataset has been split by users, and just need to do reshard.
only if is_dataset_split is False and shard_dims is not None, it will do split.

Args:
dataloader (paddle.io.DataLoader): The dataloader to be sharded. the output of dataloader
Expand All @@ -3591,7 +3591,7 @@ def shard_dataloader(
The mesh dimension to shard the dataloader.
Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes.
Default: None, which means the data loader will not be split, i.e. mp.
is_dataset_splitted (bool): Whether the dataset has been splitted, Default: False.
is_dataset_split (bool): Whether the dataset has been splitted, Default: False.
dense_tensor_idx (list): A paired 2D list specifies the index of the dense_tensor in the output of dataloader.
It allows users to identify which elements within each output batch are dense_tensor.
first dense_tensor: the dense_tensor return by dataloader.
Expand Down Expand Up @@ -3761,7 +3761,7 @@ def shard_dataloader(
meshes,
input_keys,
shard_dims,
is_dataset_splitted,
is_dataset_split,
dense_tensor_idx,
)

Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/pir/while_unittest_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def create_data_loader(self):
meshes=mesh,
shard_dims="x",
input_keys=["inputs", "label"],
is_dataset_splitted=True,
is_dataset_split=True,
)
return dist_dataloader

Expand Down
Loading