Skip to content

Commit a79c9cb

Browse files
joey12300ZeyuChen
andauthored
Eval einsum in bigbird (PaddlePaddle#314)
* matmul->einsum * fix rand_mask_idx_list Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
1 parent 04ab7af commit a79c9cb

File tree

2 files changed

+33
-34
lines changed

2 files changed

+33
-34
lines changed

examples/language_model/bigbird/run_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def do_evalute(model, criterion, metric, test_data_loader):
186186
global_steps += 1
187187
input_ids, labels = batch[:2]
188188
rand_mask_idx_list = batch[2:]
189-
output = model(input_ids, None, rand_mask_idx_list)
189+
output = model(input_ids, rand_mask_idx_list=rand_mask_idx_list)
190190
loss = criterion(output, labels)
191191
correct = metric.compute(output, labels)
192192
metric.update(correct)

paddlenlp/transformers/attention_utils.py

+32-33
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from paddle.nn import Linear, Dropout, LayerNorm, LayerList, Layer
2424
from paddle import ParamAttr
25+
import paddlenlp
2526

2627

2728
class Registry(object):
@@ -276,14 +277,13 @@ def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
276277
[B, L - G, bs, 1])
277278
temp_key_mask_front = paddle.reshape(blocked_key_mask[:, :GF],
278279
[B, 1, 1, GF * bs])
279-
global_block_mask_front = paddle.matmul(temp_query_mask,
280-
temp_key_mask_front)
280+
global_block_mask_front = paddlenlp.ops.einsum(
281+
"blqd,bmdk->blqk", temp_query_mask, temp_key_mask_front)
281282

282283
temp_key_mask_back = paddle.reshape(blocked_key_mask[:, -GB:],
283284
[B, 1, 1, GB * bs])
284-
global_block_mask_back = paddle.matmul(temp_query_mask,
285-
temp_key_mask_back)
286-
285+
global_block_mask_back = paddlenlp.ops.einsum(
286+
"blqd,bmdk->blqk", temp_query_mask, temp_key_mask_back)
287287
# create window block mask
288288
key_mask_list = []
289289
for query_block_id in range(GF, GF + W // 2):
@@ -326,8 +326,8 @@ def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
326326
[roll_key_mask1, window_key_mask, roll_key_mask2], axis=1)
327327
window_key_mask = paddle.unsqueeze(window_key_mask, axis=2)
328328
# [B, L-G, bs, 1] * [B, L-G, 1, W*bs] -> [B, L-G, bs, W*bs]
329-
window_block_mask = paddle.matmul(temp_query_mask, window_key_mask)
330-
329+
window_block_mask = paddlenlp.ops.einsum(
330+
"blkd,bldq->blkq", temp_query_mask, window_key_mask)
331331
band_mask = paddle.concat(
332332
[
333333
global_block_mask_front, window_block_mask,
@@ -435,17 +435,13 @@ def _get_rand_mask(self, blocked_query_mask, blocked_key_mask,
435435
for b in range(B)
436436
]
437437
temp_block_key_mask = paddle.concat(temp_block_key_mask_list, 0)
438-
temp_block_key_mask = paddle.reshape(temp_block_key_mask,
439-
[B, H, L - G, 1, R * bs])
440-
441-
temp_blocked_query_mask = paddle.unsqueeze(
442-
blocked_query_mask[:, GF:-GB], 1)
443-
temp_blocked_query_mask = paddle.expand(temp_blocked_query_mask,
444-
[B, H, L - G, -1])
445-
temp_blocked_query_mask = paddle.reshape(temp_blocked_query_mask,
446-
[B, H, L - G, bs, 1])
447-
448-
rand_mask = paddle.matmul(temp_blocked_query_mask, temp_block_key_mask)
438+
temp_block_key_mask = paddle.reshape(temp_block_key_mask, [
439+
B, temp_block_key_mask.shape[0] // B // (L - GF - GB) // R,
440+
L - GF - GB, -1
441+
])
442+
rand_mask = paddlenlp.ops.einsum("blq,bhlk->bhlqk",
443+
blocked_query_mask[:, GF:-GB],
444+
temp_block_key_mask)
449445
return rand_mask
450446

451447
def _gather_random_key_value(self, blocked_matrix, rand_mask_idx, B, T):
@@ -575,35 +571,38 @@ def forward(self,
575571
[band_value_matrix, random_values], axis=3)
576572
second_top_value_matrix, second_middle_value_matrix, second_bottom_value_matrix = \
577573
self._get_splited_matrix(second_value_matrix)
578-
579-
second_product = paddle.matmul(
580-
second_query_matrix, second_key_matrix, transpose_y=True)
574+
second_product = paddlenlp.ops.einsum(
575+
"bhlqd,bhlkd->bhlqk", second_query_matrix, second_key_matrix)
581576
second_product = second_product * (d_head**-0.5)
582577
second_product += (1 - second_mask) * -1e6
583578
second_weights = F.softmax(second_product)
584579

585580
second_top_weights, second_middle_weights, second_bottom_weights = \
586581
self._get_splited_matrix(second_weights)
587-
second_top_out = paddle.matmul(second_top_weights,
588-
second_top_value_matrix)
582+
second_top_out = paddlenlp.ops.einsum(
583+
"bhlqk,bhlkd->bhlqd", second_top_weights, second_top_value_matrix)
589584

590-
second_middle_out = paddle.matmul(
585+
second_middle_out = paddlenlp.ops.einsum(
586+
"bhlqk,bhlkd->bhlqd",
591587
second_middle_weights[:, :, :, :, GF * bs:-(GB + R) * bs],
592588
second_middle_value_matrix[:, :, :, GF * bs:-(GB + R) * bs])
593589
# add global block attention
594-
second_middle_out += paddle.matmul(
595-
second_middle_weights[:, :, :, :, :GF * bs],
596-
blocked_value_matrix[:, :, 0:GF])
597-
second_middle_out += paddle.matmul(
590+
second_middle_out += paddlenlp.ops.einsum(
591+
"bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, :GF * bs],
592+
blocked_value_matrix[:, :, 0])
593+
second_middle_out += paddlenlp.ops.einsum(
594+
"bhlqk,bhkd->bhlqd",
598595
second_middle_weights[:, :, :, :, -(GB + R) * bs:-R * bs],
599-
blocked_value_matrix[:, :, -GB:])
596+
blocked_value_matrix[:, :, -GB])
600597
# add random block attention
601-
second_middle_out += paddle.matmul(
602-
second_middle_weights[:, :, :, :, -R * bs:],
598+
second_middle_out += paddlenlp.ops.einsum(
599+
"...qk,...kd->...qd", second_middle_weights[:, :, :, :, -R * bs:],
603600
random_values[:, :, GF:-GB])
604601

605-
second_bottom_out = paddle.matmul(second_bottom_weights,
606-
second_bottom_value_matrix)
602+
second_bottom_out = paddlenlp.ops.einsum("bhlqk,bhlkd->bhlqd",
603+
second_bottom_weights,
604+
second_bottom_value_matrix)
605+
607606
second_out = paddle.concat(
608607
[second_top_out, second_middle_out, second_bottom_out], axis=2)
609608
second_out = paddle.reshape(second_out, [B, H, (L - G) * bs, -1])

0 commit comments

Comments
 (0)