Commits


zhijiang authored and GitHub committed 05ec22330f1
softmax perf improvement pr2 - import softmax bw (#15199) when dimension to do softmax is 2048, original ort code will fallback to cudnn, while with some optimization on ort's softmax_warp_backward, we can be faster than cudnn implementation. the ideas to optimize softmax_warp_backward is: 1. instead of saving intermediate result in register, we just recompute to save resource 2. save the input data in fp16 instead of fp32 to further save resource the perf numbers:  please be noted that when dim to do softmax is less than 2048, nothing will be changed, so only gives perf number of 2048 case. add more perf number for smaller batch size 