Skip to content

Commit 75669e3

Browse files
kangguangliBeingGod
authored andcommitted
[NewIR] register set_value in new ir (PaddlePaddle#56436)
* register set_value in new ir * fix * register set_value_grad * fix * fix * remove debug info * add unittest * fix * fix * fix * fix * fix * resolve comments
1 parent 957d59d commit 75669e3

File tree

14 files changed

+431
-69
lines changed

14 files changed

+431
-69
lines changed

paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h"
18+
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
1819
#include "paddle/ir/core/attribute.h"
1920
#include "paddle/ir/core/builtin_attribute.h"
2021
#include "paddle/phi/common/scalar.h"
@@ -49,6 +50,10 @@ class ScalarAttribute : public ir::Attribute {
4950
(val.type_id() == ir::StrAttribute::type_id());
5051
}
5152

53+
static ir::Attribute get(ir::IrContext *ctx, phi::Scalar scalar) {
54+
return TransToIrAttribute(scalar, ctx);
55+
}
56+
5257
phi::Scalar data();
5358
};
5459

paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@
329329
view: null
330330
backward: null
331331

332-
333332
- name: shadow_feed
334333
inputs:
335334
- typename: Tensor
@@ -355,3 +354,72 @@
355354
force_backend: null
356355
inplace: null
357356
backward: null
357+
358+
- name : set_value
359+
inputs:
360+
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
361+
attrs:
362+
- {typename: 'int64_t[]', name: starts}
363+
- {typename: 'int64_t[]', name: ends}
364+
- {typename: 'int64_t[]', name: steps}
365+
- {typename: 'int64_t[]', name: axes}
366+
- {typename: 'int64_t[]', name: decrease_axes}
367+
- {typename: 'int64_t[]', name: none_axes}
368+
- {typename: 'int64_t[]', name: shape}
369+
- {typename: 'Scalar[]', name: values}
370+
outputs:
371+
- {typename: Tensor, name: out, optional: false, intermediate: false}
372+
infer_meta:
373+
func: SetValueInferMeta
374+
param: [x]
375+
kernel:
376+
func: [set_value]
377+
param: [x, starts, ends, steps, axes, decrease_axes, none_axes, shape, values]
378+
inplace: {out: x}
379+
backward: set_value_grad
380+
381+
- name : set_value_with_tensor
382+
inputs:
383+
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
384+
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
385+
attrs:
386+
- {typename: 'int64_t[]', name: starts}
387+
- {typename: 'int64_t[]', name: ends}
388+
- {typename: 'int64_t[]', name: steps}
389+
- {typename: 'int64_t[]', name: axes}
390+
- {typename: 'int64_t[]', name: decrease_axes}
391+
- {typename: 'int64_t[]', name: none_axes}
392+
outputs:
393+
- {typename: Tensor, name: out, optional: false, intermediate: false}
394+
infer_meta:
395+
func: SetValueInferMeta
396+
param: [x]
397+
kernel:
398+
func: [set_value_with_tensor]
399+
param: [x, values, starts, ends, steps, axes, decrease_axes, none_axes]
400+
inplace: {out: x}
401+
backward: set_value_grad
402+
403+
404+
- name : set_value_grad
405+
inputs:
406+
- {typename: Tensor, name: out_grad, optional: false, no_need_buffer: false, data_transform: {} }
407+
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
408+
attrs:
409+
- {typename: 'int64_t[]', name: starts}
410+
- {typename: 'int64_t[]', name: ends}
411+
- {typename: 'int64_t[]', name: steps}
412+
- {typename: 'int64_t[]', name: axes}
413+
- {typename: 'int64_t[]', name: decrease_axes}
414+
- {typename: 'int64_t[]', name: none_axes}
415+
outputs:
416+
- {typename: Tensor, name: x_grad, optional: false, intermediate: false}
417+
- {typename: Tensor, name: values_grad, optional: false, intermediate: false}
418+
infer_meta:
419+
func: SetValueGradInferMeta
420+
param: [out_grad, values]
421+
kernel:
422+
func: [set_value_grad]
423+
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
424+
inplace: null
425+
backward: null

paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
16+
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
1617

1718
namespace paddle {
1819
namespace dialect {

paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
// #include "paddle/fluid/framework/convert_utils.h"
1818
#include "paddle/fluid/framework/data_type.h"
19-
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
2019
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
2120
#include "paddle/ir/core/builtin_attribute.h"
2221
#include "paddle/ir/core/builtin_type.h"

paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h"
3838
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h"
3939
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
40+
#include "paddle/ir/core/type_name.h"
4041
#include "paddle/phi/core/infermeta_utils.h"
4142

4243
#include "glog/logging.h"
@@ -81,8 +82,8 @@ void BuildPhiContext(ir::Operation* op,
8182
Context* ctx) {
8283
paddle::framework::Scope* inner_scope =
8384
local_scope != nullptr ? local_scope : scope;
84-
VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope["
85-
<< inner_scope << "]";
85+
VLOG(6) << "Build " << get_type_name<Context>() << " in scope[" << scope
86+
<< "] inner_scope[" << inner_scope << "]";
8687

8788
auto attr_map = op->attributes();
8889

paddle/fluid/ir_adaptor/translator/attribute_translator.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class AttributeVisitor {
113113
}
114114

115115
virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
116-
VLOG(10) << "translating vector<int64>";
116+
VLOG(10) << "translating vector<int64> size: " << i64s.size();
117117
std::vector<ir::Attribute> attrs;
118118
attrs.reserve(i64s.size());
119119
for (const auto& v : i64s) {
@@ -135,8 +135,13 @@ class AttributeVisitor {
135135
virtual ir::Attribute operator()(
136136
const std::vector<paddle::experimental::Scalar>& ss) {
137137
VLOG(10) << "translating vector<scalar>";
138-
IR_THROW(
139-
"not support translating std::vector<paddle::experimental::Scalar>");
138+
std::vector<ir::Attribute> attrs;
139+
attrs.reserve(ss.size());
140+
for (const auto& v : ss) {
141+
attrs.push_back(dialect::ScalarAttribute::get(ctx, v));
142+
}
143+
VLOG(10) << "translating vector<scalar> Done";
144+
return ir::ArrayAttribute::get(ctx, attrs);
140145
}
141146

142147
virtual ir::Attribute operator()(const paddle::blank& blank) {
@@ -164,6 +169,11 @@ class Int64ArrayAttributeVisitor : public AttributeVisitor {
164169
}
165170
return ir::ArrayAttribute::get(ctx, attrs);
166171
}
172+
173+
ir::Attribute operator()(const paddle::blank& blank) override {
174+
VLOG(10) << "translating paddle::blank to int64[]";
175+
return ir::ArrayAttribute::get(ctx, {});
176+
}
167177
};
168178

169179
class IntArrayAttributeVisitor : public AttributeVisitor {

paddle/fluid/ir_adaptor/translator/op_compat_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def insert_new_mutable_attributes(
126126
backward_op, op_compat_item["scalar"]
127127
)
128128

129+
# special mapping list
130+
op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD"
131+
129132
op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
130133
with open(output_source_file, 'wt') as f:
131134
op_compat_definition = op_name_normailzer_template.render(

paddle/fluid/ir_adaptor/translator/op_compat_info.h

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include <functional>
16+
#include <optional>
1617
#include <string>
1718
#include <unordered_map>
1819
#include <unordered_set>
@@ -75,42 +76,66 @@ class OpNameNormalizer {
7576
return op_mutable_attribute_infos.at(op_type).at(arg_name);
7677
}
7778

79+
std::optional<std::string> GetDirectMapping(const std::string& op_type,
80+
const std::string& arg_name) {
81+
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
82+
return {};
83+
}
84+
auto& arg_mappings = op_arg_name_mappings[op_type];
85+
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
86+
return {};
87+
}
88+
return arg_mappings.at(arg_name);
89+
}
90+
91+
std::optional<std::string> GetGradNameMapping(const std::string& op_type,
92+
const std::string& arg_name) {
93+
std::string target = kPhiGradSuffix;
94+
std::string data = kFluidVarGradSuffix;
95+
96+
size_t first_grad_pos = arg_name.find(target);
97+
size_t type_pos = op_type.find(target);
98+
std::string legacy_name = arg_name.substr(0, first_grad_pos);
99+
std::optional<std::string> ret =
100+
this->GetDirectMapping(op_type.substr(0, type_pos), legacy_name);
101+
if (ret) {
102+
legacy_name = ret.value();
103+
}
104+
legacy_name = legacy_name + arg_name.substr(first_grad_pos);
105+
for (size_t pos = 0;
106+
legacy_name.npos != (pos = legacy_name.find(target, pos));
107+
pos += data.length()) {
108+
legacy_name.replace(pos, target.length(), data);
109+
}
110+
return legacy_name;
111+
}
112+
78113
std::string GetLegacyArgName(const std::string& op_type,
79114
const std::string& arg_name) {
115+
if (auto ret = GetDirectMapping(op_type, arg_name)) {
116+
VLOG(10) << "[" << op_type << "] found " << ret.value();
117+
return ret.value();
118+
}
119+
80120
bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos);
81121
bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos);
82122

83123
if (is_grad_op && is_grad_arg) {
84-
std::string target = kPhiGradSuffix;
85-
std::string data = kFluidVarGradSuffix;
86-
87-
size_t first_grad_pos = arg_name.find(target);
88-
size_t type_pos = op_type.find(target);
89-
std::string legacy_name = this->GetLegacyArgName(
90-
op_type.substr(0, type_pos), arg_name.substr(0, first_grad_pos));
91-
legacy_name += arg_name.substr(first_grad_pos);
92-
for (size_t pos = 0;
93-
legacy_name.npos != (pos = legacy_name.find(target, pos));
94-
pos += data.length()) {
95-
legacy_name.replace(pos, target.length(), data);
124+
if (auto ret = GetGradNameMapping(op_type, arg_name)) {
125+
VLOG(10) << "[" << op_type << "] found " << ret.value();
126+
return ret.value();
96127
}
97-
return legacy_name;
98128
} else if (is_grad_op && !is_grad_arg) {
99129
// backwward op using forward args: like trace_grad using forward input
100130
size_t type_pos = op_type.find(kPhiGradSuffix);
101-
std::string legacy_name =
102-
this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name);
103-
104-
return legacy_name;
105-
}
106-
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
107-
return arg_name;
108-
}
109-
auto& arg_mappings = op_arg_name_mappings[op_type];
110-
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
111-
return arg_name;
131+
if (auto ret = GetDirectMapping(op_type.substr(0, type_pos), arg_name)) {
132+
VLOG(10) << "[" << op_type << "] found " << ret.value();
133+
return ret.value();
134+
}
112135
}
113-
return arg_mappings.at(arg_name);
136+
137+
VLOG(10) << "[" << op_type << "] not found mapping for " << arg_name;
138+
return arg_name;
114139
}
115140

116141
std::string GetLegacyAttrName(const std::string& op_type,

0 commit comments

Comments
 (0)