@@ -134,19 +134,30 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
134
134
// Broadcast KV if grouped query attention
135
135
const size_t kv_num_heads_factor = num_heads / kv_num_heads;
136
136
if (kv_num_heads_factor > 1 ) {
137
+ auto K_unsqueeze = register_new_node<v0::Unsqueeze>(K, two);
138
+ auto V_unsqueeze = register_new_node<v0::Unsqueeze>(V, two);
139
+
137
140
const auto kv_shape = register_new_node<v3::ShapeOf>(K);
138
141
const auto kv_shape_prev_2 = get_dimensions (kv_shape, {0 , 1 });
139
142
const auto kv_shape_last_2 = get_dimensions (kv_shape, {2 , 3 });
140
- auto new_kv_shape = register_new_node<v0::Concat>(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0 );
141
- K = register_new_node<v1::Reshape>(K, new_kv_shape, false );
142
- V = register_new_node<v1::Reshape>(V, new_kv_shape, false );
143
- K = register_new_node<v0::Concat>(ov::OutputVector (kv_num_heads_factor, K), 2 );
144
- V = register_new_node<v0::Concat>(ov::OutputVector (kv_num_heads_factor, V), 2 );
143
+ const auto kv_num_heads_factor_const = register_new_node (
144
+ v0::Constant::create (ov::element::i64 ,
145
+ ov::Shape{1 },
146
+ {Q.get_partial_shape ()[1 ].get_length () / K.get_partial_shape ()[1 ].get_length ()}));
147
+ auto new_kv_shape =
148
+ register_new_node<v0::Concat>(ov::NodeVector{kv_shape_prev_2, kv_num_heads_factor_const, kv_shape_last_2},
149
+ 0 );
150
+
151
+ auto K_broadcast =
152
+ register_new_node<v3::Broadcast>(K_unsqueeze, new_kv_shape, ov::op::BroadcastType::BIDIRECTIONAL);
153
+ auto V_broadcast =
154
+ register_new_node<v3::Broadcast>(V_unsqueeze, new_kv_shape, ov::op::BroadcastType::BIDIRECTIONAL);
155
+
145
156
const auto q_shape = register_new_node<v3::ShapeOf>(Q);
146
157
const auto q_shape_prev_2 = get_dimensions (q_shape, {0 , 1 });
147
158
auto extended_kv_shape = register_new_node<v0::Concat>(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0 );
148
- K = register_new_node<v1::Reshape>(K , extended_kv_shape, false );
149
- V = register_new_node<v1::Reshape>(V , extended_kv_shape, false );
159
+ K = register_new_node<v1::Reshape>(K_broadcast , extended_kv_shape, false );
160
+ V = register_new_node<v1::Reshape>(V_broadcast , extended_kv_shape, false );
150
161
}
151
162
152
163
// Make attention mask
0 commit comments