-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
Bug description
The _AcceleratorRegistry.register()
method fails when used as a decorator due to an incorrect function signature in the inner do_register
function.
The decorator form @registry.register("name")
raises a TypeError because the nested function expects two arguments (name and accelerator) but it only passes the decorated class as a single argument.
The non-decorator form registry.register("name", MyClass)
works correctly, indicating this is specifically an issue with decorator usage that breaks the documented API.
Steps to Reproduce
Minimal test case:
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.accelerators import Accelerator
registry = _AcceleratorRegistry()
# Registration fails
@registry.register("test")
class MyAcc(Accelerator):
pass
Error output:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 6
2 from lightning.fabric.accelerators import Accelerator
4 registry = _AcceleratorRegistry()
----> 6 @registry.register("test")
7 class MyAcc(Accelerator):
8 pass
TypeError: _AcceleratorRegistry.register.<locals>.do_register() missing 1 required positional argument: 'accelerator'
Root cause
The issue is in src/lightning/fabric/accelerators/registry.py
lines 68-75. The inner do_register
function has an incorrect signature:
def do_register(name: str, accelerator: Callable) -> Callable: # Wrong - has extra 'name' param
data["accelerator"] = accelerator
data["accelerator_name"] = name
self[name] = data
return accelerator
When used as a decorator, it only passes the decorated class as a single argument, but the function expects two arguments (name and accelerator).
Proposed Fix
Pass only the accelerator
parameter to do_register
as name
is already available from the outer function’s scope.
Additional context
The same pattern exists in _StrategyRegistry which has the correct implementation
What version are you seeing the problem on?
master
Environment
Current environment
-
CUDA:
- GPU:
- NVIDIA GeForce GTX 1650
- available: True
- version: 12.6 -
Lightning:
- lightning: 2.6.0.dev0
- lightning-utilities: 0.14.3
- pytorch-lightning: 2.5.2
- torch: 2.7.1
- torchmetrics: 1.7.4 -
System:
- OS: Linux
- architecture: 64bit- processor: x86_64 - python: 3.12.3