Skip to content

Commit 57a6dbb

Browse files
authored
【Hackathon 7th Fundable Projects 2 No.73】 [fluid_ops] load_combine (#68665)
1 parent d814598 commit 57a6dbb

18 files changed

+540
-71
lines changed

paddle/fluid/framework/infershape_utils.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
140140
});
141141
}
142142

143+
bool IsVocabOutput(const std::string& name) const override {
144+
auto var_types = ctx_.GetOutputsVarType(name);
145+
return std::all_of(var_types.begin(),
146+
var_types.end(),
147+
[](const proto::VarType::Type& type) {
148+
return type == proto::VarType::VOCAB;
149+
});
150+
}
151+
143152
bool IsSelectedRowsOutput(const std::string& name) const override {
144153
auto var_types = ctx_.GetOutputsVarType(name);
145154
return std::all_of(var_types.begin(),

paddle/fluid/framework/new_executor/interpreter/static_build.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
186186
}
187187

188188
inline bool IsExtendedTensor(const phi::TensorBase& tensor) {
189-
return framework::RawTensor::classof(&tensor) ||
190-
phi::Strings::classof(&tensor) || phi::Vocab::classof(&tensor);
189+
return phi::RawTensor::classof(&tensor) || phi::Strings::classof(&tensor) ||
190+
phi::Vocab::classof(&tensor);
191191
}
192192

193193
bool TensorShouldBeFakeInitialized(const OperatorBase& op,
@@ -281,9 +281,11 @@ phi::TensorBase* GetTensorFormVar(framework::Variable* var) {
281281
return var->template GetMutable<phi::TensorArray>();
282282
} else if (var->template IsType<phi::Strings>()) {
283283
return var->template GetMutable<phi::Strings>();
284-
} else if (var->template IsType<paddle::framework::RawTensor>() ||
284+
} else if (var->template IsType<phi::Vocab>()) {
285+
return var->template GetMutable<phi::Vocab>();
286+
} else if (var->template IsType<phi::RawTensor>() ||
285287
!var->IsInitialized()) {
286-
return var->template GetMutable<paddle::framework::RawTensor>();
288+
return var->template GetMutable<phi::RawTensor>();
287289
} else {
288290
PADDLE_THROW(common::errors::Unimplemented(
289291
"Unsupported `%s` type when get tensor.",

paddle/fluid/framework/operator.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ limitations under the License. */
3939
#include "paddle/fluid/framework/unused_var_check.h"
4040
#include "paddle/phi/core/memory/malloc.h"
4141
#include "paddle/phi/core/platform/device_context.h"
42+
#include "paddle/phi/core/vocab/string_array.h"
4243

4344
#include "paddle/common/flags.h"
4445
#include "paddle/common/macros.h"
@@ -697,6 +698,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
697698
});
698699
}
699700

701+
bool IsVocabOutput(const std::string& name) const override {
702+
auto vars = ctx_.MultiOutputVar(name);
703+
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
704+
return var->IsType<phi::Vocab>();
705+
});
706+
}
707+
700708
bool IsSelectedRowsOutput(const std::string& name) const override {
701709
auto vars = ctx_.MultiOutputVar(name);
702710
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {

paddle/fluid/framework/type_info.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ bool TypeInfoTraits<BaseT, DerivedT>::classof(const BaseT* obj) {
3838
return obj->type_info() == kType;
3939
}
4040

41-
template class TypeInfoTraits<phi::TensorBase, paddle::framework::RawTensor>;
4241
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
4342
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
4443
template class TypeInfoTraits<phi::TensorBase, paddle::primitive::LazyTensor>;

paddle/fluid/framework/variable.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ class Variable {
4949
if (!holder_) {
5050
holder_.reset(new PlaceholderImpl<T>());
5151
} else {
52+
// If holder_ is RawTensor, call holder_->Ptr() GetMutable again. Used for
53+
// load_combine.
54+
if (holder_->Type() == VarTypeTrait<RawTensor>::kId &&
55+
holder_->Type() != VarTypeTrait<T>::kId) {
56+
return static_cast<RawTensor*>(holder_->Ptr())->GetMutable<T>();
57+
}
5258
PADDLE_ENFORCE_EQ(
5359
holder_->Type(),
5460
VarTypeTrait<T>::kId,

paddle/fluid/operators/load_combine_op.cc

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,3 @@ namespace ops = paddle::operators; // NOLINT
8181
REGISTER_OPERATOR(load_combine,
8282
ops::LoadCombineOp,
8383
ops::LoadCombineOpProtoMaker);
84-
85-
PD_REGISTER_STRUCT_KERNEL(load_combine,
86-
CPU,
87-
ALL_LAYOUT,
88-
ops::LoadCombineOpKernel,
89-
float,
90-
double,
91-
phi::dtype::bfloat16,
92-
int,
93-
int8_t,
94-
int64_t) {}

paddle/fluid/operators/load_combine_op.cu

Lines changed: 0 additions & 26 deletions
This file was deleted.

paddle/fluid/operators/load_combine_op_xpu.cc

Lines changed: 0 additions & 26 deletions
This file was deleted.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/core/compat/op_utils.h"
16+
17+
namespace phi {
18+
19+
KernelSignature LoadCombineOpArgumentMapping(
20+
const ArgumentMappingContext& ctx) {
21+
if (ctx.IsDenseTensorOutput("Out")) {
22+
return KernelSignature("load_combine",
23+
{},
24+
{"file_path", "load_as_fp16", "model_from_memory"},
25+
{"Out"});
26+
} else if (ctx.IsVocabOutput("Out")) {
27+
return KernelSignature("load_combine_vocab",
28+
{},
29+
{"file_path", "load_as_fp16", "model_from_memory"},
30+
{"Out"});
31+
} else {
32+
return KernelSignature("load_combine_extended",
33+
{},
34+
{"file_path", "load_as_fp16", "model_from_memory"},
35+
{"Out"});
36+
}
37+
}
38+
39+
} // namespace phi
40+
41+
PD_REGISTER_ARG_MAPPING_FN(load_combine, phi::LoadCombineOpArgumentMapping);

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ namespace paddle {
3636
namespace dialect {
3737

3838
const std::unordered_set<std::string> LegacyOpList = {
39-
LoadCombineOp::name(),
4039
CConcatOp::name(),
4140
CBroadcast_Op::name(),
4241
CBroadcastOp::name(),

paddle/phi/core/compat/arg_map_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class ArgumentMappingContext {
119119

120120
virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
121121
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
122+
virtual bool IsVocabOutput(const std::string& name) const { return false; }
122123

123124
// use this function to mark it comes from InferShapeArgumentMappingContext
124125
// and will be used in infershape

paddle/phi/core/kernel_registry.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,11 @@ void SetKernelArgsDef(const std::vector<std::type_index>& args_type,
218218
default_tensor_layout,
219219
default_key.dtype(),
220220
arg_type);
221-
} else if (arg_type ==
222-
std::type_index(typeid(ExtendedTensor*))) { // NOLINT
221+
} else if (arg_type == std::type_index(typeid(ExtendedTensor*)) ||
222+
arg_type ==
223+
std::type_index(typeid(std::vector<ExtendedTensor*>)) ||
224+
arg_type ==
225+
std::type_index(typeid(std::vector<Vocab*>))) { // NOLINT
223226
args_def->AppendOutput(default_key.backend(),
224227
default_tensor_layout,
225228
default_key.dtype(),

paddle/phi/core/kernel_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
387387

388388
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray);
389389
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(ExtendedTensor);
390+
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(ExtendedTensor);
391+
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(Vocab);
390392

391393
/* End case */
392394
template <typename T>

paddle/phi/core/utils/type_info.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/phi/backends/xpu/xpu_context.h"
2121
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
2222
#include "paddle/phi/core/framework/feed_fetch_type.h"
23+
#include "paddle/phi/core/raw_tensor.h"
2324
#include "paddle/phi/core/selected_rows.h"
2425
#include "paddle/phi/core/sparse_coo_tensor.h"
2526
#include "paddle/phi/core/sparse_csr_tensor.h"
@@ -54,6 +55,7 @@ template class TypeInfoTraits<phi::TensorBase, TensorArray>;
5455
template class TypeInfoTraits<phi::TensorBase, phi::distributed::DistTensor>;
5556
template class TypeInfoTraits<phi::TensorBase, Vocab>;
5657
template class TypeInfoTraits<phi::TensorBase, Strings>;
58+
template class TypeInfoTraits<phi::TensorBase, RawTensor>;
5759
template class TypeInfoTraits<phi::TensorBase, FeedList>;
5860

5961
template class TypeInfoTraits<phi::DeviceContext, CPUContext>;
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/impl/load_combine_kernel_impl.h"
16+
17+
PD_REGISTER_KERNEL(load_combine,
18+
CPU,
19+
ALL_LAYOUT,
20+
phi::LoadCombineKernel,
21+
float,
22+
double,
23+
phi::dtype::bfloat16,
24+
int,
25+
int8_t,
26+
int64_t) {}
27+
28+
PD_REGISTER_KERNEL(load_combine_vocab,
29+
CPU,
30+
ALL_LAYOUT,
31+
phi::LoadCombineVocabKernel,
32+
float,
33+
double,
34+
phi::dtype::bfloat16,
35+
int,
36+
int8_t,
37+
int64_t) {}
38+
39+
PD_REGISTER_KERNEL(load_combine_extended,
40+
CPU,
41+
ALL_LAYOUT,
42+
phi::LoadCombineExtendedKernel,
43+
float,
44+
double,
45+
phi::dtype::bfloat16,
46+
int,
47+
int8_t,
48+
int64_t) {}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/impl/load_combine_kernel_impl.h"
16+
17+
PD_REGISTER_KERNEL(load_combine,
18+
GPU,
19+
ALL_LAYOUT,
20+
phi::LoadCombineKernel,
21+
float,
22+
double,
23+
int,
24+
int8_t,
25+
int64_t) {}
26+
27+
PD_REGISTER_KERNEL(load_combine_vocab,
28+
GPU,
29+
ALL_LAYOUT,
30+
phi::LoadCombineVocabKernel,
31+
float,
32+
double,
33+
int,
34+
int8_t,
35+
int64_t) {}
36+
37+
PD_REGISTER_KERNEL(load_combine_extended,
38+
GPU,
39+
ALL_LAYOUT,
40+
phi::LoadCombineExtendedKernel,
41+
float,
42+
double,
43+
int,
44+
int8_t,
45+
int64_t) {}

0 commit comments

Comments
 (0)