|
22 | 22 |
|
23 | 23 | from paddle.nn import Linear, Dropout, LayerNorm, LayerList, Layer
|
24 | 24 | from paddle import ParamAttr
|
| 25 | +import paddlenlp |
25 | 26 |
|
26 | 27 |
|
27 | 28 | class Registry(object):
|
@@ -276,14 +277,13 @@ def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
|
276 | 277 | [B, L - G, bs, 1])
|
277 | 278 | temp_key_mask_front = paddle.reshape(blocked_key_mask[:, :GF],
|
278 | 279 | [B, 1, 1, GF * bs])
|
279 |
| - global_block_mask_front = paddle.matmul(temp_query_mask, |
280 |
| - temp_key_mask_front) |
| 280 | + global_block_mask_front = paddlenlp.ops.einsum( |
| 281 | + "blqd,bmdk->blqk", temp_query_mask, temp_key_mask_front) |
281 | 282 |
|
282 | 283 | temp_key_mask_back = paddle.reshape(blocked_key_mask[:, -GB:],
|
283 | 284 | [B, 1, 1, GB * bs])
|
284 |
| - global_block_mask_back = paddle.matmul(temp_query_mask, |
285 |
| - temp_key_mask_back) |
286 |
| - |
| 285 | + global_block_mask_back = paddlenlp.ops.einsum( |
| 286 | + "blqd,bmdk->blqk", temp_query_mask, temp_key_mask_back) |
287 | 287 | # create window block mask
|
288 | 288 | key_mask_list = []
|
289 | 289 | for query_block_id in range(GF, GF + W // 2):
|
@@ -326,8 +326,8 @@ def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size,
|
326 | 326 | [roll_key_mask1, window_key_mask, roll_key_mask2], axis=1)
|
327 | 327 | window_key_mask = paddle.unsqueeze(window_key_mask, axis=2)
|
328 | 328 | # [B, L-G, bs, 1] * [B, L-G, 1, W*bs] -> [B, L-G, bs, W*bs]
|
329 |
| - window_block_mask = paddle.matmul(temp_query_mask, window_key_mask) |
330 |
| - |
| 329 | + window_block_mask = paddlenlp.ops.einsum( |
| 330 | + "blkd,bldq->blkq", temp_query_mask, window_key_mask) |
331 | 331 | band_mask = paddle.concat(
|
332 | 332 | [
|
333 | 333 | global_block_mask_front, window_block_mask,
|
@@ -435,17 +435,13 @@ def _get_rand_mask(self, blocked_query_mask, blocked_key_mask,
|
435 | 435 | for b in range(B)
|
436 | 436 | ]
|
437 | 437 | temp_block_key_mask = paddle.concat(temp_block_key_mask_list, 0)
|
438 |
| - temp_block_key_mask = paddle.reshape(temp_block_key_mask, |
439 |
| - [B, H, L - G, 1, R * bs]) |
440 |
| - |
441 |
| - temp_blocked_query_mask = paddle.unsqueeze( |
442 |
| - blocked_query_mask[:, GF:-GB], 1) |
443 |
| - temp_blocked_query_mask = paddle.expand(temp_blocked_query_mask, |
444 |
| - [B, H, L - G, -1]) |
445 |
| - temp_blocked_query_mask = paddle.reshape(temp_blocked_query_mask, |
446 |
| - [B, H, L - G, bs, 1]) |
447 |
| - |
448 |
| - rand_mask = paddle.matmul(temp_blocked_query_mask, temp_block_key_mask) |
| 438 | + temp_block_key_mask = paddle.reshape(temp_block_key_mask, [ |
| 439 | + B, temp_block_key_mask.shape[0] // B // (L - GF - GB) // R, |
| 440 | + L - GF - GB, -1 |
| 441 | + ]) |
| 442 | + rand_mask = paddlenlp.ops.einsum("blq,bhlk->bhlqk", |
| 443 | + blocked_query_mask[:, GF:-GB], |
| 444 | + temp_block_key_mask) |
449 | 445 | return rand_mask
|
450 | 446 |
|
451 | 447 | def _gather_random_key_value(self, blocked_matrix, rand_mask_idx, B, T):
|
@@ -575,35 +571,38 @@ def forward(self,
|
575 | 571 | [band_value_matrix, random_values], axis=3)
|
576 | 572 | second_top_value_matrix, second_middle_value_matrix, second_bottom_value_matrix = \
|
577 | 573 | self._get_splited_matrix(second_value_matrix)
|
578 |
| - |
579 |
| - second_product = paddle.matmul( |
580 |
| - second_query_matrix, second_key_matrix, transpose_y=True) |
| 574 | + second_product = paddlenlp.ops.einsum( |
| 575 | + "bhlqd,bhlkd->bhlqk", second_query_matrix, second_key_matrix) |
581 | 576 | second_product = second_product * (d_head**-0.5)
|
582 | 577 | second_product += (1 - second_mask) * -1e6
|
583 | 578 | second_weights = F.softmax(second_product)
|
584 | 579 |
|
585 | 580 | second_top_weights, second_middle_weights, second_bottom_weights = \
|
586 | 581 | self._get_splited_matrix(second_weights)
|
587 |
| - second_top_out = paddle.matmul(second_top_weights, |
588 |
| - second_top_value_matrix) |
| 582 | + second_top_out = paddlenlp.ops.einsum( |
| 583 | + "bhlqk,bhlkd->bhlqd", second_top_weights, second_top_value_matrix) |
589 | 584 |
|
590 |
| - second_middle_out = paddle.matmul( |
| 585 | + second_middle_out = paddlenlp.ops.einsum( |
| 586 | + "bhlqk,bhlkd->bhlqd", |
591 | 587 | second_middle_weights[:, :, :, :, GF * bs:-(GB + R) * bs],
|
592 | 588 | second_middle_value_matrix[:, :, :, GF * bs:-(GB + R) * bs])
|
593 | 589 | # add global block attention
|
594 |
| - second_middle_out += paddle.matmul( |
595 |
| - second_middle_weights[:, :, :, :, :GF * bs], |
596 |
| - blocked_value_matrix[:, :, 0:GF]) |
597 |
| - second_middle_out += paddle.matmul( |
| 590 | + second_middle_out += paddlenlp.ops.einsum( |
| 591 | + "bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, :GF * bs], |
| 592 | + blocked_value_matrix[:, :, 0]) |
| 593 | + second_middle_out += paddlenlp.ops.einsum( |
| 594 | + "bhlqk,bhkd->bhlqd", |
598 | 595 | second_middle_weights[:, :, :, :, -(GB + R) * bs:-R * bs],
|
599 |
| - blocked_value_matrix[:, :, -GB:]) |
| 596 | + blocked_value_matrix[:, :, -GB]) |
600 | 597 | # add random block attention
|
601 |
| - second_middle_out += paddle.matmul( |
602 |
| - second_middle_weights[:, :, :, :, -R * bs:], |
| 598 | + second_middle_out += paddlenlp.ops.einsum( |
| 599 | + "...qk,...kd->...qd", second_middle_weights[:, :, :, :, -R * bs:], |
603 | 600 | random_values[:, :, GF:-GB])
|
604 | 601 |
|
605 |
| - second_bottom_out = paddle.matmul(second_bottom_weights, |
606 |
| - second_bottom_value_matrix) |
| 602 | + second_bottom_out = paddlenlp.ops.einsum("bhlqk,bhlkd->bhlqd", |
| 603 | + second_bottom_weights, |
| 604 | + second_bottom_value_matrix) |
| 605 | + |
607 | 606 | second_out = paddle.concat(
|
608 | 607 | [second_top_out, second_middle_out, second_bottom_out], axis=2)
|
609 | 608 | second_out = paddle.reshape(second_out, [B, H, (L - G) * bs, -1])
|
|
0 commit comments