Skip to content

fix assign value op bug #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 7 additions & 23 deletions paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,33 +183,17 @@ class Mapper {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
return parser_->OpHasAttr(op, name);
}
void GetAttr(const std::string &name, int64_t *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, float *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, bool *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::string *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::vector<int64_t> *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::vector<float> *val) {

template<typename T>
void GetAttr(const std::string &name, T* val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::vector<double> *val) {

template<typename T>
void GetScalars(const std::string &name, std::vector<T>* val){
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
parser_->GetOpScalarsAttr(op, name, val);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetAttr 和 GetScalars 感觉都可以用模板简化~

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,我改一下


bool IsConstantInput(const std::string &input_key) const {
Expand Down
63 changes: 55 additions & 8 deletions paddle2onnx/mapper/tensor/assign_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"
#include <unordered_map>
#include <functional>

namespace paddle2onnx {

Expand All @@ -27,21 +28,67 @@ class AssignValueMapper : public Mapper {
: Mapper(p, helper, block_id, op_id) {
GetAttr("dtype", &dtype_);
GetAttr("shape", &shape_);
GetAttrValues();
}
int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
void GetAttrValues(){
int32_t dtype = static_cast<int32_t>(dtype_);
const std::string attr_name = HasAttr("values") ? "values" : GetAttrNameByDtype(dtype);
std::unordered_map<int32_t, std::function<void()>> type_handlers = {
{P2ODataType::INT32, [&](){
if (attr_name == "values") GetScalars(attr_name, &int64_values_);
else if (attr_name == "int32_values") GetAttr(attr_name, &int64_values_);
}},
{P2ODataType::INT64, [&](){
if (attr_name == "values") GetScalars(attr_name, &int64_values_);
else if (attr_name == "int64_values") GetAttr(attr_name, &int64_values_);
}},
{P2ODataType::FP32, [&](){
if (attr_name == "values") GetScalars(attr_name, &fp32_values_);
else if (attr_name == "fp32_values") GetAttr(attr_name, &fp32_values_);
}},
{P2ODataType::FP64, [&](){
if (attr_name == "values") GetScalars(attr_name, &double_values_);
else if (attr_name == "fp32_values") GetAttr(attr_name, &double_values_);
}},
{P2ODataType::BOOL, [&](){
if (attr_name == "values") GetScalars(attr_name, &bool_values_);
else if (attr_name == "bool_values") GetAttr(attr_name, &bool_values_);
}},
};

auto handler = type_handlers.find(dtype);
if (handler != type_handlers.end()) {
handler->second();
} else {
Error() << "Unsupported dtype value" << std::endl;
}
}

std::string GetAttrNameByDtype(int32_t dtype) {
if (dtype == P2ODataType::INT32) {
GetAttr("int32_values", &int64_values_);
} else if (dtype == P2ODataType::FP32) {
GetAttr("fp32_values", &fp32_values_);
return "int32_values";
} else if (dtype == P2ODataType::INT64) {
GetAttr("int64_values", &int64_values_);
return "int64_values";
}else if (dtype == P2ODataType::FP32) {
return "fp32_values";
} else if (dtype == P2ODataType::FP64) {
return "double_values";
} else if (dtype == P2ODataType::BOOL) {
return "bool_values";
}
Error() << "Unsupported dtype value" << std::endl;

}
int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
std::vector<float> fp32_values_;
std::vector<int64_t> int64_values_;
std::vector<bool> bool_values_;
std::vector<double> double_values_;
std::vector<int32_t> int32_values_;
std::vector<int64_t> shape_;
int64_t dtype_;
};
Expand Down
53 changes: 52 additions & 1 deletion paddle2onnx/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
Assert(found, "Cannot found attribute " + name + " in op: " + op.type());
}


void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<double>* res) const {
Expand All @@ -759,7 +760,26 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
}
Assert(found, "Cannot found attribute " + name + " in op: " + op.type());
}

void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<bool>* res) const {
bool found = false;
res->clear();
for (auto i = 0; i < op.attrs_size(); ++i) {
if (op.attrs(i).name() == name) {
found = true;
if (IsAttrVar(op, i)) break;
Assert(op.attrs(i).bools_size() >= 0,
"Cannot find list of double data from attr: " + name + " in op: " +
op.type());
for (auto j = 0; j < op.attrs(i).bools_size(); ++j) {
res->push_back(static_cast<double>(op.attrs(i).bools(j)));
}
break;
}
}
Assert(found, "Cannot found attribute " + name + " in op: " + op.type());
}
void PaddleParser::GetGlobalBlockInputOutputInfo() {
inputs.clear();
outputs.clear();
Expand Down Expand Up @@ -860,4 +880,35 @@ bool PaddleParser::ExistsDumplicateTensorName() const {
}
return false;
}

#define DECLARE_GET_OP_SCALARS(scalar_type, target_type) \
template <> \
void PaddleParser::GetOpScalarsAttr<target_type>(const paddle2onnx::framework::proto::OpDesc& op, \
const std::string& name, \
std::vector<target_type>* res) const { \
bool found = false; \
res->clear(); \
for (auto i = 0; i < op.attrs_size(); ++i) { \
if (op.attrs(i).name() == name) { \
found = true; \
if (IsAttrVar(op, i)) break; \
Assert(op.attrs(i).scalars_size() >= 0, \
"Cannot find list of scalars data from attr: " + name + \
" in op: " + op.type()); \
for (auto j = 0; j < op.attrs(i).scalars_size(); ++j) { \
Assert(op.attrs(i).scalars(j).has_##scalar_type(), \
"Scalar type does not match with " #scalar_type); \
res->push_back(static_cast<target_type>(op.attrs(i).scalars(j).scalar_type())); \
} \
break; \
} \
} \
Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); \
}

DECLARE_GET_OP_SCALARS(i, int64_t)
DECLARE_GET_OP_SCALARS(i, int32_t)
DECLARE_GET_OP_SCALARS(r, float)
DECLARE_GET_OP_SCALARS(r, double)
DECLARE_GET_OP_SCALARS(b, bool)
} // namespace paddle2onnx
7 changes: 7 additions & 0 deletions paddle2onnx/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class PaddleParser {
const std::string& name, std::vector<float>* res) const;
void GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name, std::vector<double>* res) const;
void GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name, std::vector<bool>* res) const;

bool IsConstantTensor(const int64_t& block_idx,
const std::string& tensor_name) const;
Expand All @@ -187,6 +189,11 @@ class PaddleParser {
const std::string& tensor_name,
std::vector<T>* data) const;

template <typename T>
void GetOpScalarsAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<T>* res) const;

private:
// If the model has same output name in difference operators
// will fail to convert
Expand Down