Skip to content

Commit be4e135

Browse files
committed
modify GQA decompose logic to fit gpu unsqueeze_broadcast_reshape_sdpa_fusion pattern for MQA model performance
1 parent 2739b89 commit be4e135

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,30 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
134134
// Broadcast KV if grouped query attention
135135
const size_t kv_num_heads_factor = num_heads / kv_num_heads;
136136
if (kv_num_heads_factor > 1) {
137+
auto K_unsqueeze = register_new_node<v0::Unsqueeze>(K, two);
138+
auto V_unsqueeze = register_new_node<v0::Unsqueeze>(V, two);
139+
137140
const auto kv_shape = register_new_node<v3::ShapeOf>(K);
138141
const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1});
139142
const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3});
140-
auto new_kv_shape = register_new_node<v0::Concat>(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0);
141-
K = register_new_node<v1::Reshape>(K, new_kv_shape, false);
142-
V = register_new_node<v1::Reshape>(V, new_kv_shape, false);
143-
K = register_new_node<v0::Concat>(ov::OutputVector(kv_num_heads_factor, K), 2);
144-
V = register_new_node<v0::Concat>(ov::OutputVector(kv_num_heads_factor, V), 2);
143+
const auto kv_num_heads_factor_const = register_new_node(
144+
v0::Constant::create(ov::element::i64,
145+
ov::Shape{1},
146+
{Q.get_partial_shape()[1].get_length() / K.get_partial_shape()[1].get_length()}));
147+
auto new_kv_shape =
148+
register_new_node<v0::Concat>(ov::NodeVector{kv_shape_prev_2, kv_num_heads_factor_const, kv_shape_last_2},
149+
0);
150+
151+
auto K_broadcast =
152+
register_new_node<v3::Broadcast>(K_unsqueeze, new_kv_shape, ov::op::BroadcastType::BIDIRECTIONAL);
153+
auto V_broadcast =
154+
register_new_node<v3::Broadcast>(V_unsqueeze, new_kv_shape, ov::op::BroadcastType::BIDIRECTIONAL);
155+
145156
const auto q_shape = register_new_node<v3::ShapeOf>(Q);
146157
const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1});
147158
auto extended_kv_shape = register_new_node<v0::Concat>(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0);
148-
K = register_new_node<v1::Reshape>(K, extended_kv_shape, false);
149-
V = register_new_node<v1::Reshape>(V, extended_kv_shape, false);
159+
K = register_new_node<v1::Reshape>(K_broadcast, extended_kv_shape, false);
160+
V = register_new_node<v1::Reshape>(V_broadcast, extended_kv_shape, false);
150161
}
151162

152163
// Make attention mask

0 commit comments

Comments
 (0)