130
130
"remainder_" : ["x" , "y" ],
131
131
}
132
132
133
+ # ops support casting int tensor into float32 to do forward calculation
134
+ type_autocast_op_list = {
135
+ "acos" : ["x" ],
136
+ "acosh" : ["x" ],
137
+ "asin" : ["x" ],
138
+ "asinh" : ["x" ],
139
+ "atan" : ["x" ],
140
+ "atanh" : ["x" ],
141
+ "ceil" : ["x" ],
142
+ "cos" : ["x" ],
143
+ "cosh" : ["x" ],
144
+ "digamma" : ["x" ],
145
+ "erf" : ["x" ],
146
+ "erfinv" : ["x" ],
147
+ "floor" : ["x" ],
148
+ "i0" : ["x" ],
149
+ "i0e" : ["x" ],
150
+ "i1" : ["x" ],
151
+ "i1e" : ["x" ],
152
+ "lgamma" : ["x" ],
153
+ "logcumsumexp" : ["x" ],
154
+ "logit" : ["x" ],
155
+ "logsumexp" : ["x" ],
156
+ "polygamma" : ["x" ],
157
+ "reciprocal" : ["x" ],
158
+ "rsqrt" : ["x" ],
159
+ "sigmoid" : ["x" ],
160
+ "sin" : ["x" ],
161
+ "sinh" : ["x" ],
162
+ "sqrt" : ["x" ],
163
+ "stanh" : ["x" ],
164
+ "tan" : ["x" ],
165
+ "tanh" : ["x" ],
166
+ }
167
+
168
+ # ops support casting int tensor into float32 to do forward calculation,
169
+ # and it is valid to cast float32 gradient back to int tensor.
170
+ type_autocast_valid_grad_op_list = {
171
+ "ceil" ,
172
+ "floor" ,
173
+ }
174
+
133
175
# dict of special api that forward api's output will affect backward api's output
134
176
# backward api's output usually affected by backward api's input
135
177
@@ -326,6 +368,8 @@ class {} : public egr::GradNodeBase {{
326
368
// AMP Logic
327
369
{}
328
370
// Type promotion Logic
371
+ {}
372
+ // Type autocast Logic
329
373
{}
330
374
// Layout autotune
331
375
{}
@@ -403,6 +447,8 @@ class {} : public egr::GradNodeBase {{
403
447
// AMP Logic
404
448
{}
405
449
// Type promotion Logic
450
+ {}
451
+ // Type autocast Logic
406
452
{}
407
453
// Layout autotune
408
454
{}
@@ -617,6 +663,15 @@ class {} : public egr::GradNodeBase {{
617
663
}}
618
664
"""
619
665
666
+ TYPE_AUTOCAST_LOGIC_TEMPLATE = """
667
+ if (phi::NeedTypeAutoCast({op_func_name}, {x}.dtype())) {{
668
+ VLOG(5) << "math operation got integer input data type, run type autocast.";
669
+ LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
670
+ {op_name}
671
+ auto new_{x} = egr::PromoteCast("{x}", {x}, phi::DataType::FLOAT32, {trace_backward});
672
+ {return_value}
673
+ }}
674
+ """
620
675
621
676
LAYOUT_LOGIC_TEMPLATE = """
622
677
if (egr::Controller::Instance().UseLayoutAutoTune()) {{
@@ -1562,6 +1617,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
1562
1617
1563
1618
amp_inputs_call_list = ["" for i in range (num_inputs )]
1564
1619
type_promote_inputs_call_list = ["" for i in range (num_inputs )]
1620
+ type_autocast_inputs_call_list = ["" for i in range (num_inputs )]
1565
1621
amp_tensors_vector_list = []
1566
1622
amp_tensors_vector_optional_list = []
1567
1623
amp_autocast_list = []
@@ -1590,6 +1646,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
1590
1646
type_promote_inputs_call_list [pos ] = f"new_{ name } "
1591
1647
else :
1592
1648
type_promote_inputs_call_list [pos ] = f"{ name } "
1649
+ if forward_api_name in type_autocast_op_list :
1650
+ if name in type_autocast_op_list [forward_api_name ]:
1651
+ type_autocast_inputs_call_list [pos ] = f"new_{ name } "
1652
+ else :
1653
+ type_autocast_inputs_call_list [pos ] = f"{ name } "
1593
1654
if IsPlainTensorType (ttype ):
1594
1655
if is_optional :
1595
1656
if (
@@ -1681,6 +1742,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
1681
1742
inputs_call_list [pos ] = name
1682
1743
amp_inputs_call_list [pos ] = name
1683
1744
type_promote_inputs_call_list [pos ] = name
1745
+ type_autocast_inputs_call_list [pos ] = name
1684
1746
if default_val is not None :
1685
1747
inputs_args_declaration_list [pos ] = (
1686
1748
f"{ atype } { name } = { default_val } "
@@ -1970,6 +2032,31 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
1970
2032
)
1971
2033
else :
1972
2034
type_promotion_logic_str = f'\n VLOG(5) << " No Type Promotion for { forward_ad_function_name } api. "; '
2035
+
2036
+ # Forward type autocast logic
2037
+ if forward_api_name in type_autocast_op_list :
2038
+ # only support one inputs
2039
+ op_func_name = f'"{ forward_api_name } "'
2040
+ x = type_autocast_op_list [forward_api_name ][0 ]
2041
+ type_autocast_inputs_call_args_str = ", " .join (
2042
+ type_autocast_inputs_call_list
2043
+ )
2044
+ trace_backward = (
2045
+ forward_api_name in type_autocast_valid_grad_op_list
2046
+ ) and (not self .is_forward_only )
2047
+ trace_backward = str (trace_backward ).lower ()
2048
+ type_autocast_call_list = f"return { forward_ad_function_name } ({ type_autocast_inputs_call_args_str } );"
2049
+
2050
+ type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE .format (
2051
+ op_func_name = op_func_name ,
2052
+ x = x ,
2053
+ op_name = kernel_trans2_op_name_str ,
2054
+ trace_backward = trace_backward ,
2055
+ return_value = type_autocast_call_list ,
2056
+ )
2057
+ else :
2058
+ type_autocast_logic_str = f'\n VLOG(5) << " No Type Autocast for { forward_ad_function_name } api. "; '
2059
+
1973
2060
# Forward layout autotune
1974
2061
layout_autotune_list_str = " " .join (
1975
2062
layout_autotune_list
@@ -2019,6 +2106,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
2019
2106
dygraph_event_str ,
2020
2107
amp_logic_str ,
2021
2108
type_promotion_logic_str ,
2109
+ type_autocast_logic_str ,
2022
2110
layout_logic_str ,
2023
2111
forward_api_name ,
2024
2112
before_log_str ,
@@ -2043,6 +2131,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
2043
2131
dygraph_event_str ,
2044
2132
amp_logic_str ,
2045
2133
type_promotion_logic_str ,
2134
+ type_autocast_logic_str ,
2046
2135
layout_logic_str ,
2047
2136
inputs_autograd_meta_str ,
2048
2137
forward_api_name ,
0 commit comments