Skip to content

Commit 91a3d15

Browse files
Support 'complex promote' in yaml (#50611)
* support 'complex promote' in yaml * change the compplex_promote * change 'kron' in math.py * change 'kron' comment in python * change kron comment in python * change kron comment in python
1 parent ff4ec23 commit 91a3d15

File tree

10 files changed

+92
-231
lines changed

10 files changed

+92
-231
lines changed

paddle/fluid/operators/generator/generate_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,11 @@ def add_grad_op_compat_name(grad_op_item, args_name_map):
293293
if new_op_name != op_name:
294294
forward_op_item['op_name'] = op_name
295295

296+
# add complex promote infomation
297+
if "complex_promote" in op_args:
298+
forward_op_item["complex_promote"] = op_args["complex_promote"]
299+
if has_backward:
300+
backward_op_item["complex_promote"] = op_args["complex_promote"]
296301
scalar_configs = None
297302
int_array_configs = None
298303
if 'scalar' in op_args:

paddle/fluid/operators/generator/templates/operator_utils.c.j2

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -279,30 +279,52 @@ phi::KernelKey GetExpectedKernelType(
279279
data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}});
280280
}
281281
{% endif %}
282+
{% elif "complex_promote" in op and "forward" not in op%}
283+
{% set inputs = op["complex_promote"]%}
284+
auto data_type =
285+
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "{{inputs[0]}}", "{{inputs[1]}}");
282286
{% endif %}
283287
return phi::KernelKey(data_type, ctx.GetPlace());
284288
}
285-
{% endmacro %}
289+
{% endmacro -%}
290+
291+
{% macro get_kernel_for_var(op) %}
292+
{% set skip_args = none %}
293+
{% if op["data_transform"] is not none%}
294+
{% if "skip_transform" in op["data_transform"] %}
295+
{% set skip_args = op["data_transform"]["skip_transform"] %}
296+
{% elif "support_trans_dtype" in op["data_transform"] %}
297+
{% set skip_args = op["data_transform"]["support_trans_dtype"] %}
298+
{% endif %}
299+
{% endif %}
300+
{% set var_name = "var_name" -%}
286301

287-
{% macro get_kernel_for_var(op) %} {# only for data_transform #}
288-
{% set skip_args = op["data_transform"]["skip_transform"] %}
289-
{% set var_name = "var_name" %}
290-
{% set skip_args_len = skip_args | length %}
291302
phi::KernelKey GetKernelTypeForVar(
292303
const std::string& {{var_name}},
293304
const phi::DenseTensor& tensor,
294305
const phi::KernelKey& expected_kernel_type) const override {
295-
306+
{%if skip_args is not none%}{# deal data_transform #}
307+
{% set skip_args_len = skip_args | length %}
296308
if (
297309
{%- for skip_arg in skip_args -%}
298310
var_name == "{{ skip_arg }}"
299311
{%- if skip_args_len != 1 and loop.index != skip_args_len %} || {% endif -%}
300312
{%- endfor -%}
301313
){
314+
{% if "skip_transform" in op["data_transform"] %}
302315
return phi::KernelKey(phi::Backend::ALL_BACKEND,
303316
expected_kernel_type.layout(),
304317
expected_kernel_type.dtype());
318+
{% elif "support_trans_dtype" in op["data_transform"] %}
319+
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
320+
{% endif %}
321+
}
322+
{% else %}{# deal complex_promote #}
323+
if (framework::IsComplexType(expected_kernel_type.dtype())) {
324+
// only promote inputs’s types when contains complex input
325+
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
305326
}
327+
{% endif %}
306328
else{
307329
return phi::KernelKey(
308330
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
@@ -317,20 +339,23 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne
317339
using framework::OperatorWithKernel::OperatorWithKernel;
318340
{# ----------- get expected kernel type function -------------------------- #}
319341
{% set kernel = op["kernel"] %}
320-
{% if kernel["data_type"] is not none %}
342+
{% if kernel["data_type"] is not none or "complex_promote" in op or "data_transform" in op%}
321343
protected:
322-
{% filter indent(2, True)%}
344+
{% if kernel["data_type"] is not none or "complex_promote" in op %}
345+
{% filter indent(2, True)%}
323346
{{get_expected_kernel(op)}}
324-
{% endfilter %}
325-
{%- if "data_transform" in op and op["data_transform"] is not none -%}
326-
{%- if "skip_transform" in op["data_transform"] -%}
327-
{% filter indent(2, True) %}
347+
{% endfilter %}
348+
{% endif %}
349+
{% endif %}
350+
{%- if "data_transform" in op and op["data_transform"] is not none -%}
351+
{% filter indent(2, True) %}
352+
{{get_kernel_for_var(op)}}
353+
{% endfilter %}
354+
{%- elif "complex_promote" in op and op["complex_promote"] is not none -%}
355+
{% filter indent(2, True) %}
328356
{{get_kernel_for_var(op)}}
329357
{% endfilter %}
330358
{%- endif %}
331-
{%- endif -%}
332-
{# TODO(lizhiyu): add the 'support_trans_dtype' #}
333-
{% endif %}
334359
};
335360

336361
DECLARE_INFER_SHAPE_FUNCTOR({{op["op_name"]}}, {{op["op_name"] | to_pascal_case}}InferShapeFunctor,

paddle/fluid/operators/kron_op.cc

Lines changed: 0 additions & 166 deletions
This file was deleted.

paddle/phi/api/yaml/backward.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,17 @@
663663
kernel :
664664
func : inverse_grad
665665

666+
- backward_op : kron_grad
667+
forward : kron (Tensor x, Tensor y) -> Tensor(out)
668+
args : (Tensor x, Tensor y, Tensor out_grad)
669+
output : Tensor(x_grad), Tensor(y_grad)
670+
infer_meta :
671+
func : GeneralBinaryGradInferMeta
672+
param : [x, y]
673+
kernel :
674+
func : kron_grad
675+
data_type : out_grad
676+
666677
- backward_op : kthvalue_grad
667678
forward : kthvalue(Tensor x, int k, int axis, bool keepdim) -> Tensor(out), Tensor(indices)
668679
args : (Tensor x, Tensor indices, Tensor out_grad, int k, int axis, bool keepdim)

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -636,17 +636,6 @@
636636
func : kldiv_loss_grad
637637
no_need_buffer : x
638638

639-
- backward_op : kron_grad
640-
forward : kron (Tensor x, Tensor y) -> Tensor(out)
641-
args : (Tensor x, Tensor y, Tensor out_grad)
642-
output : Tensor(x_grad), Tensor(y_grad)
643-
infer_meta :
644-
func : GeneralBinaryGradInferMeta
645-
param : [x, y]
646-
kernel :
647-
func : kron_grad
648-
data_type : out_grad
649-
650639
- backward_op : layer_norm_grad
651640
forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance)
652641
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, Tensor out_grad, float epsilon, int begin_norm_axis)

paddle/phi/api/yaml/legacy_ops.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -921,15 +921,6 @@
921921
data_type : x
922922
backward : kldiv_loss_grad
923923

924-
- op : kron
925-
args : (Tensor x, Tensor y)
926-
output : Tensor
927-
infer_meta :
928-
func : KronInferMeta
929-
kernel :
930-
func : kron
931-
backward : kron_grad
932-
933924
- op : lamb_
934925
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision)
935926
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,14 @@
828828
outputs :
829829
out : Out
830830

831+
- op : kron
832+
backward : kron_grad
833+
inputs :
834+
{x : X, y : Y}
835+
outputs :
836+
{out : Out}
837+
complex_promote : [X, Y]
838+
831839
- op : kthvalue
832840
inputs :
833841
x : X

paddle/phi/api/yaml/ops.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,15 @@
653653
func : isnan {dense -> dense},
654654
isnan_sr {selected_rows -> selected_rows}
655655

656+
- op : kron
657+
args : (Tensor x, Tensor y)
658+
output : Tensor
659+
infer_meta :
660+
func : KronInferMeta
661+
kernel :
662+
func : kron
663+
backward : kron_grad
664+
656665
- op : kthvalue
657666
args : (Tensor x, int k = 1, int axis = -1, bool keepdim = false)
658667
output : Tensor(out), Tensor(indices)

paddle/phi/ops/compat/kron_sig.cc

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)