Skip to content

math API support int tensor autocast to float32 易用性提升 #69252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,48 @@
"remainder_": ["x", "y"],
}

# ops support casting int tensor into float32 to do forward calculation
type_autocast_op_list = {
"acos": ["x"],
"acosh": ["x"],
"asin": ["x"],
"asinh": ["x"],
"atan": ["x"],
"atanh": ["x"],
"ceil": ["x"],
"cos": ["x"],
"cosh": ["x"],
"digamma": ["x"],
"erf": ["x"],
"erfinv": ["x"],
"floor": ["x"],
"i0": ["x"],
"i0e": ["x"],
"i1": ["x"],
"i1e": ["x"],
"lgamma": ["x"],
"logcumsumexp": ["x"],
"logit": ["x"],
"logsumexp": ["x"],
"polygamma": ["x"],
"reciprocal": ["x"],
"rsqrt": ["x"],
"sigmoid": ["x"],
"sin": ["x"],
"sinh": ["x"],
"sqrt": ["x"],
"stanh": ["x"],
"tan": ["x"],
"tanh": ["x"],
}

# ops support casting int tensor into float32 to do forward calculation,
# and it is valid to cast float32 gradient back to int tensor.
type_autocast_valid_grad_op_list = {
"ceil",
"floor",
}

# dict of special api that forward api's output will affect backward api's output
# backward api's output usually affected by backward api's input

Expand Down Expand Up @@ -327,6 +369,8 @@ class {} : public egr::GradNodeBase {{
// AMP Logic
{}
// Type promotion Logic
{}
// Type autocast Logic
{}
// Layout autotune
{}
Expand Down Expand Up @@ -404,6 +448,8 @@ class {} : public egr::GradNodeBase {{
// AMP Logic
{}
// Type promotion Logic
{}
// Type autocast Logic
{}
// Layout autotune
{}
Expand Down Expand Up @@ -618,6 +664,15 @@ class {} : public egr::GradNodeBase {{
}}
"""

TYPE_AUTOCAST_LOGIC_TEMPLATE = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有int情况下会被cast,还是所有情况下会被cast,不能引入不兼容的性能下降问题,只能对于之前不支持会报错的情况去cast

Copy link
Contributor Author

@NKNaN NKNaN Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有在指定的op列表里且输入是int的时候会cast,应该不会不兼容。CE-Framework 的报错可能需要修改 PaddleTest 中的一些单测,因为有一些单测里面包含int输入会报错的case。

在 aistudio 上跑的结果
image

在 PaddleTest 中注释掉 paddle.sin 的单测:PaddlePaddle/PaddleTest#2992

if (phi::NeedTypeAutoCast({op_func_name}, {x}.dtype())) {{
VLOG(5) << "math operation got integer input data type, run type autocast.";
LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
{op_name}
auto new_{x} = egr::PromoteCast("{x}", {x}, phi::DataType::FLOAT32, {trace_backward});
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不加判断直接转成float32可能会有精度丢失问题?比如16777217转成float32后,变成了16777216
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是有这个问题,我当时是看到torch是统一转成fp32的,没有考虑到精度丢失的问题。另外floor和ceil可能还不能这样改,这两个输出也得是整数类型了。
image

{return_value}
}}
"""

LAYOUT_LOGIC_TEMPLATE = """
if (egr::Controller::Instance().UseLayoutAutoTune()) {{
Expand Down Expand Up @@ -1563,6 +1618,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):

amp_inputs_call_list = ["" for i in range(num_inputs)]
type_promote_inputs_call_list = ["" for i in range(num_inputs)]
type_autocast_inputs_call_list = ["" for i in range(num_inputs)]
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
Expand Down Expand Up @@ -1591,6 +1647,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
if forward_api_name in type_autocast_op_list:
if name in type_autocast_op_list[forward_api_name]:
type_autocast_inputs_call_list[pos] = f"new_{name}"
else:
type_autocast_inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
if is_optional:
if (
Expand Down Expand Up @@ -1682,6 +1743,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
inputs_call_list[pos] = name
amp_inputs_call_list[pos] = name
type_promote_inputs_call_list[pos] = name
type_autocast_inputs_call_list[pos] = name
if default_val is not None:
inputs_args_declaration_list[pos] = (
f"{atype} {name} = {default_val}"
Expand Down Expand Up @@ -1971,6 +2033,31 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
)
else:
type_promotion_logic_str = f'\n VLOG(5) << " No Type Promotion for {forward_ad_function_name} api. "; '

# Forward type autocast logic
if forward_api_name in type_autocast_op_list:
# only support one inputs
op_func_name = f'"{forward_api_name}"'
x = type_autocast_op_list[forward_api_name][0]
type_autocast_inputs_call_args_str = ", ".join(
type_autocast_inputs_call_list
)
trace_backward = (
forward_api_name in type_autocast_valid_grad_op_list
) and (not self.is_forward_only)
trace_backward = str(trace_backward).lower()
type_autocast_call_list = f"return {forward_ad_function_name}({type_autocast_inputs_call_args_str});"

type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
op_name=kernel_trans2_op_name_str,
trace_backward=trace_backward,
return_value=type_autocast_call_list,
)
else:
type_autocast_logic_str = f'\n VLOG(5) << " No Type Autocast for {forward_ad_function_name} api. "; '

# Forward layout autotune
layout_autotune_list_str = " ".join(
layout_autotune_list
Expand Down Expand Up @@ -2020,6 +2107,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
dygraph_event_str,
amp_logic_str,
type_promotion_logic_str,
type_autocast_logic_str,
layout_logic_str,
forward_api_name,
before_log_str,
Expand All @@ -2044,6 +2132,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
dygraph_event_str,
amp_logic_str,
type_promotion_logic_str,
type_autocast_logic_str,
layout_logic_str,
inputs_autograd_meta_str,
forward_api_name,
Expand Down
100 changes: 100 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,48 @@
"remainder_": ["x", "y"],
}

# ops support casting int tensor into float32 to do forward calculation
type_autocast_op_list = {
"acos": ["x"],
"acosh": ["x"],
"asin": ["x"],
"asinh": ["x"],
"atan": ["x"],
"atanh": ["x"],
"ceil": ["x"],
"cos": ["x"],
"cosh": ["x"],
"digamma": ["x"],
"erf": ["x"],
"erfinv": ["x"],
"floor": ["x"],
"i0": ["x"],
"i0e": ["x"],
"i1": ["x"],
"i1e": ["x"],
"lgamma": ["x"],
"logcumsumexp": ["x"],
"logit": ["x"],
"logsumexp": ["x"],
"polygamma": ["x"],
"reciprocal": ["x"],
"rsqrt": ["x"],
"sigmoid": ["x"],
"sin": ["x"],
"sinh": ["x"],
"sqrt": ["x"],
"stanh": ["x"],
"tan": ["x"],
"tanh": ["x"],
}

# ops support casting int tensor into float32 to do forward calculation,
# and it is valid to cast float32 gradient back to int tensor.
type_autocast_valid_grad_op_list = {
"ceil",
"floor",
}

PD_MANUAL_API_LIST = {
'embedding_grad',
'assign',
Expand Down Expand Up @@ -140,6 +182,8 @@
{amp_logic}
// Type Promotion Logic
{type_promotion_logic}
// Type Autocast Logic
{type_autocast_logic}
{check_data_type}
{handle_optional_inputs}
{in_combine}
Expand Down Expand Up @@ -196,6 +240,18 @@
}}
"""

TYPE_AUTOCAST_LOGIC_TEMPLATE = """
auto x_dtype = paddle::imperative::GetDataType({x});
if (phi::NeedTypeAutoCast("{op_name}", x_dtype)) {{
VLOG(5) << "math operation got integer input data type, run type autocast.";
LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
//{op_name}
if (!{trace_backward}) {{ SetStopGradient({x}); }}
auto new_{x} = pir::PromoteCast("{x}", {x}, phi::DataType::FLOAT32);
return paddle::dialect::{op_name}({args});
}}
"""

OP_DISPATCH_TEMPLATE = """
if ({cond}) {{
{inner_code}
Expand Down Expand Up @@ -861,6 +917,44 @@ def _gen_type_promotion_logic(self, op_info, op_name):

return type_promotion_logic_str

def _gen_type_autocast_args(self, op_info, op_name):
type_autocast_inputs_call_list = []
for name in op_info.input_name_list:
if op_name in type_autocast_op_list:
if name in type_autocast_op_list[op_name]:
type_autocast_inputs_call_list.append(f"new_{name}")
else:
type_autocast_inputs_call_list.append(f"{name}")

attr_list = op_info.attribute_name_list
args = type_autocast_inputs_call_list + attr_list
return ', '.join(args)

def _gen_type_autocast_logic(self, op_info, op_name):
if op_name in type_autocast_op_list:
x = type_autocast_op_list[op_name][0]

type_autocast_inputs_call_args_str = self._gen_type_autocast_args(
op_info, op_name
)
trace_backward = op_name in type_autocast_valid_grad_op_list
trace_backward = str(trace_backward).lower()

if op_info.is_sparse_op:
op_name += "sp_" if op_name[-1] == "_" else "_sp"
type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
op_name=op_name,
x=x,
trace_backward=trace_backward,
args=type_autocast_inputs_call_args_str,
)
else:
type_autocast_logic_str = (
f'\n VLOG(5) << " No Type Autocast for {op_name} api. "; '
)

return type_autocast_logic_str

def _gen_check_data_type(self, op_info, op_name):
mapping_input_name_to_type = dict(
zip(op_info.input_name_list, op_info.input_type_list)
Expand Down Expand Up @@ -1044,6 +1138,9 @@ def _gen_one_impl(
type_promotion_logic=self._gen_type_promotion_logic(
op_info, op_name
),
type_autocast_logic=self._gen_type_autocast_logic(
op_info, op_name
),
check_data_type=self._gen_check_data_type(
op_info, kernel_name
),
Expand Down Expand Up @@ -1109,6 +1206,9 @@ def _gen_one_impl(
type_promotion_logic=self._gen_type_promotion_logic(
op_info, op_name
),
type_autocast_logic=self._gen_type_autocast_logic(
op_info, op_name
),
check_data_type=self._gen_check_data_type(op_info, kernel_name),
handle_optional_inputs=self._gen_handle_optional_inputs(
op_info
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/common/type_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ static std::unordered_set<std::string> support_promotion_ops = {
"less_than", "less_equal", "greater_than", "greater_equal",
};

static std::unordered_set<std::string> support_autocast_ops = {
"acos", "acosh", "asin", "asinh", "atan", "atanh",
"ceil", "cos", "cosh", "digamma", "erf", "erfinv",
"floor", "lgamma", "logcumsumexp", "logit", "logsumexp", "polygamma",
"reciprocal", "rsqrt", "sin", "sinh", "sqrt", "stanh",
"tan", "tanh", "i0", "i0e", "i1", "i1e",
"sigmoid",
};

inline bool is_support_float(DataType dtype) {
if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 ||
dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) {
Expand Down Expand Up @@ -264,4 +273,12 @@ inline bool NeedTypePromotionOldIr(const std::string& op_name,
}
}

inline bool NeedTypeAutoCast(const std::string& op_name,
const DataType& x_dtype) {
if (support_autocast_ops.find(op_name) != support_autocast_ops.end() &&
(is_support_int(x_dtype))) {
return true;
}
return false;
}
} // namespace phi
Loading