File tree 5 files changed +81
-2
lines changed
distributed/auto_parallel
5 files changed +81
-2
lines changed Original file line number Diff line number Diff line change @@ -30,14 +30,17 @@ namespace phi {
30
30
class Place ;
31
31
32
32
// NOTE: Add needed type in the future
33
+ // Move vector<int> before vector<bool>, because when
34
+ // vector<bool> is before vector<int>, a python integer
35
+ // list will be converted to vector<bool> in error.
33
36
using Attribute = paddle::variant<bool ,
34
37
int ,
35
38
int64_t ,
36
39
float ,
37
40
double ,
38
41
std::string,
39
- std::vector<bool >,
40
42
std::vector<int >,
43
+ std::vector<bool >,
41
44
std::vector<int64_t >,
42
45
std::vector<float >,
43
46
std::vector<double >,
Original file line number Diff line number Diff line change @@ -54,7 +54,7 @@ AttrType InferSpmdContext::AttrAt(size_t idx) const {
54
54
}
55
55
56
56
template <>
57
- bool InferSpmdContext::AttrAt< bool > (size_t idx) const {
57
+ bool InferSpmdContext::AttrAt (size_t idx) const {
58
58
try {
59
59
auto attr = attrs_.at (idx);
60
60
if (attr.type () == typeid (int )) {
@@ -70,6 +70,28 @@ bool InferSpmdContext::AttrAt<bool>(size_t idx) const {
70
70
}
71
71
}
72
72
73
+ template <>
74
+ std::vector<int > InferSpmdContext::AttrAt (size_t idx) const {
75
+ try {
76
+ auto attr = attrs_.at (idx);
77
+ if (attr.type () == typeid (std::vector<bool >)) {
78
+ std::vector<bool > val = PADDLE_GET_CONST (std::vector<bool >, attr);
79
+ return std::vector<int >(val.begin (), val.end ());
80
+ } else {
81
+ return paddle::get<std::vector<int >>(attr);
82
+ }
83
+ } catch (paddle::bad_variant_access const & e) {
84
+ PADDLE_THROW (phi::errors::InvalidArgument (
85
+ " Attribute cast error in InferSpmd Context, the input attr type is "
86
+ " `%s`, but the expected attribute type is `bool`." ,
87
+ attrs_.at (idx).type ().name ()));
88
+ }
89
+ }
90
+
91
+ // template const std::vector<int64_t>& InferSpmdContext::AttrAt(size_t idx)
92
+ // const; template const std::vector<int>& InferSpmdContext::AttrAt(size_t idx)
93
+ // const;
94
+
73
95
const Attribute& InferSpmdContext::AttrAt (size_t idx) const {
74
96
return attrs_.at (idx);
75
97
}
Original file line number Diff line number Diff line change @@ -138,8 +138,24 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {
138
138
} \
139
139
}
140
140
141
+ #define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF (attr_type ) \
142
+ template <typename ... Tail> \
143
+ struct InferSpmdFnCallHelper <const attr_type&, Tail...> { \
144
+ template <int in_idx, int attr_idx, typename ... PreviousArgs> \
145
+ static SpmdInfo Call (const InferSpmdContext& ctx, \
146
+ PreviousArgs&... pargs) { \
147
+ attr_type arg = ctx.AttrAt <attr_type>(attr_idx); \
148
+ return InferSpmdFnCallHelper<Tail...>::template Call<in_idx, \
149
+ attr_idx + 1 >( \
150
+ ctx, pargs..., arg); \
151
+ } \
152
+ }
153
+
141
154
// TODO(chenweihang): support other attr type later as needed
142
155
PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE (bool );
156
+ PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF (std::vector<int >);
157
+ PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF (
158
+ std::vector<int64_t >);
143
159
144
160
/* End case */
145
161
template <typename T>
Original file line number Diff line number Diff line change @@ -157,5 +157,31 @@ std::vector<int64_t> ResoluteOutputPartialDimension(
157
157
return partial_on_dims;
158
158
}
159
159
160
+ std::vector<int64_t > GetDimsMappingForAxes (
161
+ const std::string& axes,
162
+ const std::unordered_map<std::string, int64_t >& axis_to_dim_map,
163
+ const bool unsharded_miss_axis) {
164
+ std::vector<int64_t > dims_mapping;
165
+ for (int64_t i = 0 , n = static_cast <int64_t >(axes.size ()); i < n; i++) {
166
+ std::string axis = axes.substr (i, 1 );
167
+ if (axis == " 1" ) {
168
+ dims_mapping.emplace_back (-1 );
169
+ } else {
170
+ auto iter = axis_to_dim_map.find (axis);
171
+ if (iter == axis_to_dim_map.end ()) {
172
+ if (unsharded_miss_axis) {
173
+ dims_mapping.emplace_back (-1 );
174
+ } else {
175
+ phi::errors::InvalidArgument (
176
+ " Tensor axis [%s] of not in axis_to_dim_map." , axis);
177
+ }
178
+ } else {
179
+ dims_mapping.emplace_back (iter->second );
180
+ }
181
+ }
182
+ }
183
+ return dims_mapping;
184
+ }
185
+
160
186
} // namespace distributed
161
187
} // namespace phi
Original file line number Diff line number Diff line change @@ -134,5 +134,17 @@ struct PhiSpmdVariadicArgumentParser
134
134
}
135
135
};
136
136
} // namespace detail
137
+
138
+ // Get dims mapping for the given axes according to sharding information of
139
+ // the annotated axes after inferring forward or backward. The parameter axis
140
+ // stores the axes of the tensor. "1" is a special axis, for the axis "1", set
141
+ // its dims mapping to -1.
142
+ // if unsharded_miss_axis, "-1" is assigend to axes that has no key in
143
+ // axis_to_dim_map.
144
+ std::vector<int64_t > GetDimsMappingForAxes (
145
+ const std::string& axes,
146
+ const std::unordered_map<std::string, int64_t >& axis_to_dim_map,
147
+ const bool unsharded_miss_axis = false );
148
+
137
149
} // namespace distributed
138
150
} // namespace phi
You can’t perform that action at this time.
0 commit comments