From 01bead2303dc2796ec05eaf896b427c4c2f4b675 Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 22 Oct 2024 19:53:24 +0800 Subject: [PATCH] Fix --- paddle/fluid/framework/infershape_utils.cc | 9 + .../new_executor/interpreter/static_build.cc | 10 +- paddle/fluid/framework/operator.h | 8 + paddle/fluid/framework/type_info.cc | 1 - paddle/fluid/framework/variable.h | 6 + paddle/fluid/operators/load_combine_op.cc | 11 - paddle/fluid/operators/load_combine_op.cu | 26 -- paddle/fluid/operators/load_combine_op_xpu.cc | 26 -- .../ops_signature/load_combine_sig.cc | 41 +++ .../fluid/pir/dialect/operator/utils/utils.cc | 1 - paddle/phi/core/compat/arg_map_context.h | 1 + paddle/phi/core/kernel_registry.cc | 7 +- paddle/phi/core/kernel_utils.h | 2 + paddle/phi/core/utils/type_info.cc | 2 + paddle/phi/kernels/cpu/load_combine_kernel.cc | 48 +++ paddle/phi/kernels/gpu/load_combine_kernel.cu | 45 +++ .../kernels/impl/load_combine_kernel_impl.h | 322 ++++++++++++++++++ paddle/phi/kernels/xpu/load_combine_kernel.cc | 45 +++ 18 files changed, 540 insertions(+), 71 deletions(-) delete mode 100644 paddle/fluid/operators/load_combine_op.cu delete mode 100644 paddle/fluid/operators/load_combine_op_xpu.cc create mode 100644 paddle/fluid/operators/ops_signature/load_combine_sig.cc create mode 100644 paddle/phi/kernels/cpu/load_combine_kernel.cc create mode 100644 paddle/phi/kernels/gpu/load_combine_kernel.cu create mode 100644 paddle/phi/kernels/impl/load_combine_kernel_impl.h create mode 100644 paddle/phi/kernels/xpu/load_combine_kernel.cc diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 0490435d955b73..3c0b14c23b628b 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -140,6 +140,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { }); } + bool IsVocabOutput(const std::string& name) const override { + auto var_types = ctx_.GetOutputsVarType(name); + return std::all_of(var_types.begin(), + var_types.end(), + [](const proto::VarType::Type& type) { + return type == proto::VarType::VOCAB; + }); + } + bool IsSelectedRowsOutput(const std::string& name) const override { auto var_types = ctx_.GetOutputsVarType(name); return std::all_of(var_types.begin(), diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index c1e990ff06cd3e..deb76360762192 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -186,8 +186,8 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) { } inline bool IsExtendedTensor(const phi::TensorBase& tensor) { - return framework::RawTensor::classof(&tensor) || - phi::Strings::classof(&tensor) || phi::Vocab::classof(&tensor); + return phi::RawTensor::classof(&tensor) || phi::Strings::classof(&tensor) || + phi::Vocab::classof(&tensor); } bool TensorShouldBeFakeInitialized(const OperatorBase& op, @@ -281,9 +281,11 @@ phi::TensorBase* GetTensorFormVar(framework::Variable* var) { return var->template GetMutable(); } else if (var->template IsType()) { return var->template GetMutable(); - } else if (var->template IsType() || + } else if (var->template IsType()) { + return var->template GetMutable(); + } else if (var->template IsType() || !var->IsInitialized()) { - return var->template GetMutable(); + return var->template GetMutable(); } else { PADDLE_THROW(common::errors::Unimplemented( "Unsupported `%s` type when get tensor.", diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1ce1f4004965d9..9cfa3d51d6b106 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -39,6 +39,7 @@ limitations under the License. */ #include "paddle/fluid/framework/unused_var_check.h" #include "paddle/phi/core/memory/malloc.h" #include "paddle/phi/core/platform/device_context.h" +#include "paddle/phi/core/vocab/string_array.h" #include "paddle/common/flags.h" #include "paddle/common/macros.h" @@ -697,6 +698,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { }); } + bool IsVocabOutput(const std::string& name) const override { + auto vars = ctx_.MultiOutputVar(name); + return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { + return var->IsType(); + }); + } + bool IsSelectedRowsOutput(const std::string& name) const override { auto vars = ctx_.MultiOutputVar(name); return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { diff --git a/paddle/fluid/framework/type_info.cc b/paddle/fluid/framework/type_info.cc index 232922a1569c4d..862cea21f6cbe2 100644 --- a/paddle/fluid/framework/type_info.cc +++ b/paddle/fluid/framework/type_info.cc @@ -38,7 +38,6 @@ bool TypeInfoTraits::classof(const BaseT* obj) { return obj->type_info() == kType; } -template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 9f58091fc3bbab..db23515ad8e91c 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -49,6 +49,12 @@ class Variable { if (!holder_) { holder_.reset(new PlaceholderImpl()); } else { + // If holder_ is RawTensor, call holder_->Ptr() GetMutable again. Used for + // load_combine. + if (holder_->Type() == VarTypeTrait::kId && + holder_->Type() != VarTypeTrait::kId) { + return static_cast(holder_->Ptr())->GetMutable(); + } PADDLE_ENFORCE_EQ( holder_->Type(), VarTypeTrait::kId, diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 40680dbf00829a..75721ce8b6161c 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -81,14 +81,3 @@ namespace ops = paddle::operators; // NOLINT REGISTER_OPERATOR(load_combine, ops::LoadCombineOp, ops::LoadCombineOpProtoMaker); - -PD_REGISTER_STRUCT_KERNEL(load_combine, - CPU, - ALL_LAYOUT, - ops::LoadCombineOpKernel, - float, - double, - phi::dtype::bfloat16, - int, - int8_t, - int64_t) {} diff --git a/paddle/fluid/operators/load_combine_op.cu b/paddle/fluid/operators/load_combine_op.cu deleted file mode 100644 index 379834daec22e8..00000000000000 --- a/paddle/fluid/operators/load_combine_op.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/load_combine_op.h" - -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL(load_combine, - GPU, - ALL_LAYOUT, - ops::LoadCombineOpKernel, - float, - double, - int, - int8_t, - int64_t) {} diff --git a/paddle/fluid/operators/load_combine_op_xpu.cc b/paddle/fluid/operators/load_combine_op_xpu.cc deleted file mode 100644 index d285af37cda98f..00000000000000 --- a/paddle/fluid/operators/load_combine_op_xpu.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/load_combine_op.h" - -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL(load_combine, - XPU, - ALL_LAYOUT, - ops::LoadCombineOpKernel, - float, - double, - int, - int8_t, - int64_t) {} diff --git a/paddle/fluid/operators/ops_signature/load_combine_sig.cc b/paddle/fluid/operators/ops_signature/load_combine_sig.cc new file mode 100644 index 00000000000000..de10113e6eb0ab --- /dev/null +++ b/paddle/fluid/operators/ops_signature/load_combine_sig.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LoadCombineOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorOutput("Out")) { + return KernelSignature("load_combine", + {}, + {"file_path", "load_as_fp16", "model_from_memory"}, + {"Out"}); + } else if (ctx.IsVocabOutput("Out")) { + return KernelSignature("load_combine_vocab", + {}, + {"file_path", "load_as_fp16", "model_from_memory"}, + {"Out"}); + } else { + return KernelSignature("load_combine_extended", + {}, + {"file_path", "load_as_fp16", "model_from_memory"}, + {"Out"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(load_combine, phi::LoadCombineOpArgumentMapping); diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index a39642bdf6edad..1d8b8affe037dd 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -36,7 +36,6 @@ namespace paddle { namespace dialect { const std::unordered_set LegacyOpList = { - LoadCombineOp::name(), CConcatOp::name(), CBroadcast_Op::name(), CBroadcastOp::name(), diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 5ae8acad67a295..5ee383996131e9 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -119,6 +119,7 @@ class ArgumentMappingContext { virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; + virtual bool IsVocabOutput(const std::string& name) const { return false; } // use this function to mark it comes from InferShapeArgumentMappingContext // and will be used in infershape diff --git a/paddle/phi/core/kernel_registry.cc b/paddle/phi/core/kernel_registry.cc index 4c5c5785d6301d..66f405bedaf1fd 100644 --- a/paddle/phi/core/kernel_registry.cc +++ b/paddle/phi/core/kernel_registry.cc @@ -218,8 +218,11 @@ void SetKernelArgsDef(const std::vector& args_type, default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == - std::type_index(typeid(ExtendedTensor*))) { // NOLINT + } else if (arg_type == std::type_index(typeid(ExtendedTensor*)) || + arg_type == + std::type_index(typeid(std::vector)) || + arg_type == + std::type_index(typeid(std::vector))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 225326e26cfeee..6d0ac20e5c83e3 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -387,6 +387,8 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(ExtendedTensor); + PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(ExtendedTensor); + PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(Vocab); /* End case */ template diff --git a/paddle/phi/core/utils/type_info.cc b/paddle/phi/core/utils/type_info.cc index b12ba9168ed879..f6ffec27fa8604 100644 --- a/paddle/phi/core/utils/type_info.cc +++ b/paddle/phi/core/utils/type_info.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/framework/feed_fetch_type.h" +#include "paddle/phi/core/raw_tensor.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -54,6 +55,7 @@ template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; +template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; diff --git a/paddle/phi/kernels/cpu/load_combine_kernel.cc b/paddle/phi/kernels/cpu/load_combine_kernel.cc new file mode 100644 index 00000000000000..e1bf4ec0a03430 --- /dev/null +++ b/paddle/phi/kernels/cpu/load_combine_kernel.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/load_combine_kernel_impl.h" + +PD_REGISTER_KERNEL(load_combine, + CPU, + ALL_LAYOUT, + phi::LoadCombineKernel, + float, + double, + phi::dtype::bfloat16, + int, + int8_t, + int64_t) {} + +PD_REGISTER_KERNEL(load_combine_vocab, + CPU, + ALL_LAYOUT, + phi::LoadCombineVocabKernel, + float, + double, + phi::dtype::bfloat16, + int, + int8_t, + int64_t) {} + +PD_REGISTER_KERNEL(load_combine_extended, + CPU, + ALL_LAYOUT, + phi::LoadCombineExtendedKernel, + float, + double, + phi::dtype::bfloat16, + int, + int8_t, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/load_combine_kernel.cu b/paddle/phi/kernels/gpu/load_combine_kernel.cu new file mode 100644 index 00000000000000..8d0f01c4ac22b5 --- /dev/null +++ b/paddle/phi/kernels/gpu/load_combine_kernel.cu @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/load_combine_kernel_impl.h" + +PD_REGISTER_KERNEL(load_combine, + GPU, + ALL_LAYOUT, + phi::LoadCombineKernel, + float, + double, + int, + int8_t, + int64_t) {} + +PD_REGISTER_KERNEL(load_combine_vocab, + GPU, + ALL_LAYOUT, + phi::LoadCombineVocabKernel, + float, + double, + int, + int8_t, + int64_t) {} + +PD_REGISTER_KERNEL(load_combine_extended, + GPU, + ALL_LAYOUT, + phi::LoadCombineExtendedKernel, + float, + double, + int, + int8_t, + int64_t) {} diff --git a/paddle/phi/kernels/impl/load_combine_kernel_impl.h b/paddle/phi/kernels/impl/load_combine_kernel_impl.h new file mode 100644 index 00000000000000..c197b7ab5c525f --- /dev/null +++ b/paddle/phi/kernels/impl/load_combine_kernel_impl.h @@ -0,0 +1,322 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/core/extended_tensor.h" +#include "paddle/phi/core/framework/convert_utils.h" +#include "paddle/phi/core/framework/data_type_transform.h" +#include "paddle/phi/core/framework/lod_tensor_serialize.h" +#include "paddle/phi/core/framework/var_type_helper.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/platform/device_context.h" +#include "paddle/phi/core/raw_tensor.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/vocab/string_array.h" + +namespace phi { + +template +void LoadParamsFromBuffer(const Context& dev_ctx, + const phi::Place& place, + std::istream* buffer, + bool load_as_fp16, + const std::vector& out) { + auto out_vars = out; + for (size_t i = 0; i < out_vars.size(); i++) { + PADDLE_ENFORCE_NOT_NULL( + out_vars[i], + common::errors::InvalidArgument( + "The variable index %d to be loaded cannot be found.", i)); + // Error checking + PADDLE_ENFORCE_EQ( + static_cast(*buffer), + true, + common::errors::Unavailable( + "An error occurred while loading model parameters. " + "Please check whether the model file is complete or damaged.")); + + dev_ctx.template Alloc(out_vars[i]); + phi::DenseTensor* tensor = out_vars[i]; + // Get data from fin to tensor + phi::DeserializeFromStream(*buffer, tensor, dev_ctx); + auto in_dtype = tensor->dtype(); + auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; + if (in_dtype != out_dtype) { + // convert to float16 tensor + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); + phi::DenseTensor fp16_tensor; + // copy LoD info to the new tensor + fp16_tensor.set_lod(tensor->lod()); + TransDataType(in_kernel_type, out_kernel_type, *tensor, &fp16_tensor); + + // reset output tensor + tensor->set_lod(fp16_tensor.lod()); + tensor->ShareDataWith(fp16_tensor); + } + } + buffer->peek(); + PADDLE_ENFORCE_EQ(buffer->eof(), + true, + common::errors::Unavailable( + "Not allowed to load partial data via " + "load_combine_op, please use load_op instead.")); +} + +template +void LoadParamsFromBuffer(const Context& dev_ctx, + const phi::Place& place, + std::istream* buffer, + bool load_as_fp16, + const std::vector& out) { + auto out_vars = out; + for (size_t i = 0; i < out_vars.size(); i++) { + PADDLE_ENFORCE_NOT_NULL( + out_vars[i], + common::errors::InvalidArgument( + "The variable index %d to be loaded cannot be found.", i)); + // Error checking + PADDLE_ENFORCE_EQ( + static_cast(*buffer), + true, + common::errors::Unavailable( + "An error occurred while loading model parameters. " + "Please check whether the model file is complete or damaged.")); + + auto* tensor = out_vars[i]; + tensor->clear(); + std::unordered_map data; + StringMapFromStream(*buffer, &data); + for (auto it = data.begin(); it != data.end(); ++it) { + std::string tmp; + NFD(it->first, &tmp); + if (tmp.empty()) { + // VLOG(0) << "The string " << it->first + // << " was converted to unicode unsuccessfully! " + // << "Then dropped to load it."; + continue; + } + std::wstring token; + bool status = ConvertStrToWstr(tmp, &token); + if (!status) continue; + tensor->emplace(token, it->second); + } + } + buffer->peek(); + PADDLE_ENFORCE_EQ(buffer->eof(), + true, + common::errors::Unavailable( + "Not allowed to load partial data via " + "load_combine_op, please use load_op instead.")); +} + +template +void LoadParamsFromBuffer(const Context& dev_ctx, + const phi::Place& place, + std::istream* buffer, + bool load_as_fp16, + const std::vector& out) { + auto out_vars = out; + for (size_t i = 0; i < out_vars.size(); i++) { + PADDLE_ENFORCE_NOT_NULL( + out_vars[i], + common::errors::InvalidArgument( + "The variable index %d to be loaded cannot be found.", i)); + // Error checking + PADDLE_ENFORCE_EQ( + static_cast(*buffer), + true, + common::errors::Unavailable( + "An error occurred while loading model parameters. " + "Please check whether the model file is complete or damaged.")); + auto* raw_tensor = static_cast(out_vars[i]); + if (raw_tensor->IsType()) { + auto* tensor = raw_tensor->GetMutable(); + tensor->clear(); + std::unordered_map data; + StringMapFromStream(*buffer, &data); + for (auto it = data.begin(); it != data.end(); ++it) { + std::string tmp; + NFD(it->first, &tmp); + if (tmp.empty()) { + // VLOG(0) << "The string " << it->first + // << " was converted to unicode unsuccessfully! " + // << "Then dropped to load it."; + continue; + } + std::wstring token; + bool status = ConvertStrToWstr(tmp, &token); + if (!status) continue; + tensor->emplace(token, it->second); + } + } else { + auto* tensor = raw_tensor->GetMutable(); + + // Get data from fin to tensor + DeserializeFromStream(*buffer, tensor, dev_ctx); + + auto in_dtype = tensor->dtype(); + auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype; + + if (in_dtype != out_dtype) { + // convert to float16 tensor + auto in_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype); + auto out_kernel_type = + phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype); + phi::DenseTensor fp16_tensor; + // copy LoD info to the new tensor + fp16_tensor.set_lod(tensor->lod()); + TransDataType(in_kernel_type, out_kernel_type, *tensor, &fp16_tensor); + + // reset output tensor + // raw_tensor->Clear(); + tensor = raw_tensor->GetMutable(); + tensor->set_lod(fp16_tensor.lod()); + tensor->ShareDataWith(fp16_tensor); + } + } + } + buffer->peek(); + PADDLE_ENFORCE_EQ(buffer->eof(), + true, + common::errors::Unavailable( + "Not allowed to load partial data via " + "load_combine_op, please use load_op instead.")); +} + +template +void LoadCombineKernel(const Context& dev_ctx, + const std::string& file_path, + bool load_as_fp16, + bool model_from_memory, + std::vector out) { + auto place = dev_ctx.GetPlace(); + auto filename = file_path; + auto out_var_names = out; + + PADDLE_ENFORCE_GT(out_var_names.size(), + 0UL, + common::errors::InvalidArgument( + "The number of variables to be loaded is %d, expect " + "it to be greater than 0.", + out_var_names.size())); + if (!model_from_memory) { + std::ifstream fin(filename, std::ios::binary); + PADDLE_ENFORCE_EQ( + static_cast(fin), + true, + common::errors::Unavailable( + "LoadCombine operator fails to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + LoadParamsFromBuffer(dev_ctx, place, &fin, load_as_fp16, out); + } else { + PADDLE_ENFORCE_NE( + filename.empty(), + true, + common::errors::Unavailable( + "LoadCombine operator fails to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + std::stringstream fin(filename, std::ios::in | std::ios::binary); + LoadParamsFromBuffer(dev_ctx, place, &fin, load_as_fp16, out); + } +} + +template +void LoadCombineVocabKernel(const Context& dev_ctx, + const std::string& file_path, + bool load_as_fp16, + bool model_from_memory, + std::vector out) { + auto place = dev_ctx.GetPlace(); + auto filename = file_path; + auto out_var_names = out; + + PADDLE_ENFORCE_GT(out_var_names.size(), + 0UL, + common::errors::InvalidArgument( + "The number of variables to be loaded is %d, expect " + "it to be greater than 0.", + out_var_names.size())); + if (!model_from_memory) { + std::ifstream fin(filename, std::ios::binary); + PADDLE_ENFORCE_EQ( + static_cast(fin), + true, + common::errors::Unavailable( + "LoadCombine operator fails to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + LoadParamsFromBuffer(dev_ctx, place, &fin, load_as_fp16, out); + } else { + PADDLE_ENFORCE_NE( + filename.empty(), + true, + common::errors::Unavailable( + "LoadCombine operator fails to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + std::stringstream fin(filename, std::ios::in | std::ios::binary); + LoadParamsFromBuffer(dev_ctx, place, &fin, load_as_fp16, out); + } +} + +template +void LoadCombineExtendedKernel(const Context& dev_ctx, + const std::string& file_path, + bool load_as_fp16, + bool model_from_memory, + std::vector out) { + auto place = dev_ctx.GetPlace(); + auto filename = file_path; + auto out_var_names = out; + + PADDLE_ENFORCE_GT(out_var_names.size(), + 0UL, + common::errors::InvalidArgument( + "The number of variables to be loaded is %d, expect " + "it to be greater than 0.", + out_var_names.size())); + if (!model_from_memory) { + std::ifstream fin(filename, std::ios::binary); + PADDLE_ENFORCE_EQ( + static_cast(fin), + true, + common::errors::Unavailable( + "LoadCombine operator fails to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + LoadParamsFromBuffer(dev_ctx, place, &fin, load_as_fp16, out); + } else { + PADDLE_ENFORCE_NE( + filename.empty(), + true, + common::errors::Unavailable( + "LoadCombine operator fails to open file %s, please check " + "whether the model file is complete or damaged.", + filename)); + std::stringstream fin(filename, std::ios::in | std::ios::binary); + LoadParamsFromBuffer(dev_ctx, place, &fin, load_as_fp16, out); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/xpu/load_combine_kernel.cc b/paddle/phi/kernels/xpu/load_combine_kernel.cc new file mode 100644 index 00000000000000..2e673c86447aec --- /dev/null +++ b/paddle/phi/kernels/xpu/load_combine_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/load_combine_kernel_impl.h" + +PD_REGISTER_KERNEL(load_combine, + XPU, + ALL_LAYOUT, + phi::LoadCombineKernel, + float, + double, + int, + int8_t, + int64_t) {} + +PD_REGISTER_KERNEL(load_combine_vocab, + XPU, + ALL_LAYOUT, + phi::LoadCombineVocabKernel, + float, + double, + int, + int8_t, + int64_t) {} + +PD_REGISTER_KERNEL(load_combine_extended, + XPU, + ALL_LAYOUT, + phi::LoadCombineExtendedKernel, + float, + double, + int, + int8_t, + int64_t) {}