Skip to content

Commit 407b1b8

Browse files
authored
[CINN] support forOp with vectorize (#68918)
1 parent dcbea30 commit 407b1b8

File tree

9 files changed

+374
-2
lines changed

9 files changed

+374
-2
lines changed

paddle/cinn/backends/codegen_cuda_dev.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,15 @@ const std::string CodeGenCudaDev::source_header_ = // NOLINT
2525
#include "float16.h"
2626
using cinn::common::bfloat16;
2727
using cinn::common::float16;
28+
using cinn::common::float8;
2829
using cinn::common::half4;
2930
using cinn::common::half8;
30-
using cinn::common::float8;
31+
using cinn::common::float168;
32+
using cinn::common::float164;
33+
using cinn::common::float162;
34+
using cinn::common::bfloat168;
35+
using cinn::common::bfloat164;
36+
using cinn::common::bfloat162;
3137
3238
#include "cinn_cuda_runtime_source.cuh"
3339
)";

paddle/cinn/common/bfloat16.h

+12
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,18 @@ struct CINN_ALIGN(2) bfloat16 {
219219
#endif // __cplusplus
220220
};
221221

222+
struct CINN_ALIGN(16) bfloat168 {
223+
bfloat16 x, y, z, w, v, u, t, s;
224+
};
225+
226+
struct CINN_ALIGN(8) bfloat164 {
227+
bfloat16 x, y, z, w;
228+
};
229+
230+
struct CINN_ALIGN(4) bfloat162 {
231+
bfloat16 x, y;
232+
};
233+
222234
__host__ __device__ inline bfloat16 operator+(const bfloat16& a,
223235
const bfloat16& b) {
224236
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800

paddle/cinn/common/float16.h

+12
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,18 @@ struct CINN_ALIGN(8) half4 {
324324
float16 x, y, z, w;
325325
};
326326

327+
struct CINN_ALIGN(16) float168 {
328+
float16 x, y, z, w, v, u, t, s;
329+
};
330+
331+
struct CINN_ALIGN(8) float164 {
332+
float16 x, y, z, w;
333+
};
334+
335+
struct CINN_ALIGN(4) float162 {
336+
float16 x, y;
337+
};
338+
327339
#ifdef __cplusplus
328340
// Arithmetic operators on GPU
329341
// CUDA 9.0 provides built-in arithmetic operators for half while

paddle/cinn/optim/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ gather_srcs(
3838
eliminate_common_global_memory_read.cc
3939
rearrange_load_instruction.cc
4040
check_tensor_buffer_map.cc
41-
longlong2int.cc)
41+
longlong2int.cc
42+
vectorize_for_trans.cc)
4243

4344
if(WITH_CUDA OR WITH_ROCM)
4445
gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc)

paddle/cinn/optim/optimize.cc

+7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "paddle/cinn/optim/transform_gpu_forloop.h"
3939
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
4040
#include "paddle/cinn/optim/unroll_loops.h"
41+
#include "paddle/cinn/optim/vectorize_for_trans.h"
4142
#include "paddle/cinn/optim/vectorize_loops.h"
4243

4344
namespace cinn {
@@ -104,6 +105,12 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
104105
IfFusion(&copied->body);
105106
VLOG(10) << "After Optimize IfFusion" << copied;
106107

108+
VectorizeForTrans(&copied->body);
109+
VLOG(10) << "After Optimize vectorize" << copied;
110+
111+
Simplify(&copied->body);
112+
VLOG(10) << "After Optimize Simplify" << copied;
113+
107114
return copied;
108115
}
109116

+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/optim/vectorize_for_trans.h"
16+
17+
#include <unordered_set>
18+
#include "paddle/cinn/common/ir_util.h"
19+
#include "paddle/cinn/ir/ir.h"
20+
#include "paddle/cinn/ir/ir_base.h"
21+
#include "paddle/cinn/ir/ir_mutator.h"
22+
#include "paddle/cinn/ir/ir_printer.h"
23+
#include "paddle/cinn/ir/utils/ir_copy.h"
24+
#include "paddle/cinn/ir/utils/ir_replace.h"
25+
#include "paddle/cinn/optim/unroll_loops.h"
26+
27+
namespace cinn {
28+
namespace optim {
29+
30+
namespace {
31+
32+
std::unordered_set<std::string> CollectIndexSymbols(Expr *x) {
33+
struct Mutator : public ir::IRMutator<Expr *> {
34+
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
35+
void Visit(const ir::_Var_ *op, Expr *expr) override {
36+
auto *node = expr->As<ir::_Var_>();
37+
PADDLE_ENFORCE_NOT_NULL(node,
38+
::common::errors::InvalidArgument(
39+
"Sorry, but the node expr is nullptr"));
40+
symbols_.insert(op->name);
41+
}
42+
43+
std::unordered_set<std::string> GetSymbols() { return symbols_; }
44+
45+
private:
46+
std::unordered_set<std::string> symbols_;
47+
};
48+
49+
Mutator mutator;
50+
mutator(x);
51+
return std::move(mutator.GetSymbols());
52+
}
53+
54+
class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
55+
public:
56+
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
57+
58+
void Visit(const ir::Load *op, ir::Expr *expr) override {
59+
if (in_vectorize_) {
60+
auto *node = expr->As<ir::Load>();
61+
auto *tensor = node->tensor.As<ir::_Tensor_>();
62+
if (node->is_addr_tensor() && !IsScalarTensor(node->indices)) {
63+
TensorVectorized(node, &node->indices, false);
64+
}
65+
}
66+
}
67+
68+
void Visit(const ir::Store *op, ir::Expr *expr) override {
69+
auto *node = expr->As<ir::Store>();
70+
auto *tensor = node->tensor.As<ir::_Tensor_>();
71+
PADDLE_ENFORCE_NOT_NULL(
72+
tensor,
73+
::common::errors::InvalidArgument(
74+
"Expected _Tensor_ node in Store, but received nullptr."));
75+
if (in_vectorize_ && !IsScalarTensor(node->indices)) {
76+
TensorVectorized(node, &node->indices, true);
77+
}
78+
IRMutator::Visit(&node->value, &node->value);
79+
}
80+
81+
// forOp don't support vectorize in adjaccnt if-block.
82+
void Visit(const ir::IfThenElse *op, Expr *expr) override {
83+
in_vectorize_ = false;
84+
ir::IRMutator<>::Visit(op, expr);
85+
}
86+
87+
void Visit(const ir::For *op, ir::Expr *expr) override {
88+
auto *forloop = expr->As<ir::For>();
89+
if (op->is_vectorized()) {
90+
vectorize_size_ = forloop->vectorize_info().factor;
91+
loop_var_ = op->loop_var;
92+
in_vectorize_ = true;
93+
}
94+
95+
// deal with vectorize Tensor load and store
96+
IRMutator::Visit(forloop, expr);
97+
98+
if (in_vectorize_) {
99+
const int factor = forloop->vectorize_info().factor;
100+
PADDLE_ENFORCE_GT(factor,
101+
1,
102+
::common::errors::InvalidArgument(
103+
"The value of factor in SplitForLoop is incorrect."
104+
"Expected value is larger than 1, but receive %d. ",
105+
factor));
106+
107+
auto copied_loop =
108+
ir::ir_utils::IRCopy(forloop, /* copy_buffer_node = */ false);
109+
copied_loop.As<ir::For>()->set_unrolled();
110+
optim::UnrollLoop(&copied_loop);
111+
auto unroll_body = copied_loop.As<ir::Block>()->stmts;
112+
auto &body_stmts = forloop->body.As<ir::Block>()->stmts;
113+
if (!update_cast_stmts_.empty()) {
114+
body_stmts.assign(update_cast_stmts_.begin(), update_cast_stmts_.end());
115+
update_cast_stmts_.clear();
116+
}
117+
body_stmts.insert(
118+
body_stmts.end(), unroll_body.begin(), unroll_body.end());
119+
120+
if (!update_store_stmts_.empty()) {
121+
body_stmts.insert(body_stmts.end(),
122+
update_store_stmts_.begin(),
123+
update_store_stmts_.end());
124+
update_store_stmts_.clear();
125+
}
126+
*expr = forloop->body;
127+
}
128+
129+
tensor2vectorized_vars_.clear();
130+
in_vectorize_ = false;
131+
}
132+
133+
private:
134+
std::string GetVectorTypeName(ir::Type type) {
135+
std::string name_prefix =
136+
cinn::common::customized_type::kcuda_builtin_vector_t;
137+
138+
#define GET_CUDA_VECTOR_TYPE_NAME(pred_expr, scalar_name) \
139+
if (pred_expr) { \
140+
return name_prefix + scalar_name + std::to_string(vectorize_size_); \
141+
}
142+
GET_CUDA_VECTOR_TYPE_NAME(type.is_int(8), "char");
143+
GET_CUDA_VECTOR_TYPE_NAME(type.is_int(16), "short");
144+
GET_CUDA_VECTOR_TYPE_NAME(type.is_int(32), "int");
145+
GET_CUDA_VECTOR_TYPE_NAME(type.is_uint(32), "uint");
146+
GET_CUDA_VECTOR_TYPE_NAME(type.is_float(32), "float");
147+
GET_CUDA_VECTOR_TYPE_NAME(type.is_float16(), "float16");
148+
GET_CUDA_VECTOR_TYPE_NAME(type.is_bfloat16(), "bfloat16");
149+
#undef GET_CUDA_VECTOR_TYPE_NAME
150+
151+
// others are not implemented yet
152+
CINN_NOT_IMPLEMENTED
153+
return "";
154+
}
155+
156+
bool IsScalarTensor(const std::vector<ir::Expr> &indices) {
157+
for (auto var : indices) {
158+
std::unordered_set<std::string> index_symbols = CollectIndexSymbols(&var);
159+
if (index_symbols.count(loop_var_->name)) return false;
160+
}
161+
return true;
162+
}
163+
164+
void TensorVectorized(ir::LoadStoreAddrMnger *node,
165+
std::vector<ir::Expr> *indices,
166+
bool is_store) {
167+
auto *tensor = node->tensor.As<ir::_Tensor_>();
168+
169+
if (!tensor2vectorized_vars_.count(tensor->name)) {
170+
AppendCast(node->tensor, *indices, is_store);
171+
}
172+
173+
auto vectorized_var = tensor2vectorized_vars_.at(tensor->name);
174+
// substitute a new tensor with the vector name and dtype
175+
auto t = vectorized_var->type().is_cpp_handle()
176+
? node->tensor->type().PointerOf()
177+
: node->tensor->type();
178+
node->tensor = ir::Tensor(vectorized_var->name,
179+
t,
180+
{ir::Expr(vectorize_size_)},
181+
{ir::Expr(vectorize_size_)},
182+
tensor->operation);
183+
// remain the last iterative indice
184+
indices->assign({loop_var_});
185+
}
186+
187+
void AppendCast(ir::Expr tensor,
188+
const std::vector<ir::Expr> &indices,
189+
bool is_store) {
190+
auto *node = tensor.As<ir::_Tensor_>();
191+
192+
// generate the corresponding vector type
193+
Type scalar_type = tensor->type().ElementOf();
194+
Type vector_type_ptr(
195+
ir::Type::type_t::Customized, scalar_type.bits(), vectorize_size_);
196+
Type vector_type(
197+
ir::Type::type_t::Customized, scalar_type.bits(), vectorize_size_);
198+
vector_type_ptr.set_customized_type(GetVectorTypeName(scalar_type));
199+
vector_type_ptr.set_cpp_handle();
200+
vector_type_ptr.set_cpp_const(false);
201+
202+
vector_type.set_customized_type(GetVectorTypeName(scalar_type));
203+
vector_type.set_cpp_const(false);
204+
205+
// generate a local vector variable to be used in subsequent statements
206+
std::string vectorized_name =
207+
"vectorized_" + node->name + "_" + std::to_string(var_index_++);
208+
Var vectorized_var = ir::_Var_::Make(vectorized_name, vector_type);
209+
tensor2vectorized_vars_.emplace(node->name, vectorized_var);
210+
211+
// generate a get_addr expr to get the address of the tensor
212+
Expr converted_tensor = ir::Load::Make(tensor, indices);
213+
cinn::ir::ir_utils::IrReplaceVarBroadcast(
214+
&converted_tensor, loop_var_, Expr(int32_t(0)));
215+
auto get_addr = ir::intrinsics::GetAddr::Make(converted_tensor);
216+
217+
// generate a let expression to cast the tensor into the local vector
218+
auto cast = ir::Cast::Make(vector_type_ptr, get_addr);
219+
if (!is_store) {
220+
auto load = ir::Load::Make(cast, {cinn::common::make_const(0)});
221+
auto let = ir::Let::Make(vectorized_var, load);
222+
update_cast_stmts_.emplace_back(let);
223+
} else {
224+
Var vectorized_ptr =
225+
ir::_Var_::Make(vectorized_name + "_ptr", vector_type_ptr);
226+
auto let1 = ir::Let::Make(vectorized_ptr, cast);
227+
auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0));
228+
update_cast_stmts_.emplace_back(let1);
229+
update_cast_stmts_.emplace_back(let2);
230+
231+
auto t = ir::Tensor(vectorized_ptr->name,
232+
node->type().PointerOf(),
233+
{ir::Expr(vectorize_size_)},
234+
{ir::Expr(vectorize_size_)},
235+
node->operation);
236+
auto store =
237+
ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)});
238+
update_store_stmts_.emplace_back(store);
239+
VLOG(5) << "Append a vectorized expr:" << store;
240+
}
241+
}
242+
243+
std::vector<ir::Expr> update_cast_stmts_;
244+
std::vector<ir::Expr> update_store_stmts_;
245+
absl::flat_hash_map<std::string, ir::Var> tensor2vectorized_vars_;
246+
247+
int vectorize_size_{0};
248+
ir::Var loop_var_;
249+
bool in_vectorize_{false};
250+
int var_index_{0};
251+
};
252+
253+
} // namespace
254+
255+
void VectorizeForTrans(Expr *expr) {
256+
VectorizeForTransMutator collector;
257+
VLOG(5) << "before vectorize for trans " << *expr;
258+
collector(expr);
259+
VLOG(5) << "after vectorize for trans " << *expr;
260+
}
261+
262+
} // namespace optim
263+
} // namespace cinn
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "paddle/cinn/ir/ir.h"
17+
namespace cinn {
18+
namespace optim {
19+
/**
20+
* Deal with forOp with vectorization.
21+
* if vectorize factor match vectorize instruction and don't have adjaccnt
22+
* if-block.
23+
*
24+
* e.g.
25+
*
26+
* serial for (i, 0, 4)
27+
* serial for (j, 0, 4)
28+
* vectorize[4] for (v1, 0, 4)
29+
* float a[i, j, v1] = float b[i, j, v1] + float c[i, j, v1]
30+
*
31+
* to
32+
*
33+
* serial for (i, 0, 4)
34+
* serial for (j, 0, 4)
35+
* float4* temp_0_ptr = float4<4>*(get_addr(a[i * 4 + j]))
36+
* float4 temp_1
37+
* float4 temp_2 = b[i * 4 + j]
38+
* float4 temp_3 = c[i * 4 + j]
39+
* temp_1[0] = temp_2[0] + temp_3[0]
40+
* temp_1[1] = temp_2[1] + temp_3[1]
41+
* temp_1[2] = temp_2[2] + temp_3[2]
42+
* temp_1[3] = temp_2[3] + temp_3[3]
43+
* temp_0_ptr[0] = temp_1
44+
*/
45+
void VectorizeForTrans(Expr *expr);
46+
} // namespace optim
47+
} // namespace cinn

0 commit comments

Comments
 (0)