Skip to content

Commit a7b2356

Browse files
authored
fix flash attn mm (#3257)
1 parent f064291 commit a7b2356

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ppcls/arch/backbone/legendary_models/mobilenet_v4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,10 +557,10 @@ def forward(self, x, attn_mask=None):
557557
else:
558558
q = q.transpose([0, 2, 1, 3]) * self.scale
559559
v = v.transpose([0, 2, 1, 3])
560-
attn = paddle.mm(q, k.transpose([0, 2, 3, 1]))
560+
attn = q @ k.transpose([0, 2, 3, 1])
561561
attn = self.softmax(attn)
562562
attn = self.attn_drop(attn)
563-
attn = paddle.mm(attn, v)
563+
attn = attn @ v
564564
attn = attn.transpose([0, 2, 1, 3])
565565

566566
attn = attn.reshape([B, H, W, self.query_dim])

0 commit comments

Comments
 (0)