Skip to content

Commit 4c1dfa8

Browse files
committed
Fix
1 parent 2d6a9d7 commit 4c1dfa8

File tree

2 files changed

+11
-60
lines changed

2 files changed

+11
-60
lines changed

paddle/fluid/framework/feed_fetch_type.h

+1-29
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16-
17-
#include <vector>
18-
19-
#include "paddle/fluid/framework/lod_tensor.h"
20-
#include "paddle/fluid/framework/lod_tensor_array.h"
21-
#include "paddle/fluid/framework/string_array.h"
22-
#include "paddle/phi/core/extended_tensor.h"
23-
16+
#include "paddle/phi/core/framework/feed_fetch_type.h "
2417
namespace paddle {
2518
namespace framework {
26-
using FeedType =
27-
paddle::variant<phi::DenseTensor, Strings, phi::SparseCooTensor>;
28-
using FetchType = paddle::variant<phi::DenseTensor,
29-
phi::TensorArray,
30-
framework::Vocab,
31-
phi::SparseCooTensor>;
32-
33-
template <>
34-
struct PhiVectorType<FeedType> {
35-
const char *type_name = "PhiVectorFeedType";
36-
};
37-
38-
template <>
39-
struct PhiVectorType<FetchType> {
40-
const char *type_name = "PhiVectorFetchType";
41-
};
42-
43-
using FeedList = paddle::framework::PhiVector<FeedType>;
44-
using FetchList = paddle::framework::PhiVector<FetchType>;
45-
46-
using FetchUnmergedList = std::vector<std::vector<FetchType>>;
4719

4820
inline bool data_is_lod_tensor(const FetchType &data) {
4921
if (data.type() == typeid(phi::DenseTensor)) {

paddle/phi/core/framework/feed_fetch_type.h

+10-31
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,14 @@ using FeedList = PhiVector<FeedType>;
4242
using FetchList = PhiVector<FetchType>;
4343
using FetchUnmergedList = std::vector<std::vector<FetchType>>;
4444

45-
inline bool data_is_lod_tensor(const FetchType &data) {
46-
if (data.type() == typeid(phi::DenseTensor)) {
47-
return true;
48-
}
49-
return false;
50-
}
51-
52-
inline bool data_is_lod_tensor_array(const FetchType &data) {
53-
if (data.type() == typeid(phi::TensorArray)) {
54-
return true;
55-
}
56-
return false;
57-
}
58-
59-
inline bool data_is_string_tensor(const FeedType &data) {
60-
if (data.type() == typeid(Strings)) {
61-
return true;
62-
}
63-
return false;
64-
}
65-
66-
inline bool data_is_sparse_coo_tensor(const FetchType &data) {
67-
if (data.type() == typeid(phi::SparseCooTensor)) {
68-
return true;
69-
}
70-
return false;
71-
}
72-
73-
static const char kFeedOpType[] = "feed";
74-
static const char kFetchOpType[] = "fetch";
75-
7645
} // namespace phi
46+
47+
namespace paddle {
48+
namespace framework {
49+
using FeedType = phi::FeedType;
50+
using FetchType = phi::FetchType;
51+
using FeedList = phi::FeedList;
52+
using FetchList = phi::FetchList;
53+
using FetchUnmergedList = phi::FetchUnmergedList;
54+
} // namespace framework
55+
} // namespace paddle

0 commit comments

Comments
 (0)