Commits


Tianlei Wu authored and GitHub committed 6a9dc6c993a
[CUDA] Update fused MHA to support flash attention and causal mask (#13953) ### Description Update fused attention kernels to support flash attention and causal mask (GPT-2 initial decoder run). Note: Causal kernels are from FasterTransformer 5.2. Flash attention kernels that is not causal are from TensorRT 8.5.1. #### Performance Test of bert-base model Test like the following: ``` python -m onnxruntime.transformers.benchmark -m bert-base-cased -b 1 4 8 16 32 64 -s 512 -t 1000 -o by_script -g -p fp16 -i 3 --use_mask_index ``` Original Flash Attention is from https://github.com/HazyResearch/flash-attention. RemovePadding and RestorePadding is added before/after the original flash attention but not for this PR, so the result is not apple-to-apple comparison. It is added for reference only. Average latency (ms) of float16 bert-base-cased model: * A100 Kernel | b1_s512 | b4_s512 | b8_s512 | b16_s512 | b32_s512 | b64_s512 | b128_s512 -- | -- | -- | -- | -- | -- | -- | -- Unfused | 1.83 | 5.00 | 9.31 | 17.76 | 34.47 | 67.43 | 133.38 TRT Fused | 2.05 | 3.58 | 5.70 | 10.96 | 21.22 | 41.23 | 80.56 Flash Attention (from FT) | 1.43 | 3.20 | 5.71 | 10.95 | 22.19 | 42.96 | 84.54 Flash Attention (from TRT) | 1.44 | 3.28 | 5.70 | 10.86 | 21.00 | 40.56 | 79.53 Original Flash Attention | 1.81 | 4.04 | 6.82 | 13.06 | 24.62 | 46.58 | 91.10 * T4 | b1_s512 | b4_s512 | b8_s512 | b16_s512 | b32_s512 | b64_s512 -- | -- | -- | -- | -- | -- | -- Unfused | 8.17 | 29.86 | 59.56 | 115.77 | 236.66 | 461.43 Flash Attention (from FT) | 5.65 | 21.12 | 44.94 | 86.83 | 174.16 | 351.38 Flash Attention (from TRT) | 5.73| 21.49| 45.49 | 89.15 | 174.37 | 352.08 Original Flash Attention | 6.22 | 22.16 | 43.39 | 83.8 | 168.77 | 337.04 * V100 Kernel | b1_s512 | b4_512 | b8_s512 | b16_s512 | b32_s512 | b64_s512 -- | -- | -- | -- | -- | -- | -- Unfused | 3.77 | 10.48 | 19.53 | 37.63 | 73.68 | 145.58 Flash Attention (from FT) | 3.21 | 8.25 | 14.95 | 28.83 | 56.28 | 111.15 #### Performance Test of GPT-2 model Test like the following: ` python benchmark_gpt2.py -m distilgpt2 -o --stage 1 --use_gpu -p fp16 -b 1 4 8 16 32 64 128 -s 0 --sequence_lengths 8 16 32 64 128 256 512 ` * A100 Note that flash attention is used as fused attention when sequence_length > 128. batch_size | sequence_length | with Fused Attention | without Fused Attention | A100 Gain -- | -- | -- | -- | -- 1 | 8 | 0.93 | 1 | 7.0% 4 | 8 | 0.82 | 0.88 | 6.8% 8 | 8 | 0.84 | 0.88 | 4.5% 16 | 8 | 0.92 | 0.97 | 5.2% 32 | 8 | 1.15 | 1.17 | 1.7% 64 | 8 | 1.68 | 1.72 | 2.3% 128 | 8 | 2.76 | 2.78 | 0.7% 1 | 16 | 0.95 | 0.95 | 0.0% 4 | 16 | 0.83 | 0.88 | 5.7% 8 | 16 | 0.91 | 0.97 | 6.2% 16 | 16 | 1.12 | 1.17 | 4.3% 32 | 16 | 1.67 | 1.72 | 2.9% 64 | 16 | 2.73 | 2.76 | 1.1% 128 | 16 | 4.96 | 4.95 | -0.2% 1 | 32 | 0.94 | 0.88 | -6.8% 4 | 32 | 0.91 | 0.97 | 6.2% 8 | 32 | 1.12 | 1.17 | 4.3% 16 | 32 | 1.65 | 1.71 | 3.5% 32 | 32 | 2.69 | 2.76 | 2.5% 64 | 32 | 4.86 | 4.94 | 1.6% 128 | 32 | 9.35 | 9.38 | 0.3% 1 | 64 | 0.84 | 0.88 | 4.5% 4 | 64 | 1.1 | 1.17 | 6.0% 8 | 64 | 1.64 | 1.73 | 5.2% 16 | 64 | 2.66 | 2.77 | 4.0% 32 | 64 | 4.82 | 4.97 | 3.0% 64 | 64 | 9.23 | 9.4 | 1.8% 128 | 64 | 18.54 | 19.12 | 3.0% 1 | 128 | 0.91 | 0.98 | 7.1% 4 | 128 | 1.68 | 1.74 | 3.4% 8 | 128 | 2.71 | 2.83 | 4.2% 16 | 128 | 4.85 | 5.09 | 4.7% 32 | 128 | 9.32 | 9.69 | 3.8% 64 | 128 | 18.54 | 19.44 | 4.6% 128 | 128 | 36.86 | 38.47 | 4.2% 1 | 256 | 1.15 | 1.23 | 6.5% 4 | 256 | 2.71 | 2.95 | 8.1% 8 | 256 | 4.87 | 5.3 | 8.1% 16 | 256 | 9.32 | 10.23 | 8.9% 32 | 256 | 18.6 | 20.53 | 9.4% 64 | 256 | 36.93 | 40.41 | 8.6% 128 | 256 | 72.84 | 80.14 | 9.1% 1 | 512 | 1.68 | 1.96 | 14.3% 4 | 512 | 4.9 | 6.02 | 18.6% 8 | 512 | 9.4 | 11.59 | 18.9% 16 | 512 | 18.71 | 23.05 | 18.8% 32 | 512 | 37.13 | 45.46 | 18.3% 64 | 512 | 74.04 | 89.88 | 17.6% 128 | 512 | NA | NA | NA * T4: batch_size | sequence_length | with Fused Attention | with Unfused Attention | T4 Gain -- | -- | -- | -- | -- 1 | 8 | 1.97 | 2.11 | 6.6% 4 | 8 | 2.2 | 2.25 | 2.2% 8 | 8 | 2.77 | 3.1 | 10.6% 16 | 8 | 4.17 | 4.2 | 0.7% 32 | 8 | 6.86 | 6.82 | -0.6% 64 | 8 | 14.88 | 14.92 | 0.3% 128 | 8 | 31.4 | 31.29 | -0.4% 1 | 16 | 1.61 | 1.71 | 5.8% 4 | 16 | 2.13 | 2.31 | 7.8% 8 | 16 | 3.38 | 3.67 | 7.9% 16 | 16 | 6.16 | 6.54 | 5.8% 32 | 16 | 14.16 | 14.76 | 4.1% 64 | 16 | 30.36 | 30.57 | 0.7% 128 | 16 | 63.14 | 63.57 | 0.7% 1 | 32 | 1.53 | 1.69 | 9.5% 4 | 32 | 3.34 | 3.66 | 8.7% 8 | 32 | 6.25 | 6.64 | 5.9% 16 | 32 | 14.12 | 14.9 | 5.2% 32 | 32 | 28.96 | 29.82 | 2.9% 64 | 32 | 61.07 | 61.77 | 1.1% 128 | 32 | 116.38 | 117.98 | 1.4% 1 | 64 | 2.01 | 2.21 | 9.0% 4 | 64 | 6.18 | 6.67 | 7.3% 8 | 64 | 13.72 | 14.49 | 5.3% 16 | 64 | 28.71 | 29.83 | 3.8% 32 | 64 | 58.65 | 60.68 | 3.3% 64 | 64 | 113.09 | 113.17 | 0.1% 128 | 64 | 205.21 | 209.4 | 2.0% 1 | 128 | 3.37 | 3.76 | 10.4% 4 | 128 | 13.54 | 14.85 | 8.8% 8 | 128 | 28.32 | 30.22 | 6.3% 16 | 128 | 58.16 | 62.09 | 6.3% 32 | 128 | 109.17 | 113.99 | 4.2% 64 | 128 | 198.9 | 207.1 | 4.0% 128 | 128 | 413.25 | 421.82 | 2.0% 1 | 256 | 6.33 | 7.05 | 10.2% 4 | 256 | 28.09 | 31.49 | 10.8% 8 | 256 | 57.47 | 62.76 | 8.4% 16 | 256 | 106.77 | 117.95 | 9.5% 32 | 256 | 197.02 | 208.58 | 5.5% 64 | 256 | 406.81 | 431.36 | 5.7% 128 | 256 | NA | NA | NA 1 | 512 | 13.84 | 16.32 | 15.2% 4 | 512 | NA | NA | NA 8 | 512 | NA | NA | NA 16 | 512 | NA | NA | NA 32 | 512 | NA | NA | NA 64 | 512 | NA | NA | NA 128 | 512 | NA | NA | NA * V100: batch_size | sequence_length | with Fused Attention | with Unfused Attention | V100 Gain -- | -- | -- | -- | -- 1 | 8 | 1.31 | 1.6 | 18.1% 4 | 8 | 1.17 | 1.26 | 7.1% 8 | 8 | 1.43 | 1.79 | 20.1% 16 | 8 | 2.14 | 1.96 | -9.2% 32 | 8 | 2.91 | 3.08 | 5.5% 64 | 8 | 5.32 | 5.27 | -0.9% 128 | 8 | 9.34 | 8.97 | -4.1% 1 | 16 | 1.41 | 1.58 | 10.8% 4 | 16 | 1.38 | 1.49 | 7.4% 8 | 16 | 1.81 | 2.2 | 17.7% 16 | 16 | 2.8 | 2.83 | 1.1% 32 | 16 | 4.94 | 4.99 | 1.0% 64 | 16 | 8.88 | 8.84 | -0.5% 128 | 16 | 17.35 | 17.2 | -0.9% 1 | 32 | 1.38 | 1.77 | 22.0% 4 | 32 | 1.77 | 1.93 | 8.3% 8 | 32 | 2.71 | 2.86 | 5.2% 16 | 32 | 5.03 | 4.92 | -2.2% 32 | 32 | 8.8 | 8.79 | -0.1% 64 | 32 | 17.29 | 17.23 | -0.3% 128 | 32 | 33.27 | 33.1 | -0.5% 1 | 64 | 1.67 | 1.87 | 10.7% 4 | 64 | 2.69 | 2.76 | 2.5% 8 | 64 | 4.87 | 4.94 | 1.4% 16 | 64 | 8.73 | 8.81 | 0.9% 32 | 64 | 16.92 | 17.24 | 1.9% 64 | 64 | 33 | 33.38 | 1.1% 128 | 64 | 65.33 | 65.86 | 0.8% 1 | 128 | 2.03 | 2.22 | 8.6% 4 | 128 | 4.9 | 5.04 | 2.8% 8 | 128 | 8.76 | 8.81 | 0.6% 16 | 128 | 17.06 | 17.29 | 1.3% 32 | 128 | 33.25 | 33.56 | 0.9% 64 | 128 | 65.54 | 66.5 | 1.4% 128 | 128 | 130.44 | 131.44 | 0.8% 1 | 256 | 2.78 | 2.86 | 2.8% 4 | 256 | 8.75 | 9.04 | 3.2% 8 | 256 | 17 | 17.68 | 3.8% 16 | 256 | 33.19 | 34.32 | 3.3% 32 | 256 | 65.43 | 67.86 | 3.6% 64 | 256 | 129.92 | 134.68 | 3.5% 128 | 256 | NA | NA | NA 1 | 512 | 4.95 | 5.32 | 7.0% 4 | 512 | NA | NA | NA 8 | 512 | NA | NA | NA 16 | 512 | NA | NA | NA 32 | 512 | NA | NA | NA 64 | 512 | NA | NA | NA 128 | 512 | NA | NA | NA