Skip to content

Commit d7a32fa

Browse files
[Prim][PIR]Add composite rule of tile for both static and dynamic shape (PaddlePaddle#61571)
* support dynamic shape of tile and full_like * fix confict bug * add squeeze and unsqueeze as primitive op * fix dy shape * fix code * fix code * debug * debug info * fix code * fix bug * remove unused code * skip rank4 case * fix check op
1 parent 5b8f331 commit d7a32fa

File tree

5 files changed

+178
-15
lines changed

5 files changed

+178
-15
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"squeeze",
4848
"stack",
4949
"unsqueeze",
50+
"tile",
5051
]
5152

5253
# come into effect in generated file op_decomp.cc
@@ -77,6 +78,7 @@
7778
"squeeze",
7879
"stack",
7980
"unsqueeze",
81+
"tile",
8082
]
8183

8284

paddle/fluid/primitive/base/decomp_trans.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ std::unordered_set<std::string> decomp_op_contain_none = {"pd_op.squeeze",
3838
"pd_op.flatten",
3939
"pd_op.batch_norm",
4040
"pd_op.batch_norm_"};
41+
//
42+
std::unordered_set<std::string> dynamic_shape_blacklist = {"pd_op.squeeze",
43+
"pd_op.unsqueeze"};
4144

4245
static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
4346
if (std::find(vec.begin(), vec.end(), value) != vec.end()) {
@@ -48,6 +51,9 @@ static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
4851
}
4952

5053
static const phi::DDim& GetValueDims(pir::Value value) {
54+
if (!value.type()) {
55+
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
56+
}
5157
if (value.type().isa<DenseTensorType>()) {
5258
return value.type().dyn_cast<DenseTensorType>().dims();
5359
} else if (value.type().isa<SelectedRowsType>()) {
@@ -101,7 +107,7 @@ bool DecompProgram::check_decomp_dynamic_shape(pir::Operation* op) {
101107
// check if initialized in case of optional input.
102108
if (!paddle::dialect::IsEmptyValue(value)) {
103109
pir::Operation* prev_op = value.defining_op();
104-
if (prev_op->name() == "builtin.combine") {
110+
if (prev_op && prev_op->name() == "builtin.combine") {
105111
for (pir::OpOperand& sub_item : prev_op->operands()) {
106112
if (check_dynamic_shape(sub_item, *op)) {
107113
return true;
@@ -336,6 +342,11 @@ void DecompProgram::decomp_block(
336342
check_decomp_dynamic_shape(op)) {
337343
enable_prim = false;
338344
}
345+
if (enable_prim && check_decomp_dynamic_shape(op) &&
346+
dynamic_shape_blacklist.find(op->name()) !=
347+
dynamic_shape_blacklist.end()) {
348+
enable_prim = false;
349+
}
339350
if (enable_prim) {
340351
VLOG(4) << "[Prim] decomp op name " << op->name();
341352
check_decomp_dynamic_shape(op);

paddle/fluid/primitive/composite/composite.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include <numeric>
1718
#include "paddle/fluid/primitive/primitive/primitive.h"
1819
#include "paddle/fluid/primitive/type/lazy_tensor.h"
1920
#include "paddle/fluid/primitive/utils/utils.h"
@@ -25,6 +26,11 @@ namespace details {
2526
// empty_shape means x.shape=[]
2627
static std::vector<int64_t> empty_shape;
2728

29+
template <typename T>
30+
static Tensor get_slice(const Tensor& x, int64_t idx) {
31+
return slice<T>(x, {0}, {idx}, {idx + 1}, {1}, {});
32+
}
33+
2834
template <typename T>
2935
Tensor any_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
3036
auto org_dtype = x.dtype();
@@ -825,6 +831,93 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
825831
return std::make_tuple(out, mean_out, var_out);
826832
}
827833

834+
template <typename T>
835+
Tensor tile_decomp(const Tensor& x, const IntArray& repeat_times) {
836+
// x.shape = [3,4] repeat_time=(a,b,c)
837+
// shape1 = [1,3,4]
838+
// shape2 = [1,1,1,3,1,4]
839+
// shape3 = [a,1,b,3,c,4]
840+
// shape4 = shape1 -> [a, b*3, c*4]
841+
// t1 = x.reshape(shape1)
842+
// t2 = t1.reshape(shape2)
843+
// t3 = t2.expand(shape3)
844+
// res = t3.reshape(t3)
845+
std::vector<int64_t> repeat_times_ = repeat_times.GetData();
846+
std::vector<int64_t> shape1 = common::vectorize<int64_t>(x.dims());
847+
auto diff = int64_t(repeat_times_.size()) - int64_t(shape1.size());
848+
Tensor t1;
849+
if (find_value(shape1, -1)) {
850+
size_t repeat_time_length = repeat_times_.size();
851+
std::vector<int64_t> unsqueeze_idx2;
852+
if (diff > 0) {
853+
std::vector<int64_t> unsqueeze_idx1(diff);
854+
std::iota(unsqueeze_idx1.begin(), unsqueeze_idx1.end(), 0);
855+
t1 = unsqueeze<T>(x, unsqueeze_idx1);
856+
} else {
857+
t1 = x;
858+
}
859+
auto length2 = t1.dims().size();
860+
for (size_t i = 0; i < repeat_times_.size(); i++) {
861+
unsqueeze_idx2.push_back(length2 - repeat_times_.size() + i * 2);
862+
}
863+
864+
Tensor t2 = unsqueeze<T>(t1, unsqueeze_idx2);
865+
std::vector<int64_t> ref_shape(t2.dims().size(), 1);
866+
for (size_t i = 0; i < unsqueeze_idx2.size(); i++) {
867+
ref_shape[unsqueeze_idx2[i]] = repeat_times_[i];
868+
}
869+
Tensor ref_t = full<T>(ref_shape, 1.0, t2.dtype());
870+
Tensor t3 = t2 * ref_t;
871+
Tensor origin_shape_t = shape<T>(t1);
872+
std::vector<int64_t> t1_shape = common::vectorize<int64_t>(t1.dims());
873+
std::vector<Tensor> res_s;
874+
for (int64_t i = int64_t(length2) - 1; i >= 0; i--) {
875+
auto relative_idx =
876+
int64_t(repeat_time_length) - 1 - int64_t(length2 - i - 1);
877+
878+
if (relative_idx >= 0) {
879+
res_s.insert(
880+
res_s.begin(),
881+
get_slice<T>(origin_shape_t, i) * repeat_times_[relative_idx]);
882+
} else {
883+
res_s.insert(res_s.begin(), get_slice<T>(origin_shape_t, i));
884+
}
885+
}
886+
Tensor s4 = concat<T>(res_s, 0);
887+
return backend::reshape_with_tensor<T>(t3, s4);
888+
889+
} else {
890+
if (diff > 0) {
891+
for (int64_t i = 0; i < diff; i++) {
892+
shape1.insert(shape1.begin(), 1);
893+
}
894+
}
895+
896+
auto length = int64_t(shape1.size());
897+
std::vector<int64_t> shape2 = shape1;
898+
std::vector<int64_t> shape3 = shape1;
899+
std::vector<int64_t> final_shape = shape1;
900+
auto r_length = repeat_times_.size();
901+
for (size_t j = 0; j < repeat_times_.size(); j++) {
902+
int64_t i = int64_t(j);
903+
904+
shape2.insert(shape2.begin() + (length - 1 - i), 1);
905+
shape3.insert(shape3.begin() + (length - 1 - i),
906+
repeat_times_[r_length - i - 1]);
907+
908+
final_shape[length - i - 1] =
909+
final_shape[length - i - 1] * repeat_times_[r_length - i - 1];
910+
}
911+
912+
t1 = reshape<T>(x, shape1);
913+
914+
auto t2 = reshape<T>(t1, shape2);
915+
auto t3 = t2.expand(shape3);
916+
auto res = reshape<T>(t3, final_shape);
917+
return res;
918+
}
919+
}
920+
828921
template <typename T>
829922
Tensor square_decomp(const Tensor& x) {
830923
auto org_dtype = x.dtype();

test/legacy_test/test_tile_op.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def init_data(self):
4848
self.repeat_times = [2]
4949

5050
def test_check_output(self):
51-
self.check_output(check_cinn=self.check_cinn, check_pir=True)
51+
self.check_output(
52+
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
53+
)
5254

5355
def test_check_grad(self):
5456
self.check_grad(
@@ -144,6 +146,18 @@ def init_data(self):
144146
def if_enable_cinn(self):
145147
self.check_cinn = True
146148

149+
def test_check_output(self):
150+
# todo: enable check_prim_pir
151+
self.check_output(check_cinn=self.check_cinn, check_pir=True)
152+
153+
def test_check_grad(self):
154+
self.check_grad(
155+
['X'],
156+
'Out',
157+
check_prim=True,
158+
check_pir=True,
159+
)
160+
147161

148162
# Situation 2: repeat_times is a list (with tensor)
149163
# CINN not support repeat_times is a tensor now
@@ -269,7 +283,9 @@ def init_data(self):
269283
self.repeat_times = [2, 1, 4]
270284

271285
def test_check_output(self):
272-
self.check_output(check_cinn=self.check_cinn, check_pir=True)
286+
self.check_output(
287+
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
288+
)
273289

274290
def test_check_grad(self):
275291
self.check_grad(
@@ -307,7 +323,10 @@ def if_enable_cinn(self):
307323
def test_check_output(self):
308324
place = core.CUDAPlace(0)
309325
self.check_output_with_place(
310-
place, check_cinn=self.check_cinn, check_pir=True
326+
place,
327+
check_cinn=self.check_cinn,
328+
check_pir=True,
329+
check_prim_pir=True,
311330
)
312331

313332
def init_data(self):

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,23 @@ def stack_net(x):
7070
return paddle.stack([x, y], axis=0)
7171

7272

73+
def tile_net1(x):
74+
y = paddle.tile(x, repeat_times=[2, 5])
75+
return y
76+
77+
78+
def tile_net2(x):
79+
y = paddle.tile(x, repeat_times=[3, 2, 5])
80+
return y
81+
82+
7383
class TestPrimOne(unittest.TestCase):
7484
def setUp(self):
7585
np.random.seed(2023)
7686
self.dtype = "float32"
77-
self.shape_x = [1, 300, 4096]
78-
self.x = np.random.random(self.shape_x).astype(self.dtype)
87+
self.x_shape = [1, 300, 4096]
88+
self.init_x_shape = [None, None, 4096]
89+
self.x = np.random.random(self.x_shape).astype(self.dtype)
7990
self.net = log_softmax_net
8091
self.necessary_ops = "pd_op.log_softmax"
8192
self.enable_cinn = False
@@ -89,7 +100,7 @@ def base_net(self, flag=None):
89100
self.net,
90101
use_cinn=self.enable_cinn,
91102
input_spec=[
92-
InputSpec(shape=[None, None, 4096], dtype='float32'),
103+
InputSpec(shape=self.init_x_shape, dtype='float32'),
93104
],
94105
)
95106
fn.eval()
@@ -119,8 +130,9 @@ class TestPrimOne2(TestPrimOne):
119130
def setUp(self):
120131
np.random.seed(2023)
121132
self.dtype = "bool"
122-
self.shape_x = [1, 300, 4096]
123-
self.x = np.random.random(self.shape_x).astype(self.dtype)
133+
self.x_shape = [1, 300, 4096]
134+
self.init_x_shape = [None, None, 4096]
135+
self.x = np.random.random(self.x_shape).astype(self.dtype)
124136
self.net = any_net
125137
self.necessary_ops = "pd_op.any"
126138
self.enable_cinn = False
@@ -131,8 +143,8 @@ def setUp(self):
131143
# def setUp(self):
132144
# np.random.seed(2023)
133145
# self.dtype = "int"
134-
# self.shape_x = [1, 300, 4096]
135-
# self.x = np.random.randint(0, 10, size=self.shape_x)
146+
# self.x_shape = [1, 300, 4096]
147+
# self.x = np.random.randint(0, 10, size=self.x_shape)
136148
# self.net = embedding_net
137149
# self.necessary_ops = "pd_op.embedding"
138150
# self.enable_cinn = False
@@ -142,8 +154,9 @@ class TestPrimOne3(TestPrimOne):
142154
def setUp(self):
143155
np.random.seed(2023)
144156
self.dtype = "float32"
145-
self.shape_x = [1, 300, 4096]
146-
self.x = np.random.random(self.shape_x).astype(self.dtype)
157+
self.x_shape = [1, 300, 4096]
158+
self.init_x_shape = [None, None, 4096]
159+
self.x = np.random.random(self.x_shape).astype(self.dtype)
147160
self.net = full_like_net
148161
self.necessary_ops = "pd_op.full_like"
149162
self.enable_cinn = False
@@ -153,12 +166,37 @@ class TestPrimOne4(TestPrimOne):
153166
def setUp(self):
154167
np.random.seed(2023)
155168
self.dtype = "float32"
156-
self.shape_x = [1, 300, 4096]
157-
self.x = np.random.random(self.shape_x).astype(self.dtype)
169+
self.x_shape = [1, 300, 4096]
170+
self.init_x_shape = [None, None, 4096]
171+
self.x = np.random.random(self.x_shape).astype(self.dtype)
158172
self.net = stack_net
159173
self.necessary_ops = "pd_op.stack"
160174
self.enable_cinn = False
161175

162176

177+
class TestPrimOne5(TestPrimOne):
178+
def setUp(self):
179+
np.random.seed(2023)
180+
self.dtype = "float32"
181+
self.x_shape = [1, 300, 4096]
182+
self.init_x_shape = [None, None, 4096]
183+
self.x = np.random.random(self.x_shape).astype(self.dtype)
184+
self.net = tile_net1
185+
self.necessary_ops = "pd_op.tile"
186+
self.enable_cinn = False
187+
188+
189+
class TestPrimOne6(TestPrimOne):
190+
def setUp(self):
191+
np.random.seed(2023)
192+
self.dtype = "float32"
193+
self.x_shape = [300, 4096]
194+
self.init_x_shape = [None, 4096]
195+
self.x = np.random.random(self.x_shape).astype(self.dtype)
196+
self.net = tile_net2
197+
self.necessary_ops = "pd_op.tile"
198+
self.enable_cinn = False
199+
200+
163201
if __name__ == "__main__":
164202
unittest.main()

0 commit comments

Comments
 (0)