Commits


aciddelgado authored and GitHub committed 406cd324e0f
[CUDA] GroupQueryAttention operator using FlashAttention (#17674) ### Description Added Group Query Attention op, supporting integer multiple number of heads for Q / KV. As of now, this op can only use FlashAttention kernel, meaning it only supports sm>=80 on Linux. Results from onnxruntime/test/python/transformers/benchmark_gqa.py show an on-average ~37% speed-up over Decoder Masked Multi-Head Attention, with even greater improvements for long past sequence lengths. ``` op batch s_kv heads h_dim ms TFLOPS gqa 16 2048 8 32 0.34 0.10 dmmha 16 2048 8 32 0.39 0.09 --------- gqa 16 2048 8 64 0.45 0.15 dmmha 16 2048 8 64 0.61 0.11 --------- gqa 16 2048 8 128 0.54 0.25 dmmha 16 2048 8 128 0.83 0.16 --------- gqa 16 2048 16 32 0.45 0.15 dmmha 16 2048 16 32 0.69 0.10 --------- gqa 16 2048 16 64 0.69 0.19 dmmha 16 2048 16 64 0.83 0.16 --------- gqa 16 2048 16 128 0.71 0.38 dmmha 16 2048 16 128 1.28 0.21 --------- gqa 16 2048 32 32 0.58 0.23 dmmha 16 2048 32 32 0.77 0.17 --------- gqa 16 2048 32 64 0.58 0.46 dmmha 16 2048 32 64 1.25 0.21 --------- gqa 16 2048 32 128 0.76 0.71 dmmha 16 2048 32 128 2.15 0.25 --------- gqa 16 2048 64 32 0.68 0.39 dmmha 16 2048 64 32 1.23 0.22 --------- gqa 16 2048 64 64 0.77 0.70 dmmha 16 2048 64 64 2.11 0.25 --------- gqa 16 2048 64 128 1.10 0.97 dmmha 16 2048 64 128 4.06 0.26 --------- gqa 16 2048 128 32 1.00 0.54 dmmha 16 2048 128 32 2.09 0.26 --------- gqa 16 2048 128 64 1.10 0.97 dmmha 16 2048 128 64 4.08 0.26 ``` ### Motivation and Context As of now, this op is targeted for use on LLama models, as it supports kv-caching and different number of heads for Q and KV (Grouped Query Attention). We plan to add support for more platforms, input formats, etc. in the future. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>