Commits


Tianlei Wu authored and GitHub committed baaef596962
Add sparse attention kernel for H100 (sm90) (#20553) ### Description Follow up of https://github.com/microsoft/onnxruntime/pull/20216 to add sparse attention kernel compiled by Triton for H100 (sm90). - [x] Refine sparse attention v1 kernel compilation (remove some combinations) - [x] compile kernels for v1 kernels - [x] compile kernels for H100 - [x] run performance tests ### Performane Test setting `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has more computation than ORT-SparseAtt, while ORT-GQA-Local has less computation (no vertial strides) than ORT-SparseAtt. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense. Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA H100-80GB-HBM3 GPU (sm=90): ``` prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.079877 0.006362 0.006403 0.042758 1 32.0 0.086920 0.016404 0.016686 0.044183 2 64.0 0.090727 0.020429 0.020409 0.045343 3 128.0 0.128148 0.032009 0.031984 0.051516 4 256.0 0.323933 0.074110 0.073920 0.068308 5 512.0 1.021856 0.162167 0.161951 0.109226 6 1024.0 3.596002 0.452629 0.452780 0.231653 7 2048.0 13.865088 1.499534 1.195749 0.515488 8 4096.0 0.000000 5.454785 2.669682 1.163233 9 8192.0 0.000000 22.068159 6.018604 2.772873 token-sm90-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.104460 0.012652 0.012661 0.069549 1 32.0 0.113866 0.012776 0.012765 0.069024 2 64.0 0.124600 0.016791 0.012672 0.069397 3 128.0 0.108658 0.017900 0.018294 0.074844 4 256.0 0.115463 0.029409 0.029608 0.078911 5 512.0 0.149824 0.033968 0.033701 0.092998 6 1024.0 0.234050 0.042930 0.042951 0.116920 7 2048.0 0.390695 0.061462 0.043008 0.121555 8 4096.0 0.000000 0.097505 0.042948 0.134757 9 8191.0 0.000000 0.165861 0.043542 0.158796 ``` The following might be able to help performance on short sequence length. Need update operator spec: Fall back to flash attention when total_sequence length < local_blocks * block_size ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->