Commits


Jiajia Qin authored and GitHub committed 8fbbf2fd4fe
[js/webgpu] Optimize MatMul with M = 1 (#22577) ### Description <!-- Describe your changes. --> BUG #22031 In the demucs model, there are lots of MatMul ops with shapes like below: `input[0]: [3448,1,512] | float32, input[1]: [512,1536] | float32, output[0]: [3448,1,1536] | float32` We can see that for this kind of shape, the batch size is a big value, but M = 1. Our current algorithm is based on [M, N] to partition tiles, which is not efficient for such kind of shapes. This PR reshapes the inputs to improve the matmul performance. Before: [3448,1,512] x [512,1536] = [3448,1,1536] After: [1, 3448, 512] x [512, 1536] = [1, 3448, 1536] , then the output can be reshaped to [3448, 1, 1536] The overall MatMul time in demucs model becomes 1778.45 ms from 4418.17 ms on my iGPUs. --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>