Skip to content

Commit 5b3b3f4

Browse files
committed
Add support for stateful layers
1 parent 055a953 commit 5b3b3f4

File tree

2 files changed

+114
-17
lines changed

2 files changed

+114
-17
lines changed

src/kernels/layer.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import warnings
99
from abc import ABC, abstractmethod
10+
from collections import OrderedDict
1011
from contextvars import ContextVar
1112
from copy import deepcopy
1213
from dataclasses import dataclass
@@ -17,8 +18,10 @@
1718
from typing import (
1819
TYPE_CHECKING,
1920
Dict,
21+
Mapping,
2022
Optional,
2123
Protocol,
24+
Set,
2225
Tuple,
2326
Type,
2427
Union,
@@ -868,10 +871,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868871
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
869872

870873
if device is None:
871-
device_type = _find_device(model)
874+
device = _find_device(model)
875+
device_type = _find_device_type(model)
872876
elif isinstance(device, str):
873877
_validate_device_type(device)
878+
import torch
879+
874880
device_type = Device(type=device)
881+
device = torch.device(device)
875882
else:
876883
device_type = Device(device.type)
877884

@@ -884,7 +891,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884891
layer_name = module_class.kernel_layer_name
885892

886893
if _DISABLE_KERNEL_MAPPING:
887-
_replace_forward(module, module_class)
894+
_replace_forward(device, module, module_class)
888895
continue
889896

890897
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
@@ -898,7 +905,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898905
)
899906
if not use_fallback:
900907
raise ValueError(f"No layer mapping for `{layer_name}`")
901-
_replace_forward(module, module_class)
908+
_replace_forward(device, module, module_class)
902909
continue
903910

904911
# Get kernel options for the device
@@ -909,7 +916,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909916
raise ValueError(
910917
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
911918
)
912-
_replace_forward(module, module_class)
919+
_replace_forward(device, module, module_class)
913920
continue
914921

915922
repos = property_repos.repos
@@ -919,7 +926,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919926
raise ValueError(
920927
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
921928
)
922-
_replace_forward(module, module_class)
929+
_replace_forward(device, module, module_class)
923930
continue
924931

925932
repo_with_mode = _select_repository(
@@ -932,7 +939,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932939
raise ValueError(
933940
f"No repository for `{layer_name}` for configuration mode={mode}"
934941
)
935-
_replace_forward(module, module_class)
942+
_replace_forward(device, module, module_class)
936943
continue
937944

938945
repo, repo_mode = repo_with_mode
@@ -951,6 +958,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951958
)
952959

953960
_conditionally_replace_forward(
961+
device=device,
954962
module=module,
955963
layer=layer,
956964
mode=mode,
@@ -1037,19 +1045,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371045
raise TypeError(f"{repo} must not override nn.Module constructor.")
10381046

10391047
# ... or predefined member variables.
1040-
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1041-
cls_members = {name for name, _ in inspect.getmembers(cls)}
1042-
difference = cls_members - torch_module_members
1048+
unique_members = _unique_layer_members(cls)
10431049
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044-
if not difference <= {"can_torch_compile", "has_backward"}:
1050+
if not unique_members <= {
1051+
"can_torch_compile",
1052+
"create_state",
1053+
"has_backward",
1054+
"forward_with_state",
1055+
}:
10451056
raise TypeError(
10461057
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
10471058
)
10481059

10491060
# Check whether the forward signatures are similar.
1050-
params = inspect.signature(cls.forward).parameters
10511061
ref_params = inspect.signature(check_cls.forward).parameters
10521062

1063+
params: Mapping[str, inspect.Parameter]
1064+
if _is_stateful_layer(cls):
1065+
params = inspect.signature(cls.forward_with_state).parameters
1066+
# Get rid of the mappingproxy.
1067+
params = params.copy()
1068+
# Remove the state to be able to compare with forward.
1069+
del params["state"]
1070+
else:
1071+
params = inspect.signature(cls.forward).parameters
1072+
10531073
if len(params) != len(ref_params):
10541074
raise TypeError(
10551075
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
@@ -1074,15 +1094,21 @@ def _is_rocm_platform():
10741094
return torch.version.hip is not None
10751095

10761096

1077-
def _find_device(model: "nn.Module") -> Device:
1097+
def _find_device(model: "nn.Module") -> torch.device:
10781098
try:
10791099
param = next(model.parameters())
10801100
except StopIteration:
10811101
raise ValueError(
10821102
"Cannot determine model device, provide as `device` argument to `kernelize`."
10831103
)
10841104

1085-
dev_type = param.device.type
1105+
return param.device
1106+
1107+
1108+
def _find_device_type(model: "nn.Module") -> Device:
1109+
device = _find_device(model)
1110+
1111+
dev_type = device.type
10861112
if dev_type == "cuda":
10871113
# Refine based on actual platform
10881114
if _is_rocm_platform():
@@ -1103,6 +1129,7 @@ def _find_capability() -> int:
11031129

11041130
def _conditionally_replace_forward(
11051131
*,
1132+
device: "torch.device",
11061133
module: "nn.Module",
11071134
layer: Type["nn.Module"],
11081135
mode: Mode,
@@ -1128,15 +1155,25 @@ def _conditionally_replace_forward(
11281155
logging.info("Layer does not support torch.compile, using fallback")
11291156
if needs_fallback_for_backward:
11301157
logging.info("Layer does not support backward, using fallback")
1131-
_replace_forward(module, module_class)
1158+
_replace_forward(device, module, module_class)
11321159
else:
11331160
raise ValueError(f"Available kernel does not support mode: {mode}")
11341161
else:
1135-
_replace_forward(module, layer)
1162+
_replace_forward(device, module, layer)
11361163

11371164

1138-
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
1139-
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
1165+
def _replace_forward(
1166+
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
1167+
):
1168+
if _is_stateful_layer(layer):
1169+
state = layer.create_state(device, module) # type: ignore[attr-defined]
1170+
1171+
def forward(self, *args, **kwargs):
1172+
return layer.forward_with_state(self, state, *args, **kwargs)
1173+
1174+
module.forward = MethodType(forward, module)
1175+
else:
1176+
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
11401177

11411178

11421179
def _validate_layer_has_mode(
@@ -1179,3 +1216,21 @@ def _get_layer_memoize(
11791216
_CACHED_LAYER[repo] = layer
11801217

11811218
return layer
1219+
1220+
1221+
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
1222+
import torch.nn as nn
1223+
1224+
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1225+
cls_members = {name for name, _ in inspect.getmembers(layer)}
1226+
return cls_members - torch_module_members
1227+
1228+
1229+
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
1230+
unique = _unique_layer_members(layer)
1231+
is_stateful = "forward_with_state" in unique
1232+
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
1233+
raise TypeError(
1234+
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
1235+
)
1236+
return is_stateful

tests/test_layer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn as nn
77
from torch.nn import functional as F
8+
from torch.testing import assert_close
89

910
from kernels import (
1011
CUDAProperties,
@@ -321,6 +322,47 @@ def test_local_layer_repo(device):
321322
assert linear.n_calls == 0
322323

323324

325+
def test_stateful_layer(device):
326+
@use_kernel_forward_from_hub("ReluWithHiddenSize")
327+
class ReluWithHiddenSize(nn.Module):
328+
hidden_size: int
329+
330+
def __init__(self, hidden_size: int):
331+
super().__init__()
332+
self.hidden_size = hidden_size
333+
334+
def forward(self, x: torch.Tensor) -> torch.Tensor:
335+
return F.relu(x)
336+
337+
model = ReluWithHiddenSize(hidden_size=64).to(device)
338+
x = torch.randn((32, 64), device=device)
339+
y_ref = model(x)
340+
341+
with use_kernel_mapping(
342+
{
343+
"ReluWithHiddenSize": {
344+
"cuda": LayerRepository(
345+
repo_id="kernels-test/state-test",
346+
layer_name="StatefulReLU",
347+
),
348+
"xpu": LayerRepository(
349+
repo_id="kernels-test/state-test",
350+
layer_name="StatefulReLU",
351+
),
352+
}
353+
},
354+
inherit_mapping=False,
355+
):
356+
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
357+
358+
y = model(x)
359+
assert_close(y, y_ref)
360+
361+
model = torch.compile(model, fullgraph=True)
362+
y = model(x)
363+
assert_close(y, y_ref)
364+
365+
324366
@pytest.mark.cuda_only
325367
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
326368
@pytest.mark.parametrize("device", ["cuda"])

0 commit comments

Comments
 (0)