-
Notifications
You must be signed in to change notification settings - Fork 482
Add Multi-Head Attention support for Vitis #1163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Multi-Head Attention support for Vitis #1163
Conversation
…nto transformer
Thank you so much for merging it to the main! |
pre-commit.ci autofix |
Hi @rianbrooksflynn! Great work on the Multi-Head Attention implementation. Could you consider adding usage examples (e.g., The examples could demonstrate the important PyTorch requirements you mentioned (batch_first=True, channels_last_conversion='off', same key/value inputs) and basic Keras usage. Thanks! |
As far as I can tell, masking (e.g causal masking) is not supported in this. Would it be ok if I build on top of this PR and add it? |
Description
This PR adds support for Multi-Head Attention using either Keras or PyTorch with the Vitis backend in
io_parallel
mode.Tests have been added for both Keras and Pytorch parsing.
Credit is due to @Ethan0Jiang and @LostEcho365 (Zhixing Jiang and Dennis Yin) for their original implementation and Keras parsing support; my contributions were implementing PyTorch support and adding unit tests. (Here's a link to their pre-print.) The original code authors have given permission for their code to be merged into hls4ml.
There are some important notes for PyTorch (TODO: add documentation to this effect):
batch_first=True
when instantiatingnn.MultiheadAttention
so that the inputs match up ((batch_size, seq_len, embed_dim)
instead of(seq_len, batch_size, embed_dim)
).channels_last_conversion='off'
when callingconfig_from_pytorch_model()
since batch-first PyTorch and Keras use the same input shape.query
,key
, andvalue
; hls4ml currently only supports the case wherekey
andvalue
are the same; thus, you must give PyTorch the same data for the second input and the third input.Type of change
Tests
Two unit tests added:
test/pytest/test_multiheadattention.py
andtest/pytest/test_multiheadattention_pytorch.py
Checklist
pre-commit
on the files I edited or added.