Skip to content

Commit b39b2b7

Browse files
committed
modify base functions
1 parent fa3275b commit b39b2b7

File tree

5 files changed

+81
-2
lines changed

5 files changed

+81
-2
lines changed

paddle/phi/core/attribute.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@ namespace phi {
3030
class Place;
3131

3232
// NOTE: Add needed type in the future
33+
// Move vector<int> before vector<bool>, because when
34+
// vector<bool> is before vector<int>, a python integer
35+
// list will be converted to vector<bool> in error.
3336
using Attribute = paddle::variant<bool,
3437
int,
3538
int64_t,
3639
float,
3740
double,
3841
std::string,
39-
std::vector<bool>,
4042
std::vector<int>,
43+
std::vector<bool>,
4144
std::vector<int64_t>,
4245
std::vector<float>,
4346
std::vector<double>,

paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ AttrType InferSpmdContext::AttrAt(size_t idx) const {
5454
}
5555

5656
template <>
57-
bool InferSpmdContext::AttrAt<bool>(size_t idx) const {
57+
bool InferSpmdContext::AttrAt(size_t idx) const {
5858
try {
5959
auto attr = attrs_.at(idx);
6060
if (attr.type() == typeid(int)) {
@@ -70,6 +70,28 @@ bool InferSpmdContext::AttrAt<bool>(size_t idx) const {
7070
}
7171
}
7272

73+
template <>
74+
std::vector<int> InferSpmdContext::AttrAt(size_t idx) const {
75+
try {
76+
auto attr = attrs_.at(idx);
77+
if (attr.type() == typeid(std::vector<bool>)) {
78+
std::vector<bool> val = PADDLE_GET_CONST(std::vector<bool>, attr);
79+
return std::vector<int>(val.begin(), val.end());
80+
} else {
81+
return paddle::get<std::vector<int>>(attr);
82+
}
83+
} catch (paddle::bad_variant_access const& e) {
84+
PADDLE_THROW(phi::errors::InvalidArgument(
85+
"Attribute cast error in InferSpmd Context, the input attr type is "
86+
"`%s`, but the expected attribute type is `bool`.",
87+
attrs_.at(idx).type().name()));
88+
}
89+
}
90+
91+
// template const std::vector<int64_t>& InferSpmdContext::AttrAt(size_t idx)
92+
// const; template const std::vector<int>& InferSpmdContext::AttrAt(size_t idx)
93+
// const;
94+
7395
const Attribute& InferSpmdContext::AttrAt(size_t idx) const {
7496
return attrs_.at(idx);
7597
}

paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,24 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {
138138
} \
139139
}
140140

141+
#define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
142+
template <typename... Tail> \
143+
struct InferSpmdFnCallHelper<const attr_type&, Tail...> { \
144+
template <int in_idx, int attr_idx, typename... PreviousArgs> \
145+
static SpmdInfo Call(const InferSpmdContext& ctx, \
146+
PreviousArgs&... pargs) { \
147+
attr_type arg = ctx.AttrAt<attr_type>(attr_idx); \
148+
return InferSpmdFnCallHelper<Tail...>::template Call<in_idx, \
149+
attr_idx + 1>( \
150+
ctx, pargs..., arg); \
151+
} \
152+
}
153+
141154
// TODO(chenweihang): support other attr type later as needed
142155
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(bool);
156+
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
157+
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
158+
std::vector<int64_t>);
143159

144160
/* End case */
145161
template <typename T>

paddle/phi/infermeta/spmd_rules/utils.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,31 @@ std::vector<int64_t> ResoluteOutputPartialDimension(
157157
return partial_on_dims;
158158
}
159159

160+
std::vector<int64_t> GetDimsMappingForAxes(
161+
const std::string& axes,
162+
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
163+
const bool unsharded_miss_axis) {
164+
std::vector<int64_t> dims_mapping;
165+
for (int64_t i = 0, n = static_cast<int64_t>(axes.size()); i < n; i++) {
166+
std::string axis = axes.substr(i, 1);
167+
if (axis == "1") {
168+
dims_mapping.emplace_back(-1);
169+
} else {
170+
auto iter = axis_to_dim_map.find(axis);
171+
if (iter == axis_to_dim_map.end()) {
172+
if (unsharded_miss_axis) {
173+
dims_mapping.emplace_back(-1);
174+
} else {
175+
phi::errors::InvalidArgument(
176+
"Tensor axis [%s] of not in axis_to_dim_map.", axis);
177+
}
178+
} else {
179+
dims_mapping.emplace_back(iter->second);
180+
}
181+
}
182+
}
183+
return dims_mapping;
184+
}
185+
160186
} // namespace distributed
161187
} // namespace phi

paddle/phi/infermeta/spmd_rules/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,17 @@ struct PhiSpmdVariadicArgumentParser
134134
}
135135
};
136136
} // namespace detail
137+
138+
// Get dims mapping for the given axes according to sharding information of
139+
// the annotated axes after inferring forward or backward. The parameter axis
140+
// stores the axes of the tensor. "1" is a special axis, for the axis "1", set
141+
// its dims mapping to -1.
142+
// if unsharded_miss_axis, "-1" is assigend to axes that has no key in
143+
// axis_to_dim_map.
144+
std::vector<int64_t> GetDimsMappingForAxes(
145+
const std::string& axes,
146+
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
147+
const bool unsharded_miss_axis = false);
148+
137149
} // namespace distributed
138150
} // namespace phi

0 commit comments

Comments
 (0)