7
7
import sys
8
8
import warnings
9
9
from abc import ABC , abstractmethod
10
+ from collections import OrderedDict
10
11
from contextvars import ContextVar
11
12
from copy import deepcopy
12
13
from dataclasses import dataclass
17
18
from typing import (
18
19
TYPE_CHECKING ,
19
20
Dict ,
21
+ Mapping ,
20
22
Optional ,
21
23
Protocol ,
24
+ Set ,
22
25
Tuple ,
23
26
Type ,
24
27
Union ,
@@ -868,10 +871,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868
871
raise ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869
872
870
873
if device is None :
871
- device_type = _find_device (model )
874
+ device = _find_device (model )
875
+ device_type = _find_device_type (model )
872
876
elif isinstance (device , str ):
873
877
_validate_device_type (device )
878
+ import torch
879
+
874
880
device_type = Device (type = device )
881
+ device = torch .device (device )
875
882
else :
876
883
device_type = Device (device .type )
877
884
@@ -884,7 +891,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884
891
layer_name = module_class .kernel_layer_name
885
892
886
893
if _DISABLE_KERNEL_MAPPING :
887
- _replace_forward (module , module_class )
894
+ _replace_forward (device , module , module_class )
888
895
continue
889
896
890
897
kernel = _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +905,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898
905
)
899
906
if not use_fallback :
900
907
raise ValueError (f"No layer mapping for `{ layer_name } `" )
901
- _replace_forward (module , module_class )
908
+ _replace_forward (device , module , module_class )
902
909
continue
903
910
904
911
# Get kernel options for the device
@@ -909,7 +916,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909
916
raise ValueError (
910
917
f"No layer mapping for `{ layer_name } ` with device type `{ device_type } `"
911
918
)
912
- _replace_forward (module , module_class )
919
+ _replace_forward (device , module , module_class )
913
920
continue
914
921
915
922
repos = property_repos .repos
@@ -919,7 +926,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919
926
raise ValueError (
920
927
f"No layer mapping for `{ layer_name } ` device `{ device_type } ` with the right properties"
921
928
)
922
- _replace_forward (module , module_class )
929
+ _replace_forward (device , module , module_class )
923
930
continue
924
931
925
932
repo_with_mode = _select_repository (
@@ -932,7 +939,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932
939
raise ValueError (
933
940
f"No repository for `{ layer_name } ` for configuration mode={ mode } "
934
941
)
935
- _replace_forward (module , module_class )
942
+ _replace_forward (device , module , module_class )
936
943
continue
937
944
938
945
repo , repo_mode = repo_with_mode
@@ -951,6 +958,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
958
)
952
959
953
960
_conditionally_replace_forward (
961
+ device = device ,
954
962
module = module ,
955
963
layer = layer ,
956
964
mode = mode ,
@@ -1037,19 +1045,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
1037
1045
raise TypeError (f"{ repo } must not override nn.Module constructor." )
1038
1046
1039
1047
# ... or predefined member variables.
1040
- torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1041
- cls_members = {name for name , _ in inspect .getmembers (cls )}
1042
- difference = cls_members - torch_module_members
1048
+ unique_members = _unique_layer_members (cls )
1043
1049
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044
- if not difference <= {"can_torch_compile" , "has_backward" }:
1050
+ if not unique_members <= {
1051
+ "can_torch_compile" ,
1052
+ "create_state" ,
1053
+ "has_backward" ,
1054
+ "forward_with_state" ,
1055
+ }:
1045
1056
raise TypeError (
1046
1057
f"{ repo } must not contain additional members compared to `{ check_cls .__name__ } `."
1047
1058
)
1048
1059
1049
1060
# Check whether the forward signatures are similar.
1050
- params = inspect .signature (cls .forward ).parameters
1051
1061
ref_params = inspect .signature (check_cls .forward ).parameters
1052
1062
1063
+ params : Mapping [str , inspect .Parameter ]
1064
+ if _is_stateful_layer (cls ):
1065
+ params = inspect .signature (cls .forward_with_state ).parameters
1066
+ # Get rid of the mappingproxy.
1067
+ params = params .copy ()
1068
+ # Remove the state to be able to compare with forward.
1069
+ del params ["state" ]
1070
+ else :
1071
+ params = inspect .signature (cls .forward ).parameters
1072
+
1053
1073
if len (params ) != len (ref_params ):
1054
1074
raise TypeError (
1055
1075
f"Forward signature of { repo } does not match `{ check_cls .__name__ } `: different number of arguments."
@@ -1074,15 +1094,21 @@ def _is_rocm_platform():
1074
1094
return torch .version .hip is not None
1075
1095
1076
1096
1077
- def _find_device (model : "nn.Module" ) -> Device :
1097
+ def _find_device (model : "nn.Module" ) -> torch . device :
1078
1098
try :
1079
1099
param = next (model .parameters ())
1080
1100
except StopIteration :
1081
1101
raise ValueError (
1082
1102
"Cannot determine model device, provide as `device` argument to `kernelize`."
1083
1103
)
1084
1104
1085
- dev_type = param .device .type
1105
+ return param .device
1106
+
1107
+
1108
+ def _find_device_type (model : "nn.Module" ) -> Device :
1109
+ device = _find_device (model )
1110
+
1111
+ dev_type = device .type
1086
1112
if dev_type == "cuda" :
1087
1113
# Refine based on actual platform
1088
1114
if _is_rocm_platform ():
@@ -1103,6 +1129,7 @@ def _find_capability() -> int:
1103
1129
1104
1130
def _conditionally_replace_forward (
1105
1131
* ,
1132
+ device : "torch.device" ,
1106
1133
module : "nn.Module" ,
1107
1134
layer : Type ["nn.Module" ],
1108
1135
mode : Mode ,
@@ -1128,15 +1155,25 @@ def _conditionally_replace_forward(
1128
1155
logging .info ("Layer does not support torch.compile, using fallback" )
1129
1156
if needs_fallback_for_backward :
1130
1157
logging .info ("Layer does not support backward, using fallback" )
1131
- _replace_forward (module , module_class )
1158
+ _replace_forward (device , module , module_class )
1132
1159
else :
1133
1160
raise ValueError (f"Available kernel does not support mode: { mode } " )
1134
1161
else :
1135
- _replace_forward (module , layer )
1162
+ _replace_forward (device , module , layer )
1136
1163
1137
1164
1138
- def _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139
- module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1165
+ def _replace_forward (
1166
+ device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1167
+ ):
1168
+ if _is_stateful_layer (layer ):
1169
+ state = layer .create_state (device , module ) # type: ignore[attr-defined]
1170
+
1171
+ def forward (self , * args , ** kwargs ):
1172
+ return layer .forward_with_state (self , state , * args , ** kwargs )
1173
+
1174
+ module .forward = MethodType (forward , module )
1175
+ else :
1176
+ module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1140
1177
1141
1178
1142
1179
def _validate_layer_has_mode (
@@ -1179,3 +1216,21 @@ def _get_layer_memoize(
1179
1216
_CACHED_LAYER [repo ] = layer
1180
1217
1181
1218
return layer
1219
+
1220
+
1221
+ def _unique_layer_members (layer : Type ["nn.Module" ]) -> Set [str ]:
1222
+ import torch .nn as nn
1223
+
1224
+ torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1225
+ cls_members = {name for name , _ in inspect .getmembers (layer )}
1226
+ return cls_members - torch_module_members
1227
+
1228
+
1229
+ def _is_stateful_layer (layer : Type [nn .Module ]) -> bool :
1230
+ unique = _unique_layer_members (layer )
1231
+ is_stateful = "forward_with_state" in unique
1232
+ if is_stateful and len (unique & {"create_state" , "forward_with_state" }) != 2 :
1233
+ raise TypeError (
1234
+ f"Stateful layer `{ layer .__name__ } ` must implement both `create_state` and `forward_with_state` or neither."
1235
+ )
1236
+ return is_stateful
0 commit comments