Skip to content

Commit fa0f838

Browse files
committed
add autocast logic
1 parent b2a43a7 commit fa0f838

File tree

6 files changed

+809
-81
lines changed

6 files changed

+809
-81
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

+89
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,48 @@
130130
"remainder_": ["x", "y"],
131131
}
132132

133+
# ops support casting int tensor into float32 to do forward calculation
134+
type_autocast_op_list = {
135+
"acos": ["x"],
136+
"acosh": ["x"],
137+
"asin": ["x"],
138+
"asinh": ["x"],
139+
"atan": ["x"],
140+
"atanh": ["x"],
141+
"ceil": ["x"],
142+
"cos": ["x"],
143+
"cosh": ["x"],
144+
"digamma": ["x"],
145+
"erf": ["x"],
146+
"erfinv": ["x"],
147+
"floor": ["x"],
148+
"i0": ["x"],
149+
"i0e": ["x"],
150+
"i1": ["x"],
151+
"i1e": ["x"],
152+
"lgamma": ["x"],
153+
"logcumsumexp": ["x"],
154+
"logit": ["x"],
155+
"logsumexp": ["x"],
156+
"polygamma": ["x"],
157+
"reciprocal": ["x"],
158+
"rsqrt": ["x"],
159+
"sigmoid": ["x"],
160+
"sin": ["x"],
161+
"sinh": ["x"],
162+
"sqrt": ["x"],
163+
"stanh": ["x"],
164+
"tan": ["x"],
165+
"tanh": ["x"],
166+
}
167+
168+
# ops support casting int tensor into float32 to do forward calculation,
169+
# and it is valid to cast float32 gradient back to int tensor.
170+
type_autocast_valid_grad_op_list = {
171+
"ceil",
172+
"floor",
173+
}
174+
133175
# dict of special api that forward api's output will affect backward api's output
134176
# backward api's output usually affected by backward api's input
135177

@@ -326,6 +368,8 @@ class {} : public egr::GradNodeBase {{
326368
// AMP Logic
327369
{}
328370
// Type promotion Logic
371+
{}
372+
// Type autocast Logic
329373
{}
330374
// Layout autotune
331375
{}
@@ -403,6 +447,8 @@ class {} : public egr::GradNodeBase {{
403447
// AMP Logic
404448
{}
405449
// Type promotion Logic
450+
{}
451+
// Type autocast Logic
406452
{}
407453
// Layout autotune
408454
{}
@@ -617,6 +663,15 @@ class {} : public egr::GradNodeBase {{
617663
}}
618664
"""
619665

666+
TYPE_AUTOCAST_LOGIC_TEMPLATE = """
667+
if (phi::NeedTypeAutoCast({op_func_name}, {x}.dtype())) {{
668+
VLOG(5) << "math operation got integer input data type, run type autocast.";
669+
LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
670+
{op_name}
671+
auto new_{x} = egr::PromoteCast("{x}", {x}, phi::DataType::FLOAT32, {trace_backward});
672+
{return_value}
673+
}}
674+
"""
620675

621676
LAYOUT_LOGIC_TEMPLATE = """
622677
if (egr::Controller::Instance().UseLayoutAutoTune()) {{
@@ -1562,6 +1617,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
15621617

15631618
amp_inputs_call_list = ["" for i in range(num_inputs)]
15641619
type_promote_inputs_call_list = ["" for i in range(num_inputs)]
1620+
type_autocast_inputs_call_list = ["" for i in range(num_inputs)]
15651621
amp_tensors_vector_list = []
15661622
amp_tensors_vector_optional_list = []
15671623
amp_autocast_list = []
@@ -1590,6 +1646,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
15901646
type_promote_inputs_call_list[pos] = f"new_{name}"
15911647
else:
15921648
type_promote_inputs_call_list[pos] = f"{name}"
1649+
if forward_api_name in type_autocast_op_list:
1650+
if name in type_autocast_op_list[forward_api_name]:
1651+
type_autocast_inputs_call_list[pos] = f"new_{name}"
1652+
else:
1653+
type_autocast_inputs_call_list[pos] = f"{name}"
15931654
if IsPlainTensorType(ttype):
15941655
if is_optional:
15951656
if (
@@ -1681,6 +1742,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
16811742
inputs_call_list[pos] = name
16821743
amp_inputs_call_list[pos] = name
16831744
type_promote_inputs_call_list[pos] = name
1745+
type_autocast_inputs_call_list[pos] = name
16841746
if default_val is not None:
16851747
inputs_args_declaration_list[pos] = (
16861748
f"{atype} {name} = {default_val}"
@@ -1970,6 +2032,31 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
19702032
)
19712033
else:
19722034
type_promotion_logic_str = f'\n VLOG(5) << " No Type Promotion for {forward_ad_function_name} api. "; '
2035+
2036+
# Forward type autocast logic
2037+
if forward_api_name in type_autocast_op_list:
2038+
# only support one inputs
2039+
op_func_name = f'"{forward_api_name}"'
2040+
x = type_autocast_op_list[forward_api_name][0]
2041+
type_autocast_inputs_call_args_str = ", ".join(
2042+
type_autocast_inputs_call_list
2043+
)
2044+
trace_backward = (
2045+
forward_api_name in type_autocast_valid_grad_op_list
2046+
) and (not self.is_forward_only)
2047+
trace_backward = str(trace_backward).lower()
2048+
type_autocast_call_list = f"return {forward_ad_function_name}({type_autocast_inputs_call_args_str});"
2049+
2050+
type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
2051+
op_func_name=op_func_name,
2052+
x=x,
2053+
op_name=kernel_trans2_op_name_str,
2054+
trace_backward=trace_backward,
2055+
return_value=type_autocast_call_list,
2056+
)
2057+
else:
2058+
type_autocast_logic_str = f'\n VLOG(5) << " No Type Autocast for {forward_ad_function_name} api. "; '
2059+
19732060
# Forward layout autotune
19742061
layout_autotune_list_str = " ".join(
19752062
layout_autotune_list
@@ -2019,6 +2106,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
20192106
dygraph_event_str,
20202107
amp_logic_str,
20212108
type_promotion_logic_str,
2109+
type_autocast_logic_str,
20222110
layout_logic_str,
20232111
forward_api_name,
20242112
before_log_str,
@@ -2043,6 +2131,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
20432131
dygraph_event_str,
20442132
amp_logic_str,
20452133
type_promotion_logic_str,
2134+
type_autocast_logic_str,
20462135
layout_logic_str,
20472136
inputs_autograd_meta_str,
20482137
forward_api_name,

paddle/fluid/pir/dialect/op_generator/api_gen.py

+100
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,48 @@
7272
"remainder_": ["x", "y"],
7373
}
7474

75+
# ops support casting int tensor into float32 to do forward calculation
76+
type_autocast_op_list = {
77+
"acos": ["x"],
78+
"acosh": ["x"],
79+
"asin": ["x"],
80+
"asinh": ["x"],
81+
"atan": ["x"],
82+
"atanh": ["x"],
83+
"ceil": ["x"],
84+
"cos": ["x"],
85+
"cosh": ["x"],
86+
"digamma": ["x"],
87+
"erf": ["x"],
88+
"erfinv": ["x"],
89+
"floor": ["x"],
90+
"i0": ["x"],
91+
"i0e": ["x"],
92+
"i1": ["x"],
93+
"i1e": ["x"],
94+
"lgamma": ["x"],
95+
"logcumsumexp": ["x"],
96+
"logit": ["x"],
97+
"logsumexp": ["x"],
98+
"polygamma": ["x"],
99+
"reciprocal": ["x"],
100+
"rsqrt": ["x"],
101+
"sigmoid": ["x"],
102+
"sin": ["x"],
103+
"sinh": ["x"],
104+
"sqrt": ["x"],
105+
"stanh": ["x"],
106+
"tan": ["x"],
107+
"tanh": ["x"],
108+
}
109+
110+
# ops support casting int tensor into float32 to do forward calculation,
111+
# and it is valid to cast float32 gradient back to int tensor.
112+
type_autocast_valid_grad_op_list = {
113+
"ceil",
114+
"floor",
115+
}
116+
75117
PD_MANUAL_API_LIST = {
76118
'embedding_grad',
77119
'assign',
@@ -140,6 +182,8 @@
140182
{amp_logic}
141183
// Type Promotion Logic
142184
{type_promotion_logic}
185+
// Type Autocast Logic
186+
{type_autocast_logic}
143187
{check_data_type}
144188
{handle_optional_inputs}
145189
{in_combine}
@@ -196,6 +240,18 @@
196240
}}
197241
"""
198242

243+
TYPE_AUTOCAST_LOGIC_TEMPLATE = """
244+
auto x_dtype = paddle::imperative::GetDataType({x});
245+
if (phi::NeedTypeAutoCast("{op_name}", x_dtype)) {{
246+
VLOG(5) << "math operation got integer input data type, run type autocast.";
247+
LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
248+
//{op_name}
249+
if (!{trace_backward}) {{ SetStopGradient({x}); }}
250+
auto new_{x} = pir::PromoteCast("{x}", {x}, phi::DataType::FLOAT32);
251+
return paddle::dialect::{op_name}({args});
252+
}}
253+
"""
254+
199255
OP_DISPATCH_TEMPLATE = """
200256
if ({cond}) {{
201257
{inner_code}
@@ -861,6 +917,44 @@ def _gen_type_promotion_logic(self, op_info, op_name):
861917

862918
return type_promotion_logic_str
863919

920+
def _gen_type_autocast_args(self, op_info, op_name):
921+
type_autocast_inputs_call_list = []
922+
for name in op_info.input_name_list:
923+
if op_name in type_autocast_op_list:
924+
if name in type_autocast_op_list[op_name]:
925+
type_autocast_inputs_call_list.append(f"new_{name}")
926+
else:
927+
type_autocast_inputs_call_list.append(f"{name}")
928+
929+
attr_list = op_info.attribute_name_list
930+
args = type_autocast_inputs_call_list + attr_list
931+
return ', '.join(args)
932+
933+
def _gen_type_autocast_logic(self, op_info, op_name):
934+
if op_name in type_autocast_op_list:
935+
x = type_autocast_op_list[op_name][0]
936+
937+
type_autocast_inputs_call_args_str = self._gen_type_autocast_args(
938+
op_info, op_name
939+
)
940+
trace_backward = op_name in type_autocast_valid_grad_op_list
941+
trace_backward = str(trace_backward).lower()
942+
943+
if op_info.is_sparse_op:
944+
op_name += "sp_" if op_name[-1] == "_" else "_sp"
945+
type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
946+
op_name=op_name,
947+
x=x,
948+
trace_backward=trace_backward,
949+
args=type_autocast_inputs_call_args_str,
950+
)
951+
else:
952+
type_autocast_logic_str = (
953+
f'\n VLOG(5) << " No Type Autocast for {op_name} api. "; '
954+
)
955+
956+
return type_autocast_logic_str
957+
864958
def _gen_check_data_type(self, op_info, op_name):
865959
mapping_input_name_to_type = dict(
866960
zip(op_info.input_name_list, op_info.input_type_list)
@@ -1044,6 +1138,9 @@ def _gen_one_impl(
10441138
type_promotion_logic=self._gen_type_promotion_logic(
10451139
op_info, op_name
10461140
),
1141+
type_autocast_logic=self._gen_type_autocast_logic(
1142+
op_info, op_name
1143+
),
10471144
check_data_type=self._gen_check_data_type(
10481145
op_info, kernel_name
10491146
),
@@ -1109,6 +1206,9 @@ def _gen_one_impl(
11091206
type_promotion_logic=self._gen_type_promotion_logic(
11101207
op_info, op_name
11111208
),
1209+
type_autocast_logic=self._gen_type_autocast_logic(
1210+
op_info, op_name
1211+
),
11121212
check_data_type=self._gen_check_data_type(op_info, kernel_name),
11131213
handle_optional_inputs=self._gen_handle_optional_inputs(
11141214
op_info

paddle/phi/common/type_promotion.h

+17
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ static std::unordered_set<std::string> support_promotion_ops = {
9494
"less_than", "less_equal", "greater_than", "greater_equal",
9595
};
9696

97+
static std::unordered_set<std::string> support_autocast_ops = {
98+
"acos", "acosh", "asin", "asinh", "atan", "atanh",
99+
"ceil", "cos", "cosh", "digamma", "erf", "erfinv",
100+
"floor", "lgamma", "logcumsumexp", "logit", "logsumexp", "polygamma",
101+
"reciprocal", "rsqrt", "sin", "sinh", "sqrt", "stanh",
102+
"tan", "tanh", "i0", "i0e", "i1", "i1e",
103+
"sigmoid",
104+
};
105+
97106
inline bool is_support_float(DataType dtype) {
98107
if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 ||
99108
dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) {
@@ -264,4 +273,12 @@ inline bool NeedTypePromotionOldIr(const std::string& op_name,
264273
}
265274
}
266275

276+
inline bool NeedTypeAutoCast(const std::string& op_name,
277+
const DataType& x_dtype) {
278+
if (support_autocast_ops.find(op_name) != support_autocast_ops.end() &&
279+
(is_support_int(x_dtype))) {
280+
return true;
281+
}
282+
return false;
283+
}
267284
} // namespace phi

0 commit comments

Comments
 (0)