Commits


pengwa authored and GitHub committed b125446f9c2
Optimize python overhead of APEX amp (#9447) * optimize python overhead of _post_amp_backward * overwrite apex amp's zero_grad for faster implementation * move unscale_fp16_grads_into_fp32_grads into C++ impl * improve the efficiency furthur, reducing 3.5ms to 1.7ms for unilm. * unilm 1.7ms to 338us: 1). optimize python list <==> std::vector copy, 2). launch the kernels as long as num_elem reach thresh hold. This help reduce the CUDA idel time. * refine the logic a bit after validating Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>