Skip to content

Commit c1b8092

Browse files
【pir_save_load】Jit layer (#66185)
* add all * add notes * add jit.layer part * split pir.h * recover * add pir_utils * modify link bug * modify link bug * modify build bug
1 parent 483d9d9 commit c1b8092

13 files changed

+224
-55
lines changed

paddle/fluid/jit/function_schema.cc

+48-21
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
#include "paddle/fluid/jit/function_schema.h"
1616

1717
#include "paddle/fluid/framework/program_desc.h"
18+
#include "paddle/fluid/pybind/pir_utils.h"
1819
#include "paddle/phi/core/enforce.h"
20+
#include "paddle/pir/include/core/program.h"
1921

2022
#include "paddle/fluid/jit/function_utils.h"
2123
namespace paddle::jit {
@@ -51,10 +53,37 @@ void FunctionSchema::AddOutputArg(const std::string& name) {
5153
output_args.emplace_back(name, true);
5254
}
5355

56+
/* base function info*/
57+
BaseFunctionInfo::BaseFunctionInfo(const std::string& func_name,
58+
const std::vector<std::string>& param_names)
59+
: func_name_(func_name), param_names_(param_names) {}
60+
const std::string& BaseFunctionInfo::FunctionName() const { return func_name_; }
61+
62+
const std::vector<std::string>& BaseFunctionInfo::ParamNames() const {
63+
return param_names_;
64+
}
65+
66+
const std::vector<std::string> BaseFunctionInfo::InputArgNames() const {
67+
return schema_.InputArgNames();
68+
}
69+
70+
const std::vector<std::string> BaseFunctionInfo::OutputArgNames() const {
71+
return schema_.OutputArgNames();
72+
}
73+
74+
const std::string& BaseFunctionInfo::ProgramFilePath() const {
75+
return prog_file_path_;
76+
}
77+
78+
void BaseFunctionInfo::SetProgramFilePath(const std::string& path) {
79+
prog_file_path_ = path;
80+
}
81+
82+
/* FunctionInfo */
5483
FunctionInfo::FunctionInfo(const std::string& func_name,
5584
const std::vector<std::string>& param_names,
5685
const framework::ProgramDesc& program_desc)
57-
: func_name_(func_name), param_names_(param_names) {
86+
: BaseFunctionInfo(func_name, param_names) {
5887
program_desc_.reset(new framework::ProgramDesc(program_desc));
5988
// Parse FunctionSchema
6089
for (auto& in_name : program_desc_->GetFeedTargetNames()) {
@@ -65,34 +94,32 @@ FunctionInfo::FunctionInfo(const std::string& func_name,
6594
}
6695
}
6796

68-
const std::string& FunctionInfo::FunctionName() const { return func_name_; }
69-
7097
const framework::ProgramDesc& FunctionInfo::ProgramDesc() const {
7198
return *program_desc_.get(); // NOLINT
7299
}
73100

74-
const std::vector<std::string>& FunctionInfo::ParamNames() const {
75-
return param_names_;
76-
}
77-
78-
const std::vector<std::string> FunctionInfo::InputArgNames() const {
79-
return schema_.InputArgNames();
80-
}
81-
82-
const std::vector<std::string> FunctionInfo::OutputArgNames() const {
83-
return schema_.OutputArgNames();
84-
}
85-
86-
const std::string& FunctionInfo::ProgramFilePath() const {
87-
return prog_file_path_;
101+
void FunctionInfo::RemoveDescFeedFetch() {
102+
utils::RemoveFeedFetch(program_desc_.get());
88103
}
89104

90-
void FunctionInfo::SetProgramFilePath(const std::string& path) {
91-
prog_file_path_ = path;
105+
/* pirFunctionInfo*/
106+
PirFunctionInfo::PirFunctionInfo(const std::string& func_name,
107+
const std::vector<std::string>& param_names,
108+
pir::Program* program)
109+
: BaseFunctionInfo(func_name, param_names) {
110+
program_ = program;
111+
// Parse FunctionSchema
112+
for (auto& in_name : GetFeedTargetNames(program_)) {
113+
schema_.AddInputArg(in_name);
114+
}
115+
for (auto& out_name : GetFetchTargetNames(program_)) {
116+
schema_.AddOutputArg(out_name);
117+
}
92118
}
93119

94-
void FunctionInfo::RemoveDescFeedFetch() {
95-
utils::RemoveFeedFetch(program_desc_.get());
120+
pir::Program* PirFunctionInfo::Program() const {
121+
return program_; // NOLINT
96122
}
97123

124+
void PirFunctionInfo::RemoveFeedFetch() { utils::RemoveFeedFetch(program_); }
98125
} // namespace paddle::jit

paddle/fluid/jit/function_schema.h

+37-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
#include <string>
1919
#include <vector>
2020

21+
namespace pir {
22+
class Program;
23+
}
2124
namespace paddle {
2225

2326
namespace framework {
@@ -55,16 +58,14 @@ class FunctionSchema {
5558
std::vector<Argument> input_args;
5659
std::vector<Argument> output_args;
5760
};
58-
59-
class FunctionInfo {
61+
class BaseFunctionInfo {
6062
public:
61-
FunctionInfo(const std::string& func_name,
62-
const std::vector<std::string>& param_names,
63-
const framework::ProgramDesc& program_desc);
63+
BaseFunctionInfo(const std::string& func_name,
64+
const std::vector<std::string>& param_names);
6465

65-
const std::string& FunctionName() const;
66+
virtual ~BaseFunctionInfo() = default;
6667

67-
const framework::ProgramDesc& ProgramDesc() const;
68+
const std::string& FunctionName() const;
6869

6970
const std::vector<std::string>& ParamNames() const;
7071

@@ -76,15 +77,40 @@ class FunctionInfo {
7677

7778
void SetProgramFilePath(const std::string& path);
7879

79-
void RemoveDescFeedFetch();
80-
81-
private:
80+
protected:
8281
std::string func_name_;
8382
std::vector<std::string> param_names_;
84-
std::shared_ptr<framework::ProgramDesc> program_desc_;
8583
FunctionSchema schema_;
8684
std::string prog_file_path_;
8785
};
8886

87+
class FunctionInfo : public BaseFunctionInfo {
88+
public:
89+
FunctionInfo(const std::string& func_name,
90+
const std::vector<std::string>& param_names,
91+
const framework::ProgramDesc& program_desc);
92+
93+
const framework::ProgramDesc& ProgramDesc() const;
94+
95+
void RemoveDescFeedFetch();
96+
97+
private:
98+
std::shared_ptr<framework::ProgramDesc> program_desc_;
99+
};
100+
101+
class PirFunctionInfo : public BaseFunctionInfo {
102+
public:
103+
PirFunctionInfo(const std::string& func_name,
104+
const std::vector<std::string>& param_names,
105+
pir::Program* program);
106+
107+
pir::Program* Program() const;
108+
109+
void RemoveFeedFetch();
110+
111+
private:
112+
pir::Program* program_;
113+
};
114+
89115
} // namespace jit
90116
} // namespace paddle

paddle/fluid/jit/function_utils.cc

+2
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,6 @@ void RemoveFeedFetch(framework::ProgramDesc *program_desc) {
109109
}
110110
}
111111

112+
void RemoveFeedFetch(pir::Program *program) {}
113+
112114
} // namespace paddle::jit::utils

paddle/fluid/jit/function_utils.h

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ void ShareParamsIntoScope(const std::vector<std::string> &param_names,
5656

5757
void RemoveFeedFetch(framework::ProgramDesc *program_desc);
5858

59+
void RemoveFeedFetch(pir::Program *program);
60+
5961
template <typename T>
6062
std::shared_ptr<T> MakeEngine(const std::shared_ptr<FunctionInfo> &info,
6163
const std::shared_ptr<VariableMap> &params_dict,

paddle/fluid/jit/serializer.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ using FunctionInfoMap =
3737

3838
Layer Deserializer::operator()(const std::string& path,
3939
const phi::Place& place) {
40-
const auto& pdmodel_paths = utils::PdmodelFilePaths(path);
40+
const auto& pdmodel_paths = utils::ModelFilePaths(path);
4141
// set is ordered
4242
std::set<std::string> param_names_set;
4343
FunctionInfoMap info_map;

paddle/fluid/jit/serializer_utils.cc

+49-18
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include "paddle/fluid/framework/phi_utils.h"
2121
#include "paddle/fluid/framework/var_desc.h"
2222

23+
#include "paddle/pir/include/core/attribute.h"
24+
#include "paddle/pir/include/core/builtin_attribute.h"
25+
COMMON_DECLARE_bool(enable_pir_api);
2326
namespace paddle::jit::utils {
2427

2528
bool IsPersistable(framework::VarDesc* desc_ptr) {
@@ -33,6 +36,15 @@ bool IsPersistable(framework::VarDesc* desc_ptr) {
3336
return desc_ptr->Persistable();
3437
}
3538

39+
bool IsPersistable(pir::Value* value_ptr) {
40+
auto is_persistable =
41+
value_ptr->attribute<pir::BoolAttribute>(kAttrIsPersistable);
42+
if (is_persistable && is_persistable.data()) {
43+
return true;
44+
}
45+
return false;
46+
}
47+
3648
bool StartsWith(const std::string& str, const std::string& prefix) {
3749
return str.compare(0, prefix.length(), prefix) == 0;
3850
}
@@ -62,9 +74,9 @@ bool FileExists(const std::string& file_path) {
6274
return file.good();
6375
}
6476

65-
const std::vector<std::pair<std::string, std::string>> PdmodelFilePaths(
77+
const std::vector<std::pair<std::string, std::string>> ModelFilePaths(
6678
const std::string& path) {
67-
std::vector<std::pair<std::string, std::string>> pdmodel_paths;
79+
std::vector<std::pair<std::string, std::string>> model_paths;
6880
std::string format_path = path;
6981
ReplaceAll(&format_path, R"(\\)", "/");
7082
ReplaceAll(&format_path, R"(\)", "/");
@@ -80,27 +92,46 @@ const std::vector<std::pair<std::string, std::string>> PdmodelFilePaths(
8092
struct dirent* ptr = nullptr;
8193

8294
while ((ptr = readdir(dir)) != nullptr) {
95+
std::string prefix = "";
8396
std::string file_name = ptr->d_name;
84-
85-
if (StartsWith(file_name, layer_name) &&
86-
EndsWith(file_name, PDMODEL_SUFFIX)) {
87-
std::string prefix = file_name.substr(
88-
0, file_name.length() - std::string(PDMODEL_SUFFIX).length());
89-
90-
if (prefix == layer_name) {
91-
pdmodel_paths.emplace_back(
92-
std::make_pair("forward", dir_path + file_name));
93-
} else {
94-
std::string func_name = prefix.substr(layer_name.size() + 1);
95-
pdmodel_paths.emplace_back(
96-
std::make_pair(func_name, dir_path + file_name));
97+
if (FLAGS_enable_pir_api) {
98+
if (StartsWith(file_name, layer_name) &&
99+
EndsWith(file_name, JSON_SUFFIX)) {
100+
std::string prefix = file_name.substr(
101+
0, file_name.length() - std::string(JSON_SUFFIX).length());
102+
103+
if (prefix == layer_name) {
104+
model_paths.emplace_back(
105+
std::make_pair("forward", dir_path + file_name));
106+
} else {
107+
std::string func_name = prefix.substr(layer_name.size() + 1);
108+
model_paths.emplace_back(
109+
std::make_pair(func_name, dir_path + file_name));
110+
}
111+
VLOG(3) << "func_name: " << model_paths.back().first
112+
<< ", path:" << dir_path + file_name;
113+
}
114+
} else {
115+
if (StartsWith(file_name, layer_name) &&
116+
EndsWith(file_name, PDMODEL_SUFFIX)) {
117+
std::string prefix = file_name.substr(
118+
0, file_name.length() - std::string(PDMODEL_SUFFIX).length());
119+
120+
if (prefix == layer_name) {
121+
model_paths.emplace_back(
122+
std::make_pair("forward", dir_path + file_name));
123+
} else {
124+
std::string func_name = prefix.substr(layer_name.size() + 1);
125+
model_paths.emplace_back(
126+
std::make_pair(func_name, dir_path + file_name));
127+
}
128+
VLOG(3) << "func_name: " << model_paths.back().first
129+
<< ", path:" << dir_path + file_name;
97130
}
98-
VLOG(3) << "func_name: " << pdmodel_paths.back().first
99-
<< ", path:" << dir_path + file_name;
100131
}
101132
}
102133
closedir(dir);
103-
return pdmodel_paths;
134+
return model_paths;
104135
}
105136

106137
void InitKernelSignatureMap() {

paddle/fluid/jit/serializer_utils.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#include <string>
1818
#include <vector>
19-
19+
namespace pir {
20+
class Value;
21+
}
2022
namespace paddle {
2123

2224
namespace framework {
@@ -25,11 +27,13 @@ class VarDesc;
2527

2628
namespace jit {
2729
static const char PDMODEL_SUFFIX[] = ".pdmodel";
30+
static const char JSON_SUFFIX[] = ".json";
2831
static const char PDPARAMS_SUFFIX[] = ".pdiparams";
2932
static const char PROPERTY_SUFFIX[] = ".meta";
3033

3134
namespace utils {
3235
bool IsPersistable(framework::VarDesc* desc_ptr);
36+
bool IsPersistable(pir::Value* value_ptr);
3337

3438
bool StartsWith(const std::string& str, const std::string& suffix);
3539

@@ -41,7 +45,7 @@ void ReplaceAll(std::string* str,
4145

4246
bool FileExists(const std::string& file_path);
4347

44-
const std::vector<std::pair<std::string, std::string>> PdmodelFilePaths(
48+
const std::vector<std::pair<std::string, std::string>> ModelFilePaths(
4549
const std::string& path);
4650

4751
void InitKernelSignatureMap();

paddle/fluid/pybind/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ set(PYBIND_SRCS
123123
inference_api.cc
124124
control_flow_api.cc
125125
pir.cc
126+
pir_utils.cc
126127
graph.cc
127128
bind_fleet_executor.cc
128129
reader_py.cc

paddle/fluid/pybind/pir.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/fluid/pybind/pir.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -20,12 +20,14 @@
2020
#include "paddle/pir/include/core/value.h"
2121

2222
namespace paddle {
23+
2324
namespace pybind {
2425
using pir::Value;
2526
void BindPir(pybind11::module *m);
2627
const phi::DDim &GetValueDims(Value value);
2728
bool GetValueBoolAttr(Value value, const std::string &attr_name);
2829
std::string GetValueName(Value value);
2930
bool HasValueName(const Value &value);
31+
3032
} // namespace pybind
3133
} // namespace paddle

0 commit comments

Comments
 (0)