-
Notifications
You must be signed in to change notification settings - Fork 560
Open
Labels
Description
❓ Questions and Help
Hi, I have noticed that when world_size == 1, all_reduce is a no-op and does not apply scale:
In torch_xla.core.xla_model in def all_reduce:
# No-op if there is only one device
if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
bool, False):
if isinstance(inputs, torch.Tensor):
return inputs.clone()
else:
return inputs
Is this intended behavior? If it is indeed intended, it makes the use of all_reduce inconsistent when using world_size == 1 vs world_size > 1. The issue manifests, for example, when you are logging running average loss value:
epoch_loss = xm.all_reduce(xm.REDUCE_SUM, loss_accum, scale=1.0 / ((idx + 1) * world_size))