Commits


Sushanth Rajasankar authored and GitHub committed 1be64f88319
Fix flash attention for GQA (Phi4) (#23850) ### Description This change fixes GQA for Flash Attention on Nvidia GPUs. The root cause appears to be `k_start + capped_sg_id < seq_causal_length` check. This is either because, a. seq_causal_length varies per lane, so the check becomes non uniform control flow, which is having interactions with subgroupShuffle. or b. The check itself is incorrect and is wiping out values of v based on the source lane's seq_causal_length. While in actualness values of v need to be causal as per the lane that is going to multiply it with qkt. qkt is already causal because earlier values of qk for out of bounds k are set to min_value, and exp(<-4) are 0. This fix works by removing that causal check and relying on the qk being wiped out earlier. The documentation for causality behavior for GQA is missing to determine which of this reason is the true reason. Prior to this prompts with sequence length > 16 < 32 or 1k would break with Phi 4 but smaller prompts would work. Tested on Intel Alderlake, Nvidia 4070.